from PIL import Image from torch.utils.data import Dataset import os import clip class Classes: def __init__(self, classes_file): self.class2index = {} self.index2class = {} classes = open(classes_file).readlines() classes = [line.strip() for line in classes] for row in classes: index, birdname = row.split(' ') index = int(index) birdname = (birdname.split('.'))[1].replace('_', ' ') self.class2index['A photo of ' + birdname] = index - 1 self.index2class[index - 1] = 'A photo of ' + birdname def __len__(self): return len(self.class2index) def get_class(self, num: int): return self.index2class[num] if (num in self.index2class) else None def get_id(self, class_name: str): return ( self.class2index[class_name] if (class_name in self.class2index) else None ) class MyDataset(Dataset): def __init__(self, processor, train=True): classes = Classes('/home/kejingfan/cub/classes.txt') class_list = [classes.get_class(i) for i in range(len(classes))] self.tokens = clip.tokenize(class_list) self.img_process = processor self.root_dir = '/home/kejingfan/cub/images' images_list = open('/home/kejingfan/cub/images.txt').readlines() images_list = [line.strip().split(' ')[1] for line in images_list] self.images = [] labels_file = open('/home/kejingfan/cub/image_class_labels.txt').readlines() labels = [int(line.strip().split(' ')[1]) for line in labels_file] train_test_split_file = open('/home/kejingfan/cub/train_test_split.txt').readlines() is_train = [line.strip().split(' ')[1] == '1' for line in train_test_split_file] for index in range(len(images_list)): class_id = labels[index] if (train and is_train[index]) or (not train and not is_train[index]): self.images.append([os.path.join(self.root_dir, images_list[index]), int(class_id) - 1]) def __len__(self): return len(self.images) def __getitem__(self, index): image, target = self.images[index] token = self.tokens[target] image = Image.open(image).convert("RGB") image = self.img_process(image) return image, token, target