| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import colorsys
- from dataclasses import dataclass
- from typing import List, Tuple
- import cv2
- import matplotlib as mpl
- import matplotlib.colors as mplc
- import numpy as np
- import pycocotools.mask as mask_utils
- def rgb_to_hex(rgb_color):
- """
- Convert a rgb color to hex color.
- Args:
- rgb_color (tuple/list of ints): RGB color in tuple or list format.
- Returns:
- str: Hex color.
- Example:
- ```
- >>> rgb_to_hex((255, 0, 244))
- '#ff00ff'
- ```
- """
- return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color])
- # DEFAULT_COLOR_HEX_TO_NAME = {
- # rgb_to_hex((255, 0, 0)): "red",
- # rgb_to_hex((0, 255, 0)): "lime",
- # rgb_to_hex((0, 0, 255)): "blue",
- # rgb_to_hex((255, 255, 0)): "yellow",
- # rgb_to_hex((255, 0, 255)): "fuchsia",
- # rgb_to_hex((0, 255, 255)): "aqua",
- # rgb_to_hex((255, 165, 0)): "orange",
- # rgb_to_hex((128, 0, 128)): "purple",
- # rgb_to_hex((255, 215, 0)): "gold",
- # }
- # Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string.
- # For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb
- DEFAULT_COLOR_HEX_TO_NAME = {
- # The top 20 approved colors
- rgb_to_hex((255, 255, 0)): "yellow",
- rgb_to_hex((0, 255, 0)): "lime",
- rgb_to_hex((0, 255, 255)): "cyan",
- rgb_to_hex((255, 0, 255)): "magenta",
- rgb_to_hex((255, 0, 0)): "red",
- rgb_to_hex((255, 127, 0)): "orange",
- rgb_to_hex((127, 255, 0)): "chartreuse",
- rgb_to_hex((0, 255, 127)): "spring green",
- rgb_to_hex((255, 0, 127)): "rose",
- rgb_to_hex((127, 0, 255)): "violet",
- rgb_to_hex((192, 255, 0)): "electric lime",
- rgb_to_hex((255, 192, 0)): "vivid orange",
- rgb_to_hex((0, 255, 192)): "turquoise",
- rgb_to_hex((192, 0, 255)): "bright violet",
- rgb_to_hex((255, 0, 192)): "bright pink",
- rgb_to_hex((255, 64, 0)): "fiery orange",
- rgb_to_hex((64, 255, 0)): "bright chartreuse",
- rgb_to_hex((0, 255, 64)): "malachite",
- rgb_to_hex((64, 0, 255)): "deep violet",
- rgb_to_hex((255, 0, 64)): "hot pink",
- }
- DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys())
- def _validate_color_hex(color_hex: str):
- color_hex = color_hex.lstrip("#")
- if not all(c in "0123456789abcdefABCDEF" for c in color_hex):
- raise ValueError("Invalid characters in color hash")
- if len(color_hex) not in (3, 6):
- raise ValueError("Invalid length of color hash")
- # copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py
- @dataclass
- class Color:
- """
- Represents a color in RGB format.
- Attributes:
- r (int): Red channel.
- g (int): Green channel.
- b (int): Blue channel.
- """
- r: int
- g: int
- b: int
- @classmethod
- def from_hex(cls, color_hex: str):
- """
- Create a Color instance from a hex string.
- Args:
- color_hex (str): Hex string of the color.
- Returns:
- Color: Instance representing the color.
- Example:
- ```
- >>> Color.from_hex('#ff00ff')
- Color(r=255, g=0, b=255)
- ```
- """
- _validate_color_hex(color_hex)
- color_hex = color_hex.lstrip("#")
- if len(color_hex) == 3:
- color_hex = "".join(c * 2 for c in color_hex)
- r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2))
- return cls(r, g, b)
- @classmethod
- def to_hex(cls, color):
- """
- Convert a Color instance to a hex string.
- Args:
- color (Color): Color instance of color.
- Returns:
- Color: a hex string.
- """
- return rgb_to_hex((color.r, color.g, color.b))
- def as_rgb(self) -> Tuple[int, int, int]:
- """
- Returns the color as an RGB tuple.
- Returns:
- Tuple[int, int, int]: RGB tuple.
- Example:
- ```
- >>> color.as_rgb()
- (255, 0, 255)
- ```
- """
- return self.r, self.g, self.b
- def as_bgr(self) -> Tuple[int, int, int]:
- """
- Returns the color as a BGR tuple.
- Returns:
- Tuple[int, int, int]: BGR tuple.
- Example:
- ```
- >>> color.as_bgr()
- (255, 0, 255)
- ```
- """
- return self.b, self.g, self.r
- @classmethod
- def white(cls):
- return Color.from_hex(color_hex="#ffffff")
- @classmethod
- def black(cls):
- return Color.from_hex(color_hex="#000000")
- @classmethod
- def red(cls):
- return Color.from_hex(color_hex="#ff0000")
- @classmethod
- def green(cls):
- return Color.from_hex(color_hex="#00ff00")
- @classmethod
- def blue(cls):
- return Color.from_hex(color_hex="#0000ff")
- @dataclass
- class ColorPalette:
- colors: List[Color]
- @classmethod
- def default(cls):
- """
- Returns a default color palette.
- Returns:
- ColorPalette: A ColorPalette instance with default colors.
- Example:
- ```
- >>> ColorPalette.default()
- ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
- ```
- """
- return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE)
- @classmethod
- def from_hex(cls, color_hex_list: List[str]):
- """
- Create a ColorPalette instance from a list of hex strings.
- Args:
- color_hex_list (List[str]): List of color hex strings.
- Returns:
- ColorPalette: A ColorPalette instance.
- Example:
- ```
- >>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff'])
- ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
- ```
- """
- colors = [Color.from_hex(color_hex) for color_hex in color_hex_list]
- return cls(colors)
- def by_idx(self, idx: int) -> Color:
- """
- Return the color at a given index in the palette.
- Args:
- idx (int): Index of the color in the palette.
- Returns:
- Color: Color at the given index.
- Example:
- ```
- >>> color_palette.by_idx(1)
- Color(r=0, g=255, b=0)
- ```
- """
- if idx < 0:
- raise ValueError("idx argument should not be negative")
- idx = idx % len(self.colors)
- return self.colors[idx]
- def find_farthest_color(self, img_array):
- """
- Return the color that is the farthest from the given color.
- Args:
- img_array (np array): any *x3 np array, 3 is the RGB color channel.
- Returns:
- Color: Farthest color.
- """
- # Reshape the image array for broadcasting
- img_array = img_array.reshape((-1, 3))
- # Convert colors dictionary to a NumPy array
- color_values = np.array([[c.r, c.g, c.b] for c in self.colors])
- # Calculate the Euclidean distance between the colors and each pixel in the image
- # Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3)
- distances = np.sqrt(
- np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2)
- )
- # Average the distances for each color
- mean_distances = np.mean(distances, axis=0)
- # return the farthest color
- farthest_idx = np.argmax(mean_distances)
- farthest_color = self.colors[farthest_idx]
- farthest_color_hex = Color.to_hex(farthest_color)
- if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME:
- farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex]
- else:
- farthest_color_name = "unknown"
- return farthest_color, farthest_color_name
- def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0):
- x0, y0, width, height = box_coord
- ax.add_patch(
- mpl.patches.Rectangle(
- (x0, y0),
- width,
- height,
- fill=False,
- edgecolor=edge_color,
- linewidth=linewidth,
- alpha=alpha,
- linestyle=line_style,
- )
- )
- def draw_text(
- ax,
- text,
- position,
- font_size=None,
- color="g",
- horizontal_alignment="left",
- rotation=0,
- ):
- if not font_size:
- font_size = mpl.rcParams["font.size"]
- color = np.maximum(list(mplc.to_rgb(color)), 0.2)
- color[np.argmax(color)] = max(0.8, np.max(color))
- x, y = position
- ax.text(
- x,
- y,
- text,
- size=font_size,
- family="sans-serif",
- bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"},
- verticalalignment="top",
- horizontalalignment=horizontal_alignment,
- color=color,
- rotation=rotation,
- )
- def draw_mask(
- ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None
- ):
- if isinstance(rle, dict):
- mask = mask_utils.decode(rle)
- elif isinstance(rle, np.ndarray):
- mask = rle
- else:
- raise ValueError(f"Unsupported type for rle: {type(rle)}")
- mask_upsampled = None
- if upsample_factor > 1.0 and show_holes:
- assert rle_upsampled is not None
- if isinstance(rle_upsampled, dict):
- mask_upsampled = mask_utils.decode(rle_upsampled)
- elif isinstance(rle_upsampled, np.ndarray):
- mask_upsampled = rle_upsampled
- else:
- raise ValueError(f"Unsupported type for rle: {type(rle)}")
- if show_holes:
- if mask_upsampled is None:
- mask_upsampled = mask
- h, w = mask_upsampled.shape
- mask_img = np.zeros((h, w, 4))
- mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :]
- mask_img[:, :, -1] = mask_upsampled * alpha
- ax.imshow(mask_img)
- *_, contours, _ = cv2.findContours(
- mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
- )
- upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours]
- facecolor = (0, 0, 0, 0) if show_holes else color
- if alpha > 0.8:
- edge_color = _change_color_brightness(color, brightness_factor=-0.7)
- else:
- edge_color = color
- for cont in upsampled_contours:
- polygon = mpl.patches.Polygon(
- [el[0] for el in cont],
- edgecolor=edge_color,
- linewidth=2.0,
- facecolor=facecolor,
- )
- ax.add_patch(polygon)
- def _change_color_brightness(color, brightness_factor):
- """
- Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
- less or more saturation than the original color.
- Args:
- color: color of the polygon. Refer to `matplotlib.colors` for a full list of
- formats that are accepted.
- brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
- 0 will correspond to no change, a factor in [-1.0, 0) range will result in
- a darker color and a factor in (0, 1.0] range will result in a lighter color.
- Returns:
- modified_color (tuple[double]): a tuple containing the RGB values of the
- modified color. Each value in the tuple is in the [0.0, 1.0] range.
- """
- assert brightness_factor >= -1.0 and brightness_factor <= 1.0
- color = mplc.to_rgb(color)
- polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
- modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
- modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
- modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
- modified_color = colorsys.hls_to_rgb(
- polygon_color[0], modified_lightness, polygon_color[2]
- )
- return modified_color
|