| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import os
- import numpy as np
- import pytest
- import torch
- from PIL import Image
- from sam3.perflib.masks_ops import masks_to_boxes
- class TestMasksToBoxes:
- def test_masks_box(self):
- def masks_box_check(masks, expected, atol=1e-4):
- out = masks_to_boxes(masks, [1 for _ in range(masks.shape[0])])
- assert out.dtype == torch.float
- print("out: ", out)
- print("expected: ", expected)
- torch.testing.assert_close(
- out, expected, rtol=0.0, check_dtype=True, atol=atol
- )
- # Check for int type boxes.
- def _get_image():
- assets_directory = os.path.join(
- os.path.dirname(os.path.abspath(__file__)), "assets"
- )
- mask_path = os.path.join(assets_directory, "masks.tiff")
- image = Image.open(mask_path)
- return image
- def _create_masks(image, masks):
- for index in range(image.n_frames):
- image.seek(index)
- frame = np.array(image)
- masks[index] = torch.tensor(frame)
- return masks
- expected = torch.tensor(
- [
- [127, 2, 165, 40],
- [2, 50, 44, 92],
- [56, 63, 98, 100],
- [139, 68, 175, 104],
- [160, 112, 198, 145],
- [49, 138, 99, 182],
- [108, 148, 152, 213],
- ],
- dtype=torch.float,
- )
- image = _get_image()
- for dtype in [torch.float16, torch.float32, torch.float64]:
- masks = torch.zeros(
- (image.n_frames, image.height, image.width), dtype=dtype
- )
- masks = _create_masks(image, masks)
- masks_box_check(masks, expected)
|