修改手动实现线性层的测试代码
This commit is contained in:
parent
5958a62045
commit
22b33d63de
@ -147,7 +147,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 题目2\n",
|
||||
"1. **利用Tensor创建两个大小分别$3\\times 2$和$4\\times 2$的随机数矩阵P和Q,要求服从均值为$0$,标准差$0.01$为的正态分布;**\n",
|
||||
"1. **利用Tensor创建两个大小分别$3\\times 2$和$4\\times 2$的随机数矩阵$P$和$Q$,要求服从均值为$0$,标准差$0.01$为的正态分布;**\n",
|
||||
"2. **对第二步得到的矩阵$Q$进行形状变换得到$Q$的转置$Q^T$;**\n",
|
||||
"3. **对上述得到的矩阵$P$和矩阵$Q^T$求矩阵相乘。**"
|
||||
]
|
||||
@ -163,21 +163,21 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"矩阵 P:\n",
|
||||
"tensor([[ 0.0043, 0.0009],\n",
|
||||
" [-0.0008, 0.0021],\n",
|
||||
" [-0.0012, -0.0091]])\n",
|
||||
"tensor([[ 0.0069, 0.0082],\n",
|
||||
" [-0.0052, -0.0124],\n",
|
||||
" [ 0.0055, -0.0014]])\n",
|
||||
"矩阵 Q:\n",
|
||||
"tensor([[ 0.0016, 0.0073],\n",
|
||||
" [-0.0092, 0.0024],\n",
|
||||
" [ 0.0026, 0.0171],\n",
|
||||
" [ 0.0101, -0.0038]])\n",
|
||||
"tensor([[ 0.0050, 0.0075],\n",
|
||||
" [ 0.0161, 0.0070],\n",
|
||||
" [-0.0009, -0.0014],\n",
|
||||
" [-0.0003, 0.0024]])\n",
|
||||
"矩阵 QT:\n",
|
||||
"tensor([[ 0.0016, -0.0092, 0.0026, 0.0101],\n",
|
||||
" [ 0.0073, 0.0024, 0.0171, -0.0038]])\n",
|
||||
"tensor([[ 0.0050, 0.0161, -0.0009, -0.0003],\n",
|
||||
" [ 0.0075, 0.0070, -0.0014, 0.0024]])\n",
|
||||
"矩阵相乘的结果:\n",
|
||||
"tensor([[ 1.3472e-05, -3.7060e-05, 2.7148e-05, 3.9682e-05],\n",
|
||||
" [ 1.3877e-05, 1.2322e-05, 3.3492e-05, -1.6069e-05],\n",
|
||||
" [-6.8047e-05, -1.1357e-05, -1.5886e-04, 2.3463e-05]])\n"
|
||||
"tensor([[ 9.6016e-05, 1.6860e-04, -1.7451e-05, 1.8011e-05],\n",
|
||||
" [-1.1894e-04, -1.7065e-04, 2.1900e-05, -2.8712e-05],\n",
|
||||
" [ 1.6918e-05, 7.8455e-05, -2.7165e-06, -4.9904e-06]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -415,17 +415,17 @@
|
||||
"tensor([[1.],\n",
|
||||
" [2.]], requires_grad=True)\n",
|
||||
"权重:\n",
|
||||
"tensor([[1.],\n",
|
||||
" [2.],\n",
|
||||
" [3.]])\n",
|
||||
"tensor([[ 0.4240],\n",
|
||||
" [-0.2577],\n",
|
||||
" [ 0.4972]])\n",
|
||||
"偏置:\n",
|
||||
"tensor([[1.]])\n",
|
||||
"tensor([0.6298, 0.6243, 0.8217])\n",
|
||||
"My_Linear输出:\n",
|
||||
"tensor([[2., 3., 4.],\n",
|
||||
" [3., 5., 7.]], grad_fn=<AddBackward0>)\n",
|
||||
"tensor([[1.0539, 0.3666, 1.3189],\n",
|
||||
" [1.4779, 0.1089, 1.8161]], grad_fn=<AddBackward0>)\n",
|
||||
"nn.Linear输出:\n",
|
||||
"tensor([[2., 3., 4.],\n",
|
||||
" [3., 5., 7.]], grad_fn=<AddmmBackward0>)\n"
|
||||
"tensor([[1.0539, 0.3666, 1.3189],\n",
|
||||
" [1.4779, 0.1089, 1.8161]], grad_fn=<AddmmBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -455,12 +455,8 @@
|
||||
"# 测试\n",
|
||||
"my_linear = My_Linear(1, 3)\n",
|
||||
"nn_linear = nn.Linear(1, 3)\n",
|
||||
"weight = torch.nn.Parameter(torch.tensor([[1.],\n",
|
||||
" [2.],\n",
|
||||
" [3.]]), requires_grad=True)\n",
|
||||
"bias = torch.nn.Parameter(torch.tensor([[1.]]), requires_grad=True)\n",
|
||||
"nn_linear.weight, my_linear.weight = weight, weight\n",
|
||||
"nn_linear.bias, my_linear.bias = bias, bias\n",
|
||||
"my_linear.weight = nn_linear.weight.clone().requires_grad_()\n",
|
||||
"my_linear.bias = nn_linear.bias.clone().requires_grad_()\n",
|
||||
"x = torch.tensor([[1.], [2.]], requires_grad=True)\n",
|
||||
"print(f\"输入:\\n{x}\")\n",
|
||||
"print(f\"权重:\\n{my_linear.weight.data}\")\n",
|
||||
@ -613,18 +609,18 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 696.7355516552925, Acc: 0.9442648464635192\n",
|
||||
"Epoch 2/10, Loss: 680.1393249630928, Acc: 0.9911232759674347\n",
|
||||
"Epoch 3/10, Loss: 677.770676612854, Acc: 0.9956893804390976\n",
|
||||
"Epoch 4/10, Loss: 677.294788479805, Acc: 0.9982881159072501\n",
|
||||
"Epoch 5/10, Loss: 677.1979722976685, Acc: 0.9991794744511796\n",
|
||||
"Epoch 6/10, Loss: 677.1792464852333, Acc: 0.999493084950588\n",
|
||||
"Epoch 7/10, Loss: 677.1751466989517, Acc: 0.9998704799602793\n",
|
||||
"Epoch 8/10, Loss: 677.1746656894684, Acc: 0.9999325569195194\n",
|
||||
"Epoch 9/10, Loss: 677.1742008328438, Acc: 0.9999852565480795\n",
|
||||
"Epoch 10/10, Loss: 677.1745100617409, Acc: 0.9999026775350377\n",
|
||||
"Model weights: -0.0018169691320508718, bias: 0.018722545355558395\n",
|
||||
"Prediction for test data: 0.5042262673377991\n"
|
||||
"Epoch 1/10, Loss: 688.6783249974251, Acc: 0.9766838179955138\n",
|
||||
"Epoch 2/10, Loss: 679.506599009037, Acc: 0.992039453911494\n",
|
||||
"Epoch 3/10, Loss: 677.644762635231, Acc: 0.9961844975781526\n",
|
||||
"Epoch 4/10, Loss: 677.2690716981888, Acc: 0.998395304269398\n",
|
||||
"Epoch 5/10, Loss: 677.1928514242172, Acc: 0.9993592246184307\n",
|
||||
"Epoch 6/10, Loss: 677.1781670451164, Acc: 0.9996570376204033\n",
|
||||
"Epoch 7/10, Loss: 677.1744618415833, Acc: 0.9998465339227576\n",
|
||||
"Epoch 8/10, Loss: 677.1738814711571, Acc: 0.9998001679325041\n",
|
||||
"Epoch 9/10, Loss: 677.1742851734161, Acc: 0.9998804348705138\n",
|
||||
"Epoch 10/10, Loss: 677.1740592718124, Acc: 0.9999446971149187\n",
|
||||
"Model weights: -0.0037125118542462587, bias: 0.017451055347919464\n",
|
||||
"Prediction for test data: 0.5034345984458923\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -726,18 +722,18 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 638.0160498773107, Acc: 0.9937490349019101\n",
|
||||
"Epoch 2/10, Loss: 583.3716011882699, Acc: 0.9804301403247839\n",
|
||||
"Epoch 3/10, Loss: 571.1401001196623, Acc: 0.9896515985806724\n",
|
||||
"Epoch 4/10, Loss: 567.6159870155185, Acc: 0.9943010522600507\n",
|
||||
"Epoch 5/10, Loss: 566.5014995958526, Acc: 0.9966799384882902\n",
|
||||
"Epoch 6/10, Loss: 566.1252285088149, Acc: 0.998098013624422\n",
|
||||
"Epoch 7/10, Loss: 565.9985610526666, Acc: 0.9988608489236236\n",
|
||||
"Epoch 8/10, Loss: 565.9526960214308, Acc: 0.9993323768578708\n",
|
||||
"Epoch 9/10, Loss: 565.9374750639024, Acc: 0.9995989407216784\n",
|
||||
"Epoch 10/10, Loss: 565.9291789969539, Acc: 0.9997716274081613\n",
|
||||
"Model weights: -3.685411369051284, bias: 1.8638604353928832\n",
|
||||
"Prediction for test data: 0.1392477539282304\n"
|
||||
"Epoch 1/10, Loss: 660.2008021697803, Acc: 0.9355364605682331\n",
|
||||
"Epoch 2/10, Loss: 589.2025169091534, Acc: 0.9769773185253259\n",
|
||||
"Epoch 3/10, Loss: 572.7106042209589, Acc: 0.9881629137259633\n",
|
||||
"Epoch 4/10, Loss: 568.0903503441508, Acc: 0.9935173218188225\n",
|
||||
"Epoch 5/10, Loss: 566.6528526848851, Acc: 0.9962586560919562\n",
|
||||
"Epoch 6/10, Loss: 566.1778871576632, Acc: 0.9978209774304773\n",
|
||||
"Epoch 7/10, Loss: 566.0143385848835, Acc: 0.9987369762885633\n",
|
||||
"Epoch 8/10, Loss: 565.9605239629793, Acc: 0.9992563563084009\n",
|
||||
"Epoch 9/10, Loss: 565.9402079010808, Acc: 0.9995321069396558\n",
|
||||
"Epoch 10/10, Loss: 565.9281422200424, Acc: 0.9997496312356398\n",
|
||||
"Model weights: -3.6833968323036084, bias: 1.8628376037952126\n",
|
||||
"Prediction for test data: 0.13936666014014443\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -899,15 +895,15 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"输入:\n",
|
||||
"tensor([[ 1.0624, 1.7008, -1.2849, 0.4049, -0.3993],\n",
|
||||
" [ 0.0757, 1.0636, 0.3586, -0.0252, -1.1431],\n",
|
||||
" [ 0.4754, -1.9538, 0.6616, -1.0363, 0.6049]], requires_grad=True)\n",
|
||||
"tensor([[ 0.7600, 0.4269, 0.7948, -0.6086, 1.2527],\n",
|
||||
" [-0.4749, 0.5720, -0.0164, -0.2126, -0.0410],\n",
|
||||
" [ 1.3269, 1.8524, -0.9815, 0.0156, 1.6971]], requires_grad=True)\n",
|
||||
"标签:\n",
|
||||
"tensor([[0., 0., 0., 0., 1.],\n",
|
||||
" [0., 1., 0., 0., 0.],\n",
|
||||
" [0., 0., 0., 0., 1.]])\n",
|
||||
"My_CrossEntropyLoss损失值: 1.5949448347091675\n",
|
||||
"nn.CrossEntropyLoss损失值: 1.594944953918457\n"
|
||||
"tensor([[0., 1., 0., 0., 0.],\n",
|
||||
" [0., 0., 0., 1., 0.],\n",
|
||||
" [1., 0., 0., 0., 0.]])\n",
|
||||
"My_CrossEntropyLoss损失值: 1.7417106628417969\n",
|
||||
"nn.CrossEntropyLoss损失值: 1.7417105436325073\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1081,16 +1077,16 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 84.3017807006836, Acc: 0.5087000131607056\n",
|
||||
"Epoch 2/10, Loss: 37.02857208251953, Acc: 0.5946999788284302\n",
|
||||
"Epoch 3/10, Loss: 30.553579330444336, Acc: 0.6287999749183655\n",
|
||||
"Epoch 4/10, Loss: 27.279203414916992, Acc: 0.6550999879837036\n",
|
||||
"Epoch 5/10, Loss: 25.244386672973633, Acc: 0.6694999933242798\n",
|
||||
"Epoch 6/10, Loss: 23.713878631591797, Acc: 0.6798999905586243\n",
|
||||
"Epoch 7/10, Loss: 22.5694580078125, Acc: 0.6924999952316284\n",
|
||||
"Epoch 8/10, Loss: 21.611900329589844, Acc: 0.6965000033378601\n",
|
||||
"Epoch 9/10, Loss: 20.85039520263672, Acc: 0.7014999985694885\n",
|
||||
"Epoch 10/10, Loss: 20.116191864013672, Acc: 0.7102000117301941\n"
|
||||
"Epoch 1/10, Loss: 87.64246368408203, Acc: 0.45329999923706055\n",
|
||||
"Epoch 2/10, Loss: 42.025726318359375, Acc: 0.5523999929428101\n",
|
||||
"Epoch 3/10, Loss: 34.06425094604492, Acc: 0.5947999954223633\n",
|
||||
"Epoch 4/10, Loss: 30.135021209716797, Acc: 0.620199978351593\n",
|
||||
"Epoch 5/10, Loss: 27.43822479248047, Acc: 0.6401000022888184\n",
|
||||
"Epoch 6/10, Loss: 25.72039031982422, Acc: 0.6525999903678894\n",
|
||||
"Epoch 7/10, Loss: 24.28335952758789, Acc: 0.6638999581336975\n",
|
||||
"Epoch 8/10, Loss: 23.18214988708496, Acc: 0.671999990940094\n",
|
||||
"Epoch 9/10, Loss: 22.18520164489746, Acc: 0.680899977684021\n",
|
||||
"Epoch 10/10, Loss: 21.393451690673828, Acc: 0.6875999569892883\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1212,16 +1208,16 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 19.19444465637207, Acc: 0.7229999899864197\n",
|
||||
"Epoch 2/10, Loss: 12.180685043334961, Acc: 0.7491999864578247\n",
|
||||
"Epoch 3/10, Loss: 10.77286148071289, Acc: 0.7608999609947205\n",
|
||||
"Epoch 4/10, Loss: 10.058968544006348, Acc: 0.7716000080108643\n",
|
||||
"Epoch 5/10, Loss: 9.58817195892334, Acc: 0.7815999984741211\n",
|
||||
"Epoch 6/10, Loss: 9.245816230773926, Acc: 0.7861999869346619\n",
|
||||
"Epoch 7/10, Loss: 8.98766040802002, Acc: 0.7924000024795532\n",
|
||||
"Epoch 8/10, Loss: 8.778538703918457, Acc: 0.7949999570846558\n",
|
||||
"Epoch 9/10, Loss: 8.59365177154541, Acc: 0.795699954032898\n",
|
||||
"Epoch 10/10, Loss: 8.442872047424316, Acc: 0.7998999953269958\n"
|
||||
"Epoch 1/10, Loss: 19.15451431274414, Acc: 0.7202000021934509\n",
|
||||
"Epoch 2/10, Loss: 12.260371208190918, Acc: 0.7486000061035156\n",
|
||||
"Epoch 3/10, Loss: 10.835549354553223, Acc: 0.7615999579429626\n",
|
||||
"Epoch 4/10, Loss: 10.09542179107666, Acc: 0.7701999545097351\n",
|
||||
"Epoch 5/10, Loss: 9.626176834106445, Acc: 0.777899980545044\n",
|
||||
"Epoch 6/10, Loss: 9.264442443847656, Acc: 0.7854999899864197\n",
|
||||
"Epoch 7/10, Loss: 9.017412185668945, Acc: 0.7879999876022339\n",
|
||||
"Epoch 8/10, Loss: 8.786051750183105, Acc: 0.7915999889373779\n",
|
||||
"Epoch 9/10, Loss: 8.613431930541992, Acc: 0.79749995470047\n",
|
||||
"Epoch 10/10, Loss: 8.462657928466797, Acc: 0.7996999621391296\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -163,21 +163,21 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"矩阵 P:\n",
|
||||
"tensor([[-6.8759e-03, -1.5809e-02],\n",
|
||||
" [ 8.0131e-03, -5.1682e-05],\n",
|
||||
" [ 4.1277e-03, 4.1408e-03]])\n",
|
||||
"tensor([[ 0.0069, 0.0082],\n",
|
||||
" [-0.0052, -0.0124],\n",
|
||||
" [ 0.0055, -0.0014]])\n",
|
||||
"矩阵 Q:\n",
|
||||
"tensor([[-0.0037, -0.0188],\n",
|
||||
" [-0.0083, 0.0147],\n",
|
||||
" [-0.0029, -0.0061],\n",
|
||||
" [-0.0123, 0.0162]])\n",
|
||||
"tensor([[ 0.0050, 0.0075],\n",
|
||||
" [ 0.0161, 0.0070],\n",
|
||||
" [-0.0009, -0.0014],\n",
|
||||
" [-0.0003, 0.0024]])\n",
|
||||
"矩阵 QT:\n",
|
||||
"tensor([[-0.0037, -0.0083, -0.0029, -0.0123],\n",
|
||||
" [-0.0188, 0.0147, -0.0061, 0.0162]])\n",
|
||||
"tensor([[ 0.0050, 0.0161, -0.0009, -0.0003],\n",
|
||||
" [ 0.0075, 0.0070, -0.0014, 0.0024]])\n",
|
||||
"矩阵相乘的结果:\n",
|
||||
"tensor([[ 3.2228e-04, -1.7527e-04, 1.1701e-04, -1.7221e-04],\n",
|
||||
" [-2.8839e-05, -6.7375e-05, -2.2812e-05, -9.9158e-05],\n",
|
||||
" [-9.3070e-05, 2.6567e-05, -3.7366e-05, 1.6558e-05]])\n"
|
||||
"tensor([[ 9.6016e-05, 1.6860e-04, -1.7451e-05, 1.8011e-05],\n",
|
||||
" [-1.1894e-04, -1.7065e-04, 2.1900e-05, -2.8712e-05],\n",
|
||||
" [ 1.6918e-05, 7.8455e-05, -2.7165e-06, -4.9904e-06]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -415,17 +415,17 @@
|
||||
"tensor([[1.],\n",
|
||||
" [2.]], requires_grad=True)\n",
|
||||
"权重:\n",
|
||||
"tensor([[-0.3470],\n",
|
||||
" [-0.4461],\n",
|
||||
" [-0.8590]])\n",
|
||||
"tensor([[ 0.4240],\n",
|
||||
" [-0.2577],\n",
|
||||
" [ 0.4972]])\n",
|
||||
"偏置:\n",
|
||||
"tensor([ 0.9821, 0.8701, -0.4202])\n",
|
||||
"tensor([0.6298, 0.6243, 0.8217])\n",
|
||||
"My_Linear输出:\n",
|
||||
"tensor([[ 0.6351, 0.4240, -1.2792],\n",
|
||||
" [ 0.2882, -0.0221, -2.1383]], grad_fn=<AddBackward0>)\n",
|
||||
"tensor([[1.0539, 0.3666, 1.3189],\n",
|
||||
" [1.4779, 0.1089, 1.8161]], grad_fn=<AddBackward0>)\n",
|
||||
"nn.Linear输出:\n",
|
||||
"tensor([[ 0.6351, 0.4240, -1.2792],\n",
|
||||
" [ 0.2882, -0.0221, -2.1383]], grad_fn=<AddmmBackward0>)\n"
|
||||
"tensor([[1.0539, 0.3666, 1.3189],\n",
|
||||
" [1.4779, 0.1089, 1.8161]], grad_fn=<AddmmBackward0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -455,7 +455,8 @@
|
||||
"# 测试\n",
|
||||
"my_linear = My_Linear(1, 3)\n",
|
||||
"nn_linear = nn.Linear(1, 3)\n",
|
||||
"my_linear.weight, my_linear.bias = nn_linear.weight, nn_linear.bias\n",
|
||||
"my_linear.weight = nn_linear.weight.clone().requires_grad_()\n",
|
||||
"my_linear.bias = nn_linear.bias.clone().requires_grad_()\n",
|
||||
"x = torch.tensor([[1.], [2.]], requires_grad=True)\n",
|
||||
"print(f\"输入:\\n{x}\")\n",
|
||||
"print(f\"权重:\\n{my_linear.weight.data}\")\n",
|
||||
@ -603,7 +604,26 @@
|
||||
"execution_count": 10,
|
||||
"id": "5612661e-2809-4d46-96c2-33ee9f44116d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 688.6783249974251, Acc: 0.9766838179955138\n",
|
||||
"Epoch 2/10, Loss: 679.506599009037, Acc: 0.992039453911494\n",
|
||||
"Epoch 3/10, Loss: 677.644762635231, Acc: 0.9961844975781526\n",
|
||||
"Epoch 4/10, Loss: 677.2690716981888, Acc: 0.998395304269398\n",
|
||||
"Epoch 5/10, Loss: 677.1928514242172, Acc: 0.9993592246184307\n",
|
||||
"Epoch 6/10, Loss: 677.1781670451164, Acc: 0.9996570376204033\n",
|
||||
"Epoch 7/10, Loss: 677.1744618415833, Acc: 0.9998465339227576\n",
|
||||
"Epoch 8/10, Loss: 677.1738814711571, Acc: 0.9998001679325041\n",
|
||||
"Epoch 9/10, Loss: 677.1742851734161, Acc: 0.9998804348705138\n",
|
||||
"Epoch 10/10, Loss: 677.1740592718124, Acc: 0.9999446971149187\n",
|
||||
"Model weights: -0.0037125118542462587, bias: 0.017451055347919464\n",
|
||||
"Prediction for test data: 0.5034345984458923\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"learning_rate = 5e-2\n",
|
||||
"num_epochs = 10\n",
|
||||
@ -668,7 +688,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"id": "fa121afd-a1af-4193-9b54-68041e0ed068",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -694,7 +714,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"id": "93b0fdb6-be8b-4663-b59e-05ed19a9ea09",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -702,18 +722,18 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 638.0160498773107, Acc: 0.9937490349019101\n",
|
||||
"Epoch 2/10, Loss: 583.3716011882699, Acc: 0.9804301403247839\n",
|
||||
"Epoch 3/10, Loss: 571.1401001196623, Acc: 0.9896515985806724\n",
|
||||
"Epoch 4/10, Loss: 567.6159870155185, Acc: 0.9943010522600507\n",
|
||||
"Epoch 5/10, Loss: 566.5014995958526, Acc: 0.9966799384882902\n",
|
||||
"Epoch 6/10, Loss: 566.1252285088149, Acc: 0.998098013624422\n",
|
||||
"Epoch 7/10, Loss: 565.9985610526666, Acc: 0.9988608489236236\n",
|
||||
"Epoch 8/10, Loss: 565.9526960214308, Acc: 0.9993323768578708\n",
|
||||
"Epoch 9/10, Loss: 565.9374750639024, Acc: 0.9995989407216784\n",
|
||||
"Epoch 10/10, Loss: 565.9291789969539, Acc: 0.9997716274081613\n",
|
||||
"Model weights: -3.685411369051284, bias: 1.8638604353928832\n",
|
||||
"Prediction for test data: 0.1392477539282304\n"
|
||||
"Epoch 1/10, Loss: 660.2008021697803, Acc: 0.9355364605682331\n",
|
||||
"Epoch 2/10, Loss: 589.2025169091534, Acc: 0.9769773185253259\n",
|
||||
"Epoch 3/10, Loss: 572.7106042209589, Acc: 0.9881629137259633\n",
|
||||
"Epoch 4/10, Loss: 568.0903503441508, Acc: 0.9935173218188225\n",
|
||||
"Epoch 5/10, Loss: 566.6528526848851, Acc: 0.9962586560919562\n",
|
||||
"Epoch 6/10, Loss: 566.1778871576632, Acc: 0.9978209774304773\n",
|
||||
"Epoch 7/10, Loss: 566.0143385848835, Acc: 0.9987369762885633\n",
|
||||
"Epoch 8/10, Loss: 565.9605239629793, Acc: 0.9992563563084009\n",
|
||||
"Epoch 9/10, Loss: 565.9402079010808, Acc: 0.9995321069396558\n",
|
||||
"Epoch 10/10, Loss: 565.9281422200424, Acc: 0.9997496312356398\n",
|
||||
"Model weights: -3.6833968323036084, bias: 1.8628376037952126\n",
|
||||
"Prediction for test data: 0.13936666014014443\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -802,7 +822,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"id": "e605f1b0-1d32-410f-bddf-402a85ccc9ff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -866,7 +886,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"id": "759a3bb2-b5f4-4ea5-a2d7-15f0c4cdd14b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -875,15 +895,15 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"输入:\n",
|
||||
"tensor([[ 1.0624, 1.7008, -1.2849, 0.4049, -0.3993],\n",
|
||||
" [ 0.0757, 1.0636, 0.3586, -0.0252, -1.1431],\n",
|
||||
" [ 0.4754, -1.9538, 0.6616, -1.0363, 0.6049]], requires_grad=True)\n",
|
||||
"tensor([[ 0.7600, 0.4269, 0.7948, -0.6086, 1.2527],\n",
|
||||
" [-0.4749, 0.5720, -0.0164, -0.2126, -0.0410],\n",
|
||||
" [ 1.3269, 1.8524, -0.9815, 0.0156, 1.6971]], requires_grad=True)\n",
|
||||
"标签:\n",
|
||||
"tensor([[0., 0., 0., 0., 1.],\n",
|
||||
" [0., 1., 0., 0., 0.],\n",
|
||||
" [0., 0., 0., 0., 1.]])\n",
|
||||
"My_CrossEntropyLoss损失值: 1.5949448347091675\n",
|
||||
"nn.CrossEntropyLoss损失值: 1.594944953918457\n"
|
||||
"tensor([[0., 1., 0., 0., 0.],\n",
|
||||
" [0., 0., 0., 1., 0.],\n",
|
||||
" [1., 0., 0., 0., 0.]])\n",
|
||||
"My_CrossEntropyLoss损失值: 1.7417106628417969\n",
|
||||
"nn.CrossEntropyLoss损失值: 1.7417105436325073\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -928,7 +948,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 15,
|
||||
"id": "74322629-8325-4823-b80f-f28182d577c1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -989,7 +1009,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 16,
|
||||
"id": "bb31a75e-464c-4b94-b927-b219a765e35d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1049,7 +1069,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 17,
|
||||
"id": "d816dae1-5fbe-4c29-9597-19d66b5eb6b4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1057,16 +1077,16 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 84.3017807006836, Acc: 0.5087000131607056\n",
|
||||
"Epoch 2/10, Loss: 37.02857208251953, Acc: 0.5946999788284302\n",
|
||||
"Epoch 3/10, Loss: 30.553579330444336, Acc: 0.6287999749183655\n",
|
||||
"Epoch 4/10, Loss: 27.279203414916992, Acc: 0.6550999879837036\n",
|
||||
"Epoch 5/10, Loss: 25.244386672973633, Acc: 0.6694999933242798\n",
|
||||
"Epoch 6/10, Loss: 23.713878631591797, Acc: 0.6798999905586243\n",
|
||||
"Epoch 7/10, Loss: 22.5694580078125, Acc: 0.6924999952316284\n",
|
||||
"Epoch 8/10, Loss: 21.611900329589844, Acc: 0.6965000033378601\n",
|
||||
"Epoch 9/10, Loss: 20.85039520263672, Acc: 0.7014999985694885\n",
|
||||
"Epoch 10/10, Loss: 20.116191864013672, Acc: 0.7102000117301941\n"
|
||||
"Epoch 1/10, Loss: 87.64246368408203, Acc: 0.45329999923706055\n",
|
||||
"Epoch 2/10, Loss: 42.025726318359375, Acc: 0.5523999929428101\n",
|
||||
"Epoch 3/10, Loss: 34.06425094604492, Acc: 0.5947999954223633\n",
|
||||
"Epoch 4/10, Loss: 30.135021209716797, Acc: 0.620199978351593\n",
|
||||
"Epoch 5/10, Loss: 27.43822479248047, Acc: 0.6401000022888184\n",
|
||||
"Epoch 6/10, Loss: 25.72039031982422, Acc: 0.6525999903678894\n",
|
||||
"Epoch 7/10, Loss: 24.28335952758789, Acc: 0.6638999581336975\n",
|
||||
"Epoch 8/10, Loss: 23.18214988708496, Acc: 0.671999990940094\n",
|
||||
"Epoch 9/10, Loss: 22.18520164489746, Acc: 0.680899977684021\n",
|
||||
"Epoch 10/10, Loss: 21.393451690673828, Acc: 0.6875999569892883\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1153,7 +1173,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 18,
|
||||
"id": "0163b9f7-1019-429c-8c29-06436d0a4c98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1180,7 +1200,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 19,
|
||||
"id": "6d241c05-b153-4f56-a845-0f2362f6459b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1188,16 +1208,16 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10, Loss: 19.19444465637207, Acc: 0.7229999899864197\n",
|
||||
"Epoch 2/10, Loss: 12.180685043334961, Acc: 0.7491999864578247\n",
|
||||
"Epoch 3/10, Loss: 10.77286148071289, Acc: 0.7608999609947205\n",
|
||||
"Epoch 4/10, Loss: 10.058968544006348, Acc: 0.7716000080108643\n",
|
||||
"Epoch 5/10, Loss: 9.58817195892334, Acc: 0.7815999984741211\n",
|
||||
"Epoch 6/10, Loss: 9.245816230773926, Acc: 0.7861999869346619\n",
|
||||
"Epoch 7/10, Loss: 8.98766040802002, Acc: 0.7924000024795532\n",
|
||||
"Epoch 8/10, Loss: 8.778538703918457, Acc: 0.7949999570846558\n",
|
||||
"Epoch 9/10, Loss: 8.59365177154541, Acc: 0.795699954032898\n",
|
||||
"Epoch 10/10, Loss: 8.442872047424316, Acc: 0.7998999953269958\n"
|
||||
"Epoch 1/10, Loss: 19.15451431274414, Acc: 0.7202000021934509\n",
|
||||
"Epoch 2/10, Loss: 12.260371208190918, Acc: 0.7486000061035156\n",
|
||||
"Epoch 3/10, Loss: 10.835549354553223, Acc: 0.7615999579429626\n",
|
||||
"Epoch 4/10, Loss: 10.09542179107666, Acc: 0.7701999545097351\n",
|
||||
"Epoch 5/10, Loss: 9.626176834106445, Acc: 0.777899980545044\n",
|
||||
"Epoch 6/10, Loss: 9.264442443847656, Acc: 0.7854999899864197\n",
|
||||
"Epoch 7/10, Loss: 9.017412185668945, Acc: 0.7879999876022339\n",
|
||||
"Epoch 8/10, Loss: 8.786051750183105, Acc: 0.7915999889373779\n",
|
||||
"Epoch 9/10, Loss: 8.613431930541992, Acc: 0.79749995470047\n",
|
||||
"Epoch 10/10, Loss: 8.462657928466797, Acc: 0.7996999621391296\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1297,7 +1317,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
39
Lab1/code/.ipynb_checkpoints/1.1-checkpoint.py
Normal file
39
Lab1/code/.ipynb_checkpoints/1.1-checkpoint.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
A = torch.tensor([[1, 2, 3]])
|
||||
|
||||
B = torch.tensor([[4],
|
||||
[5]])
|
||||
|
||||
# 方法1: 使用PyTorch的减法操作符
|
||||
result1 = A - B
|
||||
|
||||
# 方法2: 使用PyTorch的sub函数
|
||||
result2 = torch.sub(A, B)
|
||||
|
||||
# 方法3: 手动实现广播机制并作差
|
||||
def my_sub(a:torch.Tensor, b:torch.Tensor):
|
||||
if not (
|
||||
(a.size(0) == 1 and b.size(1) == 1)
|
||||
or
|
||||
(a.size(1) == 1 and b.size(0) == 1)
|
||||
):
|
||||
raise ValueError("输入的张量大小无法满足广播机制的条件。")
|
||||
else:
|
||||
target_shape = torch.Size([max(A.size(0), B.size(0)), max(A.size(1), B.size(1))])
|
||||
A_broadcasted = A.expand(target_shape)
|
||||
B_broadcasted = B.expand(target_shape)
|
||||
result = torch.zeros(target_shape, dtype=torch.int64).to(device=A_broadcasted.device)
|
||||
for i in range(target_shape[0]):
|
||||
for j in range(target_shape[1]):
|
||||
result[i, j] = A_broadcasted[i, j] - B_broadcasted[i, j]
|
||||
return result
|
||||
|
||||
result3 = my_sub(A, B)
|
||||
|
||||
print("方法1的结果:")
|
||||
print(result1)
|
||||
print("方法2的结果:")
|
||||
print(result2)
|
||||
print("方法3的结果:")
|
||||
print(result3)
|
Loading…
x
Reference in New Issue
Block a user