2024-09-05 12:56:46 +08:00

23 lines
717 B
Python

import torch
import torch.nn
import clip
import numpy as np
from tqdm import tqdm
from get_loader import Classes
def test(net, test_dataset, test_loader, device):
net.eval()
total_accuracy = 0.0
texts = test_dataset.tokens.to(device)
with torch.no_grad():
for index, (images, tokens, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):
images = images.to(device)
logits_per_image, logits_per_text = net(images, texts)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
accuracy = np.sum(probs.argmax(1) == targets.numpy())
total_accuracy += accuracy
net.train()
return total_accuracy / len(test_dataset)