修改手动实现线性层的测试代码

This commit is contained in:
Jingfan Ke 2023-10-11 22:36:08 +08:00
parent 1e25f418ae
commit 5958a62045

View File

@ -147,7 +147,7 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 题目2\n", "## 题目2\n",
"1. **利用Tensor创建两个大小分别$3\\times 2$和$4\\times 2$的随机数矩阵P和Q要求服从均值为$0$,标准差$0.01$为的正态分布;**\n", "1. **利用Tensor创建两个大小分别$3\\times 2$和$4\\times 2$的随机数矩阵$P$$Q$,要求服从均值为$0$,标准差$0.01$为的正态分布;**\n",
"2. **对第二步得到的矩阵$Q$进行形状变换得到$Q$的转置$Q^T$**\n", "2. **对第二步得到的矩阵$Q$进行形状变换得到$Q$的转置$Q^T$**\n",
"3. **对上述得到的矩阵$P$和矩阵$Q^T$求矩阵相乘。**" "3. **对上述得到的矩阵$P$和矩阵$Q^T$求矩阵相乘。**"
] ]
@ -163,21 +163,21 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"矩阵 P:\n", "矩阵 P:\n",
"tensor([[ 0.0043, 0.0009],\n", "tensor([[-6.8759e-03, -1.5809e-02],\n",
" [-0.0008, 0.0021],\n", " [ 8.0131e-03, -5.1682e-05],\n",
" [-0.0012, -0.0091]])\n", " [ 4.1277e-03, 4.1408e-03]])\n",
"矩阵 Q:\n", "矩阵 Q:\n",
"tensor([[ 0.0016, 0.0073],\n", "tensor([[-0.0037, -0.0188],\n",
" [-0.0092, 0.0024],\n", " [-0.0083, 0.0147],\n",
" [ 0.0026, 0.0171],\n", " [-0.0029, -0.0061],\n",
" [ 0.0101, -0.0038]])\n", " [-0.0123, 0.0162]])\n",
"矩阵 QT:\n", "矩阵 QT:\n",
"tensor([[ 0.0016, -0.0092, 0.0026, 0.0101],\n", "tensor([[-0.0037, -0.0083, -0.0029, -0.0123],\n",
" [ 0.0073, 0.0024, 0.0171, -0.0038]])\n", " [-0.0188, 0.0147, -0.0061, 0.0162]])\n",
"矩阵相乘的结果:\n", "矩阵相乘的结果:\n",
"tensor([[ 1.3472e-05, -3.7060e-05, 2.7148e-05, 3.9682e-05],\n", "tensor([[ 3.2228e-04, -1.7527e-04, 1.1701e-04, -1.7221e-04],\n",
" [ 1.3877e-05, 1.2322e-05, 3.3492e-05, -1.6069e-05],\n", " [-2.8839e-05, -6.7375e-05, -2.2812e-05, -9.9158e-05],\n",
" [-6.8047e-05, -1.1357e-05, -1.5886e-04, 2.3463e-05]])\n" " [-9.3070e-05, 2.6567e-05, -3.7366e-05, 1.6558e-05]])\n"
] ]
} }
], ],
@ -415,17 +415,17 @@
"tensor([[1.],\n", "tensor([[1.],\n",
" [2.]], requires_grad=True)\n", " [2.]], requires_grad=True)\n",
"权重:\n", "权重:\n",
"tensor([[1.],\n", "tensor([[-0.3470],\n",
" [2.],\n", " [-0.4461],\n",
" [3.]])\n", " [-0.8590]])\n",
"偏置:\n", "偏置:\n",
"tensor([[1.]])\n", "tensor([ 0.9821, 0.8701, -0.4202])\n",
"My_Linear输出\n", "My_Linear输出\n",
"tensor([[2., 3., 4.],\n", "tensor([[ 0.6351, 0.4240, -1.2792],\n",
" [3., 5., 7.]], grad_fn=<AddBackward0>)\n", " [ 0.2882, -0.0221, -2.1383]], grad_fn=<AddBackward0>)\n",
"nn.Linear输出\n", "nn.Linear输出\n",
"tensor([[2., 3., 4.],\n", "tensor([[ 0.6351, 0.4240, -1.2792],\n",
" [3., 5., 7.]], grad_fn=<AddmmBackward0>)\n" " [ 0.2882, -0.0221, -2.1383]], grad_fn=<AddmmBackward0>)\n"
] ]
} }
], ],
@ -455,12 +455,7 @@
"# 测试\n", "# 测试\n",
"my_linear = My_Linear(1, 3)\n", "my_linear = My_Linear(1, 3)\n",
"nn_linear = nn.Linear(1, 3)\n", "nn_linear = nn.Linear(1, 3)\n",
"weight = torch.nn.Parameter(torch.tensor([[1.],\n", "my_linear.weight, my_linear.bias = nn_linear.weight, nn_linear.bias\n",
" [2.],\n",
" [3.]]), requires_grad=True)\n",
"bias = torch.nn.Parameter(torch.tensor([[1.]]), requires_grad=True)\n",
"nn_linear.weight, my_linear.weight = weight, weight\n",
"nn_linear.bias, my_linear.bias = bias, bias\n",
"x = torch.tensor([[1.], [2.]], requires_grad=True)\n", "x = torch.tensor([[1.], [2.]], requires_grad=True)\n",
"print(f\"输入:\\n{x}\")\n", "print(f\"输入:\\n{x}\")\n",
"print(f\"权重:\\n{my_linear.weight.data}\")\n", "print(f\"权重:\\n{my_linear.weight.data}\")\n",
@ -608,26 +603,7 @@
"execution_count": 10, "execution_count": 10,
"id": "5612661e-2809-4d46-96c2-33ee9f44116d", "id": "5612661e-2809-4d46-96c2-33ee9f44116d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10, Loss: 696.7355516552925, Acc: 0.9442648464635192\n",
"Epoch 2/10, Loss: 680.1393249630928, Acc: 0.9911232759674347\n",
"Epoch 3/10, Loss: 677.770676612854, Acc: 0.9956893804390976\n",
"Epoch 4/10, Loss: 677.294788479805, Acc: 0.9982881159072501\n",
"Epoch 5/10, Loss: 677.1979722976685, Acc: 0.9991794744511796\n",
"Epoch 6/10, Loss: 677.1792464852333, Acc: 0.999493084950588\n",
"Epoch 7/10, Loss: 677.1751466989517, Acc: 0.9998704799602793\n",
"Epoch 8/10, Loss: 677.1746656894684, Acc: 0.9999325569195194\n",
"Epoch 9/10, Loss: 677.1742008328438, Acc: 0.9999852565480795\n",
"Epoch 10/10, Loss: 677.1745100617409, Acc: 0.9999026775350377\n",
"Model weights: -0.0018169691320508718, bias: 0.018722545355558395\n",
"Prediction for test data: 0.5042262673377991\n"
]
}
],
"source": [ "source": [
"learning_rate = 5e-2\n", "learning_rate = 5e-2\n",
"num_epochs = 10\n", "num_epochs = 10\n",
@ -692,7 +668,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": null,
"id": "fa121afd-a1af-4193-9b54-68041e0ed068", "id": "fa121afd-a1af-4193-9b54-68041e0ed068",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -718,7 +694,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": null,
"id": "93b0fdb6-be8b-4663-b59e-05ed19a9ea09", "id": "93b0fdb6-be8b-4663-b59e-05ed19a9ea09",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -826,7 +802,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": null,
"id": "e605f1b0-1d32-410f-bddf-402a85ccc9ff", "id": "e605f1b0-1d32-410f-bddf-402a85ccc9ff",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -890,7 +866,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": null,
"id": "759a3bb2-b5f4-4ea5-a2d7-15f0c4cdd14b", "id": "759a3bb2-b5f4-4ea5-a2d7-15f0c4cdd14b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -952,7 +928,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": null,
"id": "74322629-8325-4823-b80f-f28182d577c1", "id": "74322629-8325-4823-b80f-f28182d577c1",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1013,7 +989,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": null,
"id": "bb31a75e-464c-4b94-b927-b219a765e35d", "id": "bb31a75e-464c-4b94-b927-b219a765e35d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -1073,7 +1049,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": null,
"id": "d816dae1-5fbe-4c29-9597-19d66b5eb6b4", "id": "d816dae1-5fbe-4c29-9597-19d66b5eb6b4",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1177,7 +1153,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": null,
"id": "0163b9f7-1019-429c-8c29-06436d0a4c98", "id": "0163b9f7-1019-429c-8c29-06436d0a4c98",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -1204,7 +1180,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": null,
"id": "6d241c05-b153-4f56-a845-0f2362f6459b", "id": "6d241c05-b153-4f56-a845-0f2362f6459b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1321,7 +1297,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.13" "version": "3.10.12"
} }
}, },
"nbformat": 4, "nbformat": 4,