# Transforms
TRANSFORMS = v2.Compose([v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.CenterCrop(size=(160,160)),
v2.Resize(size=(224,224))])
NORMALIZE = v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# Datasets
train_ds = Imagenette(root="./example_data", split='train', size='160px', transform=TRANSFORMS, download=False)
ref_set = torch.utils.data.Subset(train_ds, indices=[i for i in range(0, len(train_ds), 10)])
val_ds = Imagenette(root="./example_data", split='val', size='160px', transform=TRANSFORMS)
# DataLoaders
ref_loader = DataLoader(ref_set, batch_size=32, shuffle=False)
inf_loader = DataLoader(val_ds, batch_size=6, shuffle=True)
# Labels mapping
CLASS_NAMES = ["tench", "English springer", "cassette player",
"chain saw", "church", "French horn", "garbage truck",
"gas pump", "golf ball", "parachute"]
LOGIT2NAME = {
0: "tench",
1: "English springer",
2: "cassette player",
3: "chain saw",
4: "church",
5: "French horn",
6: "garbage truck",
7: "gas pump",
8: "golf ball",
9: "parachute"
}