58 lines
2.3 KiB
Python
58 lines
2.3 KiB
Python
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
import os
|
|
import clip
|
|
|
|
|
|
class Classes:
|
|
def __init__(self, classes_file):
|
|
self.class2index = {}
|
|
self.index2class = {}
|
|
classes = open(classes_file).readlines()
|
|
classes = [line.strip() for line in classes]
|
|
for row in classes:
|
|
index, birdname = row.split(' ')
|
|
index = int(index)
|
|
birdname = (birdname.split('.'))[1].replace('_', ' ')
|
|
self.class2index['A photo of ' + birdname] = index - 1
|
|
self.index2class[index - 1] = 'A photo of ' + birdname
|
|
def __len__(self):
|
|
return len(self.class2index)
|
|
def get_class(self, num: int):
|
|
return self.index2class[num] if (num in self.index2class) else None
|
|
def get_id(self, class_name: str):
|
|
return (
|
|
self.class2index[class_name] if (class_name in self.class2index) else None
|
|
)
|
|
|
|
|
|
class MyDataset(Dataset):
|
|
def __init__(self, processor, train=True):
|
|
classes = Classes('/home/kejingfan/cub/classes.txt')
|
|
class_list = [classes.get_class(i) for i in range(len(classes))]
|
|
self.tokens = clip.tokenize(class_list)
|
|
|
|
self.img_process = processor
|
|
self.root_dir = '/home/kejingfan/cub/images'
|
|
images_list = open('/home/kejingfan/cub/images.txt').readlines()
|
|
images_list = [line.strip().split(' ')[1] for line in images_list]
|
|
self.images = []
|
|
labels_file = open('/home/kejingfan/cub/image_class_labels.txt').readlines()
|
|
labels = [int(line.strip().split(' ')[1]) for line in labels_file]
|
|
train_test_split_file = open('/home/kejingfan/cub/train_test_split.txt').readlines()
|
|
is_train = [line.strip().split(' ')[1] == '1' for line in train_test_split_file]
|
|
for index in range(len(images_list)):
|
|
class_id = labels[index]
|
|
if (train and is_train[index]) or (not train and not is_train[index]):
|
|
self.images.append([os.path.join(self.root_dir, images_list[index]), int(class_id) - 1])
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
def __getitem__(self, index):
|
|
image, target = self.images[index]
|
|
token = self.tokens[target]
|
|
image = Image.open(image).convert("RGB")
|
|
image = self.img_process(image)
|
|
return image, token, target
|