| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- """
- Transforms and data augmentation for both image + bbox.
- """
- import logging
- import random
- from typing import Iterable
- import torch
- import torchvision.transforms as T
- import torchvision.transforms.functional as F
- import torchvision.transforms.v2.functional as Fv2
- from PIL import Image as PILImage
- from torchvision.transforms import InterpolationMode
- from training.utils.data_utils import VideoDatapoint
- def hflip(datapoint, index):
- datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
- for obj in datapoint.frames[index].objects:
- if obj.segment is not None:
- obj.segment = F.hflip(obj.segment)
- return datapoint
- def get_size_with_aspect_ratio(image_size, size, max_size=None):
- w, h = image_size
- if max_size is not None:
- min_original_size = float(min((w, h)))
- max_original_size = float(max((w, h)))
- if max_original_size / min_original_size * size > max_size:
- size = max_size * min_original_size / max_original_size
- if (w <= h and w == size) or (h <= w and h == size):
- return (h, w)
- if w < h:
- ow = int(round(size))
- oh = int(round(size * h / w))
- else:
- oh = int(round(size))
- ow = int(round(size * w / h))
- return (oh, ow)
- def resize(datapoint, index, size, max_size=None, square=False, v2=False):
- # size can be min_size (scalar) or (w, h) tuple
- def get_size(image_size, size, max_size=None):
- if isinstance(size, (list, tuple)):
- return size[::-1]
- else:
- return get_size_with_aspect_ratio(image_size, size, max_size)
- if square:
- size = size, size
- else:
- cur_size = (
- datapoint.frames[index].data.size()[-2:][::-1]
- if v2
- else datapoint.frames[index].data.size
- )
- size = get_size(cur_size, size, max_size)
- old_size = (
- datapoint.frames[index].data.size()[-2:][::-1]
- if v2
- else datapoint.frames[index].data.size
- )
- if v2:
- datapoint.frames[index].data = Fv2.resize(
- datapoint.frames[index].data, size, antialias=True
- )
- else:
- datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
- new_size = (
- datapoint.frames[index].data.size()[-2:][::-1]
- if v2
- else datapoint.frames[index].data.size
- )
- for obj in datapoint.frames[index].objects:
- if obj.segment is not None:
- obj.segment = F.resize(obj.segment[None, None], size).squeeze()
- h, w = size
- datapoint.frames[index].size = (h, w)
- return datapoint
- def pad(datapoint, index, padding, v2=False):
- old_h, old_w = datapoint.frames[index].size
- h, w = old_h, old_w
- if len(padding) == 2:
- # assumes that we only pad on the bottom right corners
- datapoint.frames[index].data = F.pad(
- datapoint.frames[index].data, (0, 0, padding[0], padding[1])
- )
- h += padding[1]
- w += padding[0]
- else:
- # left, top, right, bottom
- datapoint.frames[index].data = F.pad(
- datapoint.frames[index].data,
- (padding[0], padding[1], padding[2], padding[3]),
- )
- h += padding[1] + padding[3]
- w += padding[0] + padding[2]
- datapoint.frames[index].size = (h, w)
- for obj in datapoint.frames[index].objects:
- if obj.segment is not None:
- if v2:
- if len(padding) == 2:
- obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
- else:
- obj.segment = Fv2.pad(obj.segment, tuple(padding))
- else:
- if len(padding) == 2:
- obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
- else:
- obj.segment = F.pad(obj.segment, tuple(padding))
- return datapoint
- class RandomHorizontalFlip:
- def __init__(self, consistent_transform, p=0.5):
- self.p = p
- self.consistent_transform = consistent_transform
- def __call__(self, datapoint, **kwargs):
- if self.consistent_transform:
- if random.random() < self.p:
- for i in range(len(datapoint.frames)):
- datapoint = hflip(datapoint, i)
- return datapoint
- for i in range(len(datapoint.frames)):
- if random.random() < self.p:
- datapoint = hflip(datapoint, i)
- return datapoint
- class RandomResizeAPI:
- def __init__(
- self, sizes, consistent_transform, max_size=None, square=False, v2=False
- ):
- if isinstance(sizes, int):
- sizes = (sizes,)
- assert isinstance(sizes, Iterable)
- self.sizes = list(sizes)
- self.max_size = max_size
- self.square = square
- self.consistent_transform = consistent_transform
- self.v2 = v2
- def __call__(self, datapoint, **kwargs):
- if self.consistent_transform:
- size = random.choice(self.sizes)
- for i in range(len(datapoint.frames)):
- datapoint = resize(
- datapoint, i, size, self.max_size, square=self.square, v2=self.v2
- )
- return datapoint
- for i in range(len(datapoint.frames)):
- size = random.choice(self.sizes)
- datapoint = resize(
- datapoint, i, size, self.max_size, square=self.square, v2=self.v2
- )
- return datapoint
- class ToTensorAPI:
- def __init__(self, v2=False):
- self.v2 = v2
- def __call__(self, datapoint: VideoDatapoint, **kwargs):
- for img in datapoint.frames:
- if self.v2:
- img.data = Fv2.to_image_tensor(img.data)
- else:
- img.data = F.to_tensor(img.data)
- return datapoint
- class NormalizeAPI:
- def __init__(self, mean, std, v2=False):
- self.mean = mean
- self.std = std
- self.v2 = v2
- def __call__(self, datapoint: VideoDatapoint, **kwargs):
- for img in datapoint.frames:
- if self.v2:
- img.data = Fv2.convert_image_dtype(img.data, torch.float32)
- img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
- else:
- img.data = F.normalize(img.data, mean=self.mean, std=self.std)
- return datapoint
- class ComposeAPI:
- def __init__(self, transforms):
- self.transforms = transforms
- def __call__(self, datapoint, **kwargs):
- for t in self.transforms:
- datapoint = t(datapoint, **kwargs)
- return datapoint
- def __repr__(self):
- format_string = self.__class__.__name__ + "("
- for t in self.transforms:
- format_string += "\n"
- format_string += " {0}".format(t)
- format_string += "\n)"
- return format_string
- class RandomGrayscale:
- def __init__(self, consistent_transform, p=0.5):
- self.p = p
- self.consistent_transform = consistent_transform
- self.Grayscale = T.Grayscale(num_output_channels=3)
- def __call__(self, datapoint: VideoDatapoint, **kwargs):
- if self.consistent_transform:
- if random.random() < self.p:
- for img in datapoint.frames:
- img.data = self.Grayscale(img.data)
- return datapoint
- for img in datapoint.frames:
- if random.random() < self.p:
- img.data = self.Grayscale(img.data)
- return datapoint
- class ColorJitter:
- def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
- self.consistent_transform = consistent_transform
- self.brightness = (
- brightness
- if isinstance(brightness, list)
- else [max(0, 1 - brightness), 1 + brightness]
- )
- self.contrast = (
- contrast
- if isinstance(contrast, list)
- else [max(0, 1 - contrast), 1 + contrast]
- )
- self.saturation = (
- saturation
- if isinstance(saturation, list)
- else [max(0, 1 - saturation), 1 + saturation]
- )
- self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
- def __call__(self, datapoint: VideoDatapoint, **kwargs):
- if self.consistent_transform:
- # Create a color jitter transformation params
- (
- fn_idx,
- brightness_factor,
- contrast_factor,
- saturation_factor,
- hue_factor,
- ) = T.ColorJitter.get_params(
- self.brightness, self.contrast, self.saturation, self.hue
- )
- for img in datapoint.frames:
- if not self.consistent_transform:
- (
- fn_idx,
- brightness_factor,
- contrast_factor,
- saturation_factor,
- hue_factor,
- ) = T.ColorJitter.get_params(
- self.brightness, self.contrast, self.saturation, self.hue
- )
- for fn_id in fn_idx:
- if fn_id == 0 and brightness_factor is not None:
- img.data = F.adjust_brightness(img.data, brightness_factor)
- elif fn_id == 1 and contrast_factor is not None:
- img.data = F.adjust_contrast(img.data, contrast_factor)
- elif fn_id == 2 and saturation_factor is not None:
- img.data = F.adjust_saturation(img.data, saturation_factor)
- elif fn_id == 3 and hue_factor is not None:
- img.data = F.adjust_hue(img.data, hue_factor)
- return datapoint
- class RandomAffine:
- def __init__(
- self,
- degrees,
- consistent_transform,
- scale=None,
- translate=None,
- shear=None,
- image_mean=(123, 116, 103),
- log_warning=True,
- num_tentatives=1,
- image_interpolation="bicubic",
- ):
- """
- The mask is required for this transform.
- if consistent_transform if True, then the same random affine is applied to all frames and masks.
- """
- self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
- self.scale = scale
- self.shear = (
- shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
- )
- self.translate = translate
- self.fill_img = image_mean
- self.consistent_transform = consistent_transform
- self.log_warning = log_warning
- self.num_tentatives = num_tentatives
- if image_interpolation == "bicubic":
- self.image_interpolation = InterpolationMode.BICUBIC
- elif image_interpolation == "bilinear":
- self.image_interpolation = InterpolationMode.BILINEAR
- else:
- raise NotImplementedError
- def __call__(self, datapoint: VideoDatapoint, **kwargs):
- for _tentative in range(self.num_tentatives):
- res = self.transform_datapoint(datapoint)
- if res is not None:
- return res
- if self.log_warning:
- logging.warning(
- f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
- )
- return datapoint
- def transform_datapoint(self, datapoint: VideoDatapoint):
- _, height, width = F.get_dimensions(datapoint.frames[0].data)
- img_size = [width, height]
- if self.consistent_transform:
- # Create a random affine transformation
- affine_params = T.RandomAffine.get_params(
- degrees=self.degrees,
- translate=self.translate,
- scale_ranges=self.scale,
- shears=self.shear,
- img_size=img_size,
- )
- for img_idx, img in enumerate(datapoint.frames):
- this_masks = [
- obj.segment.unsqueeze(0) if obj.segment is not None else None
- for obj in img.objects
- ]
- if not self.consistent_transform:
- # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
- affine_params = T.RandomAffine.get_params(
- degrees=self.degrees,
- translate=self.translate,
- scale_ranges=self.scale,
- shears=self.shear,
- img_size=img_size,
- )
- transformed_bboxes, transformed_masks = [], []
- for i in range(len(img.objects)):
- if this_masks[i] is None:
- transformed_masks.append(None)
- # Dummy bbox for a dummy target
- transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
- else:
- transformed_mask = F.affine(
- this_masks[i],
- *affine_params,
- interpolation=InterpolationMode.NEAREST,
- fill=0.0,
- )
- if img_idx == 0 and transformed_mask.max() == 0:
- # We are dealing with a video and the object is not visible in the first frame
- # Return the datapoint without transformation
- return None
- transformed_masks.append(transformed_mask.squeeze())
- for i in range(len(img.objects)):
- img.objects[i].segment = transformed_masks[i]
- img.data = F.affine(
- img.data,
- *affine_params,
- interpolation=self.image_interpolation,
- fill=self.fill_img,
- )
- return datapoint
- def random_mosaic_frame(
- datapoint,
- index,
- grid_h,
- grid_w,
- target_grid_y,
- target_grid_x,
- should_hflip,
- ):
- # Step 1: downsize the images and paste them into a mosaic
- image_data = datapoint.frames[index].data
- is_pil = isinstance(image_data, PILImage.Image)
- if is_pil:
- H_im = image_data.height
- W_im = image_data.width
- image_data_output = PILImage.new("RGB", (W_im, H_im))
- else:
- H_im = image_data.size(-2)
- W_im = image_data.size(-1)
- image_data_output = torch.zeros_like(image_data)
- downsize_cache = {}
- for grid_y in range(grid_h):
- for grid_x in range(grid_w):
- y_offset_b = grid_y * H_im // grid_h
- x_offset_b = grid_x * W_im // grid_w
- y_offset_e = (grid_y + 1) * H_im // grid_h
- x_offset_e = (grid_x + 1) * W_im // grid_w
- H_im_downsize = y_offset_e - y_offset_b
- W_im_downsize = x_offset_e - x_offset_b
- if (H_im_downsize, W_im_downsize) in downsize_cache:
- image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
- else:
- image_data_downsize = F.resize(
- image_data,
- size=(H_im_downsize, W_im_downsize),
- interpolation=InterpolationMode.BILINEAR,
- antialias=True, # antialiasing for downsizing
- )
- downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
- if should_hflip[grid_y, grid_x].item():
- image_data_downsize = F.hflip(image_data_downsize)
- if is_pil:
- image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
- else:
- image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
- image_data_downsize
- )
- datapoint.frames[index].data = image_data_output
- # Step 2: downsize the masks and paste them into the target grid of the mosaic
- for obj in datapoint.frames[index].objects:
- if obj.segment is None:
- continue
- assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
- segment_output = torch.zeros_like(obj.segment)
- target_y_offset_b = target_grid_y * H_im // grid_h
- target_x_offset_b = target_grid_x * W_im // grid_w
- target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
- target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
- target_H_im_downsize = target_y_offset_e - target_y_offset_b
- target_W_im_downsize = target_x_offset_e - target_x_offset_b
- segment_downsize = F.resize(
- obj.segment[None, None],
- size=(target_H_im_downsize, target_W_im_downsize),
- interpolation=InterpolationMode.BILINEAR,
- antialias=True, # antialiasing for downsizing
- )[0, 0]
- if should_hflip[target_grid_y, target_grid_x].item():
- segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
- segment_output[
- target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
- ] = segment_downsize
- obj.segment = segment_output
- return datapoint
- class RandomMosaicVideoAPI:
- def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
- self.prob = prob
- self.grid_h = grid_h
- self.grid_w = grid_w
- self.use_random_hflip = use_random_hflip
- def __call__(self, datapoint, **kwargs):
- if random.random() > self.prob:
- return datapoint
- # select a random location to place the target mask in the mosaic
- target_grid_y = random.randint(0, self.grid_h - 1)
- target_grid_x = random.randint(0, self.grid_w - 1)
- # whether to flip each grid in the mosaic horizontally
- if self.use_random_hflip:
- should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
- else:
- should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
- for i in range(len(datapoint.frames)):
- datapoint = random_mosaic_frame(
- datapoint,
- i,
- grid_h=self.grid_h,
- grid_w=self.grid_w,
- target_grid_y=target_grid_y,
- target_grid_x=target_grid_x,
- should_hflip=should_hflip,
- )
- return datapoint
|