import torch import torch.nn import clip import numpy as np from tqdm import tqdm from get_loader import Classes def test(net, test_dataset, test_loader, device): net.eval() total_accuracy = 0.0 texts = test_dataset.tokens.to(device) with torch.no_grad(): for index, (images, tokens, targets) in tqdm(enumerate(test_loader), total=len(test_loader)): images = images.to(device) logits_per_image, logits_per_text = net(images, texts) probs = logits_per_image.softmax(dim=-1).cpu().numpy() accuracy = np.sum(probs.argmax(1) == targets.numpy()) total_accuracy += accuracy net.train() return total_accuracy / len(test_dataset)