sam_utils.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import numpy as np
  2. from PIL import Image
  3. import io
  4. import base64
  5. def show_mask(mask):
  6. h, w = mask.shape[-2:]
  7. mask = mask.astype(np.uint8)
  8. # 创建一个全黑的RGBA图像
  9. mask_image = np.zeros((h, w, 4), dtype=np.uint8)
  10. # 将掩码区域设为白色(前景)
  11. mask_image[mask > 0] = [0, 122, 204, 255] # 白色前景
  12. mask_image[mask == 0] = [0, 0, 0, 0] # 黑色背景
  13. return mask_image
  14. def show_masks(masks, scores):
  15. base64_images = []
  16. for i, (mask, score) in enumerate(zip(masks, scores)):
  17. np_arr = show_mask(mask)
  18. arr_img = Image.fromarray(np_arr)
  19. # 将图像转换为base64
  20. buffered = io.BytesIO()
  21. arr_img.save(buffered, format="PNG")
  22. img_str = base64.b64encode(buffered.getvalue()).decode()
  23. base64_images.append(img_str)
  24. return base64_images
  25. def convert_to_serializable(obj):
  26. """递归将numpy和torch类型转换为可序列化的Python类型"""
  27. import torch
  28. if isinstance(obj, np.ndarray):
  29. return obj.tolist() # 将numpy数组转换为Python列表
  30. elif isinstance(obj, torch.Tensor):
  31. return obj.detach().cpu().numpy().tolist() # 将tensor转换为numpy再转为列表
  32. elif isinstance(obj, (np.floating, np.integer, np.bool_)):
  33. return obj.item() # 将numpy标量转换为Python原生类型
  34. elif isinstance(obj, (list, tuple)):
  35. return [convert_to_serializable(item) for item in obj]
  36. elif isinstance(obj, dict):
  37. return {key: convert_to_serializable(value) for key, value in obj.items()}
  38. else:
  39. return obj