总结:
数据加载的时候会同时加载数据和数据增强方式。数据增强的时候会默认调用加载数据集时的getitem方法,去获取对应的数据和标签。
定义好普通的transformers
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
结合有Cutout的自定义Transformer
自定义的transformer可以接收两个参数:img,label。
传统的transformer只接受一个参数:img
class ConditionalTransform:
def __init__(self, transform, num_per_cls_dict):
self.transform = transform
self.num_per_cls_dict = num_per_cls_dict
self.n_holes_dict = {}
def __call__(self, img, label):
total_samples = sum(self.num_per_cls_dict.values())
cls_num_list = list(self.num_per_cls_dict.values())
per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).to(torch.device('cpu')) # 假设我们在 CPU 上运行
# 计算 n_holes 的数量
n_holes = 1 + int(per_cls_weights[label] * 3) # 确保 n_holes 在 1 到 4 之间
n_holes = min(4, max(1, n_holes))
# 保存每个类别的 n_holes 数量
self.n_holes_dict[label] = n_holes
img = self.transform(img)
cutout_transform = Cutout(n_holes=n_holes, length=16)
return cutout_transform(img)
cutout
class Cutout(object):
"""Randomly mask out one or more patches from an image.
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img
数据加载
标准的CIFAR10
部分代码:
class CIFAR10(VisionDataset):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
self.data: Any = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
__getitem__通常是会被自动调用的?
以下是一些常见的情况:
1. 索引操作
当你使用索引操作(如 dataset[index]
)访问数据集对象时,__getitem__
方法会被自动调用。例如:
dataset = CIFAR10(root='data', train=True, download=True)
image, label = dataset[0] # 这里会自动调用 dataset.__getitem__(0)
在上面的代码中,当你尝试访问 dataset[0]
时,__getitem__
方法会被调用,返回第一个图像和标签。
2. 与 DataLoader
一起使用
在深度学习中,__getitem__
方法经常与 PyTorch 的 DataLoader
类一起使用。DataLoader
会在训练或测试过程中自动调用数据集的 __getitem__
方法来获取数据。例如:
from torch.utils.data import DataLoader
dataset = CIFAR10(root='data', train=True, download=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for images, labels in dataloader:
# 这里会自动调用 dataset.__getitem__(index) 来获取数据
# 进行训练或测试的相关操作
pass
在这个例子中,DataLoader
会在每次迭代时调用 __getitem__
方法来获取数据集中的样本。
3. 自定义数据集
在创建自定义数据集时,你可以通过实现 __getitem__
方法来定义如何访问和处理数据。例如:
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# 这里定义如何获取和处理数据
return self.data[index], self.labels[index]
data = ... # 数据
labels = ... # 标签
dataset = MyDataset(data, labels)
在这个自定义数据集中,__getitem__
方法定义了如何根据索引访问数据和标签。
总结
__getitem__
方法是一种魔法方法(magic method),在特定场景下会被自动调用,尤其是当你使用索引操作访问对象或与某些库(如 PyTorch 的 DataLoader
)一起使用时。通过实现 __getitem__
方法,你可以自定义对象的索引行为,从而更方便地处理和访问数据。
自定义CIFAR10
class CustomCIFAR10(datasets.CIFAR10):
cls_num = 10
def __init__(self, root, imb_type='exp', imb_factor=0.01, train=True,
transform=None, target_transform=None,
download=False):
super(CustomCIFAR10, self).__init__(root, train=train, transform=None, target_transform=target_transform, download=download)
self.num_per_cls_dict = {}
if imb_type is not None:
img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
self.gen_imbalanced_data(img_num_list)
# 初始化条件变换,传入类别分布
if transform is None:
transform = transforms.Compose([
transforms.ToTensor()
])
# 初始化条件变换,传入类别分布
self.transform = ConditionalTransform(transform, self.num_per_cls_dict)
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img, target)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
和官方的cifar10的区别在与,自定义的getitem方法里的transformer里多了一个参数target
官方的:
img = self.transform(img)
自制的:
img = self.transform(img, target)
如果不重写getitem方法,就会出现这样的报错:
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 120, in __getitem__
img = self.transform(img)
TypeError: __call__() missing 1 required positional argument: 'label'