import torch from torch import nn, optim from torch.utils.data import DataLoader from tqdm import tqdm import clip from get_loader import MyDataset from test import test def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() p.grad.data = p.grad.data.float() def train(): batch_size = 64 learning_rate = 1e-6 num_epochs = 500 device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") net, preprocess = clip.load("ViT-L/14", device=device, jit=False) if device == 'cpu': net.float() else: clip.model.convert_weights(net) loss_img = nn.CrossEntropyLoss() loss_txt = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2) train_dateset = MyDataset(processor=preprocess, train=True) train_loader = DataLoader(train_dateset, batch_size=batch_size, shuffle=True, num_workers=64, pin_memory=True) test_dataset = MyDataset(processor=preprocess, train=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=64, shuffle=True, pin_memory=True) print(f'Train dataset size: {len(train_dateset)}\nTest dataset size: {len(test_dataset)}\n') for epoch in range(num_epochs): total_epoch_loss = 0 for index, (images, tokens, targets) in tqdm(enumerate(train_loader), total=len(train_loader)): optimizer.zero_grad() images = images.to(device) tokens = tokens.to(device) with torch.set_grad_enabled(True): logits_per_image, logits_per_text = net(images, tokens) ground_truth = torch.arange(len(images), dtype=torch.long, device=device) cur_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2 total_epoch_loss += cur_loss.item() cur_loss.backward() if device == 'cpu': optimizer.step() else: convert_models_to_fp32(net) optimizer.step() clip.model.convert_weights(net) test_acc = test(net, test_dataset, test_loader, device) print(f'Total train loss: {total_epoch_loss:.6f}, Test accuracy: {test_acc:.6%}') print("--------------------------------------------------------------") torch.save({'epoch': epoch, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': total_epoch_loss, }, f"model_checkpoint/model-{epoch + 1}_acc-{test_acc*100:.3f}.pt") if __name__ == "__main__": train()