98 lines
2.8 KiB
Python

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import ipdb
class Model_3_2(nn.Module):
def __init__(self, num_classes):
super(Model_3_2, self).__init__()
self.flatten = nn.Flatten()
self.linear = nn.Linear(28 * 28, num_classes)
def forward(self, x: torch.Tensor):
x = self.flatten(x)
x = self.linear(x)
return x
if __name__ == "__main__":
learning_rate = 5e-2
num_epochs = 10
batch_size = 4096
num_classes = 10
device = "cuda:0" if torch.cuda.is_available() else "cpu"
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,)),
]
)
train_dataset = datasets.FashionMNIST(
root="../dataset", train=True, transform=transform, download=True
)
test_dataset = datasets.FashionMNIST(
root="../dataset", train=False, transform=transform, download=True
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=14,
pin_memory=True,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=14,
pin_memory=True,
)
model = Model_3_2(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
total_epoch_loss = 0
model.train()
for index, (images, targets) in tqdm(
enumerate(train_loader), total=len(train_loader)
):
optimizer.zero_grad()
images = images.to(device)
targets = targets.to(device)
one_hot_targets = (
torch.nn.functional.one_hot(targets, num_classes=num_classes)
.to(device)
.to(dtype=torch.float32)
)
outputs = model(images)
loss = criterion(outputs, one_hot_targets)
total_epoch_loss += loss
loss.backward()
optimizer.step()
model.eval()
total_acc = 0
with torch.no_grad():
for index, (image, targets) in tqdm(
enumerate(test_loader), total=len(test_loader)
):
image = image.to(device)
targets = targets.to(device)
outputs = model(image)
total_acc += (outputs.argmax(1) == targets).sum()
print(
f"Epoch {epoch + 1}/{num_epochs} Train, Loss: {total_epoch_loss}, Acc: {total_acc / len(test_dataset)}"
)