From:migrating_from_torchvision_to_albumentations

torchvision 提供了 transforms 数据增强库.

torchvision - transform 数据增强 - AIUAI

Albumentations 是一个更强大的数据增强库.

Python库 - Albumentations 图片数据增强库

1. 基于 torchvision 的 pipeline

from PIL import Image
import cv2
import numpy as np

from torch.utils.data import Dataset
from torchvision import transforms

class TorchvisionDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]
        
        # Read an image with PIL
        image = Image.open(file_path)
        if self.transform:
            image = self.transform(image)
        return image, label

# 
torchvision_transform = transforms.Compose([
    transforms.Resize((256, 256)), 
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
])

#
torchvision_dataset = TorchvisionDataset(
    file_paths=['./images/image_1.jpg', 
                './images/image_2.jpg', 
                './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=torchvision_transform,
)

2. 基于 albumentations 的 pipline

2.1. opencv 版

import cv2
import numpy as np
from torch.utils.data import Dataset
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize
from albumentations.pytorch import ToTensor

class AlbumentationsDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]
        
        # Read an image with OpenCV
        image = cv2.imread(file_path)
        
        # By default OpenCV uses BGR color space for color images,
        # so we need to convert the image to RGB color space.
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label

#
albumentations_transform = Compose([
    Resize(256, 256), 
    RandomCrop(224, 224),
    HorizontalFlip(),
    Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensor()
])

#
albumentations_dataset = AlbumentationsDataset(
    file_paths=['./images/image_1.jpg', 
                './images/image_2.jpg', 
                './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=albumentations_transform,
)

2.2. PIL 版

采用 PIL库在数据增强前,需要现将 PIL 图像转换为 numpy 数组;最后再将数据增强后的 numpy 数组转换为 PIL 图像.

from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize
from albumentations.pytorch import ToTensor

class AlbumentationsPilDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        image = Image.open(file_path)
        
        if self.transform:
            # Convert PIL image to numpy array
            image_np = np.array(image)
            # Apply transformations
            augmented = self.transform(image=image_np)
            # Convert numpy array to PIL Image
            image = Image.fromarray(augmented['image'])
        return image, label

#
albumentations_pil_transform = Compose([
    Resize(256, 256), 
    RandomCrop(224, 224),
    HorizontalFlip(),
])


# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
    file_paths=['./images/image_1.jpg', 
                './images/image_2.jpg', 
                './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=albumentations_pil_transform,
)

3. torchvision 与 albumentations 等价的变换

torchvision transformalbumentations transformalbumentations example
ComposeComposeCompose([Resize(256, 256), RandomCrop(224, 224)])
CenterCropCenterCropCenterCrop(256, 256)
ColorJitterHueSaturationValueHueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5)
PadPadIfNeededPadIfNeeded(min_height=512, min_width=512)
RandomAffineShiftScaleRotateShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5)
RandomCropRandomCropRandomCrop(256, 256)
RandomGrayscaleToGrayToGray(p=0.5)
RandomHorizontalFlipHorizontalFlipHorizontalFlip(p=0.5)
RandomRotationRotateRotate(limit=45, p=0.5)
RandomVerticalFlipVerticalFlipVerticalFlip(p=0.5)
ResizeResizeResize(256, 256)
NormalizeNormalizeNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Last modification:June 16th, 2020 at 04:31 pm