23 lines
717 B
Python
23 lines
717 B
Python
import torch
|
|
import torch.nn
|
|
import clip
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from get_loader import Classes
|
|
|
|
|
|
def test(net, test_dataset, test_loader, device):
|
|
net.eval()
|
|
total_accuracy = 0.0
|
|
texts = test_dataset.tokens.to(device)
|
|
with torch.no_grad():
|
|
for index, (images, tokens, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):
|
|
images = images.to(device)
|
|
logits_per_image, logits_per_text = net(images, texts)
|
|
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
|
accuracy = np.sum(probs.argmax(1) == targets.numpy())
|
|
total_accuracy += accuracy
|
|
net.train()
|
|
return total_accuracy / len(test_dataset)
|