修改手动实现线性层的测试代码
This commit is contained in:
parent
1e25f418ae
commit
5958a62045
@ -147,7 +147,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 题目2\n",
|
||||
"1. **利用Tensor创建两个大小分别$3\\times 2$和$4\\times 2$的随机数矩阵P和Q,要求服从均值为$0$,标准差$0.01$为的正态分布;**\n",
|
||||
"1. **利用Tensor创建两个大小分别$3\\times 2$和$4\\times 2$的随机数矩阵$P$和$Q$,要求服从均值为$0$,标准差$0.01$为的正态分布;**\n",
|
||||
"2. **对第二步得到的矩阵$Q$进行形状变换得到$Q$的转置$Q^T$;**\n",
|
||||
"3. **对上述得到的矩阵$P$和矩阵$Q^T$求矩阵相乘。**"
|
||||
]
|
||||
@ -163,21 +163,21 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"矩阵 P:\n",
|
||||
"tensor([[ 0.0043, 0.0009],\n",
|
||||
" [-0.0008, 0.0021],\n",
|
||||
" [-0.0012, -0.0091]])\n",
|
||||
"tensor([[-6.8759e-03, -1.5809e-02],\n",
|
||||
" [ 8.0131e-03, -5.1682e-05],\n",
|
||||
" [ 4.1277e-03, 4.1408e-03]])\n",
|
||||
"矩阵 Q:\n",
|
||||
"tensor([[ 0.0016, 0.0073],\n",
|
||||
" [-0.0092, 0.0024],\n",
|
||||
" [ 0.0026, 0.0171],\n",
|
||||
" [ 0.0101, -0.0038]])\n",
|
||||
"tensor([[-0.0037, -0.0188],\n",
|
||||
" [-0.0083, 0.0147],\n",
|
||||
" [-0.0029, -0.0061],\n",
|
||||
" [-0.0123, 0.0162]])\n",
|
||||
"矩阵 QT:\n",
|
||||
"tensor([[ 0.0016, -0.0092, 0.0026, 0.0101],\n",
|
||||
" [ 0.0073, 0.0024, 0.0171, -0.0038]])\n",
|
||||
"tensor([[-0.0037, -0.0083, -0.0029, -0.0123],\n",
|
||||
" [-0.0188, 0.0147, -0.0061, 0.0162]])\n",
|
||||
"矩阵相乘的结果:\n",
|
||||
"tensor([[ 1.3472e-05, -3.7060e-05, 2.7148e-05, 3.9682e-05],\n",
|
||||
" [ 1.3877e-05, 1.2322e-05, 3.3492e-05, -1.6069e-05],\n",
|
||||
" [-6.8047e-05, -1.1357e-05, -1.5886e-04, 2.3463e-05]])\n"
|
||||
"tensor([[ 3.2228e-04, -1.7527e-04, 1.1701e-04, -1.7221e-04],\n",
|
||||
" [-2.8839e-05, -6.7375e-05, -2.2812e-05, -9.9158e-05],\n",
|
||||
" [-9.3070e-05, 2.6567e-05, -3.7366e-05, 1.6558e-05]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -415,17 +415,17 @@
|
||||
"tensor([[1.],\n",
|
||||
" [2.]], requires_grad=True)\n",
|
||||
"权重:\n",
|
||||
"tensor([[1.],\n",
|
||||
" [2.],\n",
|
||||
" [3.]])\n",
|
||||
"tensor([[-0.3470],\n",
|
||||
" [-0.4461],\n",
|
||||
" [-0.8590]])\n",
|
||||
"偏置:\n",
|
||||
"tensor([[1.]])\n",
|
||||
"tensor([ 0.9821, 0.8701, -0.4202])\n",
|
||||
"My_Linear输出:\n",
|
||||
"tensor([[2., 3., 4.],\n",
|
||||
" [3., 5., 7.]], grad_fn=<AddBackward0>)\n",
|
||||
"tensor([[ 0.6351, 0.4240, -1.2792],\n",
|
||||
" [ 0.2882, -0.0221, -2.1383]], grad_fn=<AddBackward0>)\n",
|
||||
"nn.Linear输出:\n",
|
||||
"tensor([[2., 3., 4.],\n",
|
||||
" [3., 5., 7.]], grad_fn=<AddmmBackward0>)\n"
|
||||
"tensor([[ 0.6351, 0.4240, -1.2792],\n",
|
||||
" [ 0.2882, -0.0221, -2.1383]], grad_fn=<AddmmBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -455,12 +455,7 @@
|
||||
"# 测试\n",
|
||||
"my_linear = My_Linear(1, 3)\n",
|
||||
"nn_linear = nn.Linear(1, 3)\n",
|
||||
"weight = torch.nn.Parameter(torch.tensor([[1.],\n",
|
||||
" [2.],\n",
|
||||
" [3.]]), requires_grad=True)\n",
|
||||
"bias = torch.nn.Parameter(torch.tensor([[1.]]), requires_grad=True)\n",
|
||||
"nn_linear.weight, my_linear.weight = weight, weight\n",
|
||||
"nn_linear.bias, my_linear.bias = bias, bias\n",
|
||||
"my_linear.weight, my_linear.bias = nn_linear.weight, nn_linear.bias\n",
|
||||
"x = torch.tensor([[1.], [2.]], requires_grad=True)\n",
|
||||
"print(f\"输入:\\n{x}\")\n",
|
||||
"print(f\"权重:\\n{my_linear.weight.data}\")\n",
|
||||
@ -608,26 +603,7 @@
|
||||
"execution_count": 10,
|
||||
"id": "5612661e-2809-4d46-96c2-33ee9f44116d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 696.7355516552925, Acc: 0.9442648464635192\n",
|
||||
"Epoch 2/10, Loss: 680.1393249630928, Acc: 0.9911232759674347\n",
|
||||
"Epoch 3/10, Loss: 677.770676612854, Acc: 0.9956893804390976\n",
|
||||
"Epoch 4/10, Loss: 677.294788479805, Acc: 0.9982881159072501\n",
|
||||
"Epoch 5/10, Loss: 677.1979722976685, Acc: 0.9991794744511796\n",
|
||||
"Epoch 6/10, Loss: 677.1792464852333, Acc: 0.999493084950588\n",
|
||||
"Epoch 7/10, Loss: 677.1751466989517, Acc: 0.9998704799602793\n",
|
||||
"Epoch 8/10, Loss: 677.1746656894684, Acc: 0.9999325569195194\n",
|
||||
"Epoch 9/10, Loss: 677.1742008328438, Acc: 0.9999852565480795\n",
|
||||
"Epoch 10/10, Loss: 677.1745100617409, Acc: 0.9999026775350377\n",
|
||||
"Model weights: -0.0018169691320508718, bias: 0.018722545355558395\n",
|
||||
"Prediction for test data: 0.5042262673377991\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learning_rate = 5e-2\n",
|
||||
"num_epochs = 10\n",
|
||||
@ -692,7 +668,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"id": "fa121afd-a1af-4193-9b54-68041e0ed068",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -718,7 +694,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"id": "93b0fdb6-be8b-4663-b59e-05ed19a9ea09",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -826,7 +802,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"id": "e605f1b0-1d32-410f-bddf-402a85ccc9ff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -890,7 +866,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"id": "759a3bb2-b5f4-4ea5-a2d7-15f0c4cdd14b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -952,7 +928,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": null,
|
||||
"id": "74322629-8325-4823-b80f-f28182d577c1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1013,7 +989,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": null,
|
||||
"id": "bb31a75e-464c-4b94-b927-b219a765e35d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1073,7 +1049,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": null,
|
||||
"id": "d816dae1-5fbe-4c29-9597-19d66b5eb6b4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1177,7 +1153,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": null,
|
||||
"id": "0163b9f7-1019-429c-8c29-06436d0a4c98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1204,7 +1180,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": null,
|
||||
"id": "6d241c05-b153-4f56-a845-0f2362f6459b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1321,7 +1297,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Loading…
x
Reference in New Issue
Block a user