rerun and refacotr content

This commit is contained in:
Jingfan Ke 2023-10-10 13:56:56 +08:00
parent f814db12ae
commit c384059131

View File

@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 20,
"id": "a4e12268-bad4-44c4-92d5-883624d93e25",
"metadata": {},
"outputs": [],
@ -69,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 21,
"id": "79ea46db-cf49-436c-9b5b-c6562d0da9e2",
"metadata": {},
"outputs": [
@ -138,7 +138,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 22,
"id": "41e4ee02-1d05-4101-b3f0-477bac0277fb",
"metadata": {},
"outputs": [
@ -147,21 +147,21 @@
"output_type": "stream",
"text": [
"矩阵 P:\n",
"tensor([[ 0.0098, -0.0111],\n",
" [-0.0057, 0.0051],\n",
" [-0.0180, 0.0194]])\n",
"tensor([[-0.0094, -0.0073],\n",
" [-0.0087, -0.0008],\n",
" [-0.0012, 0.0103]])\n",
"矩阵 Q:\n",
"tensor([[ 0.0010, -0.0026],\n",
" [-0.0095, -0.0059],\n",
" [-0.0168, 0.0194],\n",
" [ 0.0022, 0.0125]])\n",
"tensor([[ 0.0094, -0.0126],\n",
" [-0.0082, 0.0005],\n",
" [-0.0079, -0.0101],\n",
" [-0.0002, -0.0161]])\n",
"矩阵 QT:\n",
"tensor([[ 0.0010, -0.0095, -0.0168, 0.0022],\n",
" [-0.0026, -0.0059, 0.0194, 0.0125]])\n",
"tensor([[ 0.0094, -0.0082, -0.0079, -0.0002],\n",
" [-0.0126, 0.0005, -0.0101, -0.0161]])\n",
"矩阵相乘的结果:\n",
"tensor([[ 3.8758e-05, -2.7672e-05, -3.7944e-04, -1.1683e-04],\n",
" [-1.8842e-05, 2.4259e-05, 1.9324e-04, 5.0424e-05],\n",
" [-6.8471e-05, 5.7510e-05, 6.7733e-04, 2.0131e-04]])\n"
"tensor([[ 4.8768e-06, 7.3478e-05, 1.4821e-04, 1.2020e-04],\n",
" [-7.1462e-05, 7.1558e-05, 7.7439e-05, 1.4915e-05],\n",
" [-1.4130e-04, 1.4922e-05, -9.4810e-05, -1.6629e-04]])\n"
]
}
],
@ -199,7 +199,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 23,
"id": "951512cd-d915-4d04-959f-eb99d1971e2d",
"metadata": {},
"outputs": [
@ -250,7 +250,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 24,
"id": "e31b86ec-4114-48dd-8d73-fe4e0686419a",
"metadata": {},
"outputs": [
@ -304,7 +304,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 25,
"id": "0297066c-9fc1-448d-bdcb-29a6f1519117",
"metadata": {},
"outputs": [
@ -364,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 26,
"id": "8e18695a-d8c5-4f77-8b5c-de40d9240fb9",
"metadata": {},
"outputs": [
@ -446,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 27,
"id": "e7de7e4b-a084-4793-812e-46e8550ecd8d",
"metadata": {},
"outputs": [],
@ -495,7 +495,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 28,
"id": "c39fbafb-62e4-4b8c-9d65-6718d25f2970",
"metadata": {},
"outputs": [
@ -566,7 +566,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 29,
"id": "5612661e-2809-4d46-96c2-33ee9f44116d",
"metadata": {},
"outputs": [
@ -574,18 +574,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10, Loss: 685.3895894885063, Acc: 0.9642275829459848\n",
"Epoch 2/10, Loss: 677.4121572375298, Acc: 0.9974711945592333\n",
"Epoch 3/10, Loss: 677.2220785021782, Acc: 0.9990452451894614\n",
"Epoch 4/10, Loss: 677.1839035749435, Acc: 0.9993094710137819\n",
"Epoch 5/10, Loss: 677.1762611865997, Acc: 0.9998272919002676\n",
"Epoch 6/10, Loss: 677.1740638613701, Acc: 0.9999073923880469\n",
"Epoch 7/10, Loss: 677.1739921569824, Acc: 0.9997274632391843\n",
"Epoch 8/10, Loss: 677.1744710803032, Acc: 0.9999882508320989\n",
"Epoch 9/10, Loss: 677.1742913126945, Acc: 0.999904539547138\n",
"Epoch 10/10, Loss: 677.173879802227, Acc: 0.9997605824956097\n",
"Model weights: -0.0010404698550701141, bias: 0.02203504741191864\n",
"Prediction for test data: 0.505248486995697\n"
"Epoch 1/10, Loss: 681.7797573804855, Acc: 0.9645753421871553\n",
"Epoch 2/10, Loss: 677.2049961090088, Acc: 0.9990700532071279\n",
"Epoch 3/10, Loss: 677.1804099082947, Acc: 0.9996768410577491\n",
"Epoch 4/10, Loss: 677.175698697567, Acc: 0.9996360650992927\n",
"Epoch 5/10, Loss: 677.1747546195984, Acc: 0.999986148984189\n",
"Epoch 6/10, Loss: 677.1744914650917, Acc: 0.9998796786709696\n",
"Epoch 7/10, Loss: 677.1742819547653, Acc: 0.9999521451026462\n",
"Epoch 8/10, Loss: 677.1738398075104, Acc: 0.9999777880946412\n",
"Epoch 9/10, Loss: 677.1740134358406, Acc: 0.9997993523341308\n",
"Epoch 10/10, Loss: 677.1745718121529, Acc: 0.9998104022783462\n",
"Model weights: -0.0036095045506954193, bias: 0.016485782340168953\n",
"Prediction for test data: 0.5032190084457397\n"
]
}
],
@ -653,7 +653,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 30,
"id": "fa121afd-a1af-4193-9b54-68041e0ed068",
"metadata": {},
"outputs": [],
@ -679,7 +679,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 31,
"id": "93b0fdb6-be8b-4663-b59e-05ed19a9ea09",
"metadata": {},
"outputs": [
@ -687,18 +687,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10, Loss: 576.7015416165058, Acc: 0.9735617914738028\n",
"Epoch 2/10, Loss: 565.9262382361084, Acc: 0.9999925140596344\n",
"Epoch 3/10, Loss: 565.9295897295112, Acc: 0.9999952212094322\n",
"Epoch 4/10, Loss: 565.9272355019373, Acc: 0.9999899716045327\n",
"Epoch 5/10, Loss: 565.9276486165418, Acc: 0.9999941261622728\n",
"Epoch 6/10, Loss: 565.9258608743777, Acc: 0.999994099092236\n",
"Epoch 7/10, Loss: 565.9304406750343, Acc: 0.9999997538554865\n",
"Epoch 8/10, Loss: 565.9290585726536, Acc: 0.9999990918784897\n",
"Epoch 9/10, Loss: 565.9277625135361, Acc: 0.9999886345247774\n",
"Epoch 10/10, Loss: 565.9291837050997, Acc: 0.9999944677252854\n",
"Model weights: -3.712182683343629, bias: 1.8752337556721546\n",
"Prediction for test data: 0.13741241440796031\n"
"Epoch 1/10, Loss: 582.1114150944223, Acc: 0.9622313076669672\n",
"Epoch 2/10, Loss: 565.93256158834, Acc: 0.9999686256703629\n",
"Epoch 3/10, Loss: 565.9305296230643, Acc: 0.9999988205402547\n",
"Epoch 4/10, Loss: 565.9292865398384, Acc: 0.9999988799203948\n",
"Epoch 5/10, Loss: 565.928863850198, Acc: 0.9999991768121363\n",
"Epoch 6/10, Loss: 565.9304914128694, Acc: 0.9999969140456769\n",
"Epoch 7/10, Loss: 565.9264041730053, Acc: 0.9999955753695261\n",
"Epoch 8/10, Loss: 565.9313891761873, Acc: 0.9999980937154029\n",
"Epoch 9/10, Loss: 565.9266170542029, Acc: 0.9999949410275989\n",
"Epoch 10/10, Loss: 565.9337094448973, Acc: 0.9999975812010478\n",
"Model weights: -3.7012964947839575, bias: 1.8774806436910758\n",
"Prediction for test data: 0.13897650708244993\n"
]
}
],
@ -787,7 +787,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 32,
"id": "e605f1b0-1d32-410f-bddf-402a85ccc9ff",
"metadata": {},
"outputs": [
@ -851,7 +851,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 33,
"id": "759a3bb2-b5f4-4ea5-a2d7-15f0c4cdd14b",
"metadata": {},
"outputs": [
@ -860,15 +860,15 @@
"output_type": "stream",
"text": [
"输入:\n",
"tensor([[ 0.9415, 0.4358, -1.1650, 0.4496, -0.9394],\n",
" [-0.1956, -0.1466, -0.7704, 0.1465, -0.4571],\n",
" [-0.9923, -1.0455, -0.4241, 0.3850, 2.1680]], requires_grad=True)\n",
"tensor([[ 0.4113, 1.0890, -0.4301, -0.1975, 2.2331],\n",
" [ 0.7901, 1.8117, -2.3197, -0.8144, -0.5751],\n",
" [-1.8110, -0.5550, -0.2773, 2.3990, 0.1804]], requires_grad=True)\n",
"标签:\n",
"tensor([[0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 1.],\n",
" [0., 0., 0., 0., 1.]])\n",
"My_CrossEntropyLoss损失值: 1.1712640523910522\n",
"nn.CrossEntropyLoss损失值: 1.1712640523910522\n"
"tensor([[0., 1., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0.]])\n",
"My_CrossEntropyLoss损失值: 1.1033374071121216\n",
"nn.CrossEntropyLoss损失值: 1.1033374071121216\n"
]
}
],
@ -913,7 +913,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 34,
"id": "74322629-8325-4823-b80f-f28182d577c1",
"metadata": {},
"outputs": [
@ -974,7 +974,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 35,
"id": "bb31a75e-464c-4b94-b927-b219a765e35d",
"metadata": {},
"outputs": [],
@ -1034,7 +1034,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 36,
"id": "d816dae1-5fbe-4c29-9597-19d66b5eb6b4",
"metadata": {},
"outputs": [
@ -1134,7 +1134,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 37,
"id": "0163b9f7-1019-429c-8c29-06436d0a4c98",
"metadata": {},
"outputs": [],
@ -1161,7 +1161,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 38,
"id": "a58a23e1-368c-430a-ad62-0e256dff564d",
"metadata": {},
"outputs": [
@ -1169,16 +1169,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10, Loss: 15.949012756347656, Acc: 0.7468000054359436\n",
"Epoch 2/10, Loss: 9.318169593811035, Acc: 0.7906999588012695\n",
"Epoch 3/10, Loss: 8.015625953674316, Acc: 0.8120999932289124\n",
"Epoch 4/10, Loss: 7.471133708953857, Acc: 0.8168999552726746\n",
"Epoch 5/10, Loss: 7.215029239654541, Acc: 0.8253999948501587\n",
"Epoch 6/10, Loss: 7.007692337036133, Acc: 0.8244999647140503\n",
"Epoch 7/10, Loss: 6.847175598144531, Acc: 0.828499972820282\n",
"Epoch 8/10, Loss: 6.6865668296813965, Acc: 0.8323000073432922\n",
"Epoch 9/10, Loss: 6.595873832702637, Acc: 0.8307999968528748\n",
"Epoch 10/10, Loss: 6.535965919494629, Acc: 0.8348999619483948\n"
"Epoch 1/10, Loss: 15.768913269042969, Acc: 0.7530999779701233\n",
"Epoch 2/10, Loss: 9.122207641601562, Acc: 0.7967000007629395\n",
"Epoch 3/10, Loss: 7.9603657722473145, Acc: 0.8100999593734741\n",
"Epoch 4/10, Loss: 7.427120208740234, Acc: 0.8179000020027161\n",
"Epoch 5/10, Loss: 7.115703582763672, Acc: 0.8248999714851379\n",
"Epoch 6/10, Loss: 6.900459289550781, Acc: 0.8259999752044678\n",
"Epoch 7/10, Loss: 6.802896976470947, Acc: 0.8269000053405762\n",
"Epoch 8/10, Loss: 6.687209606170654, Acc: 0.832099974155426\n",
"Epoch 9/10, Loss: 6.6183180809021, Acc: 0.833299994468689\n",
"Epoch 10/10, Loss: 6.531178951263428, Acc: 0.8341999650001526\n"
]
}
],
@ -1252,7 +1252,7 @@
"\n",
"首先是数据集的设置。如果数据没有进行归一化,很容易出现梯度爆炸。这是在我以前直接使用图片数据集的经历中没有遇到过的问题。\n",
"\n",
"在实现logistic回归模型时通过手动实现各个组件如优化器、线性层等让我对这些模块的工作原理有了更清晰的认识。尤其是在实现广播机制时需要充分理解张量操作的维度变换规律。而使用Pytorch内置模块进行实现时通过继承nn.Module可以自动获得多功能,使代码更加简洁。\n",
"在实现logistic回归模型时通过手动实现各个组件如优化器、线性层等让我对这些模块的工作原理有了更清晰的认识。尤其是在实现广播机制时需要充分理解张量操作的维度变换规律。而使用Pytorch内置模块进行实现时通过继承nn.Module可以自动获得多功能,使代码更加简洁。\n",
"\n",
"在实现softmax回归时则遇到了更大的困难。手动实现的模型很容易出现梯度爆炸的问题而使用Pytorch内置的损失函数和优化器则可以稳定训练。这让我意识到了选择合适的优化方法的重要性。另外Pytorch强大的自动微分机制也是构建深度神经网络的重要基础。\n",
"\n",