eval_base.yaml 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # @package _global_
  2. defaults:
  3. - _self_
  4. # This config is the base configuration for all evaluations. Amongst other things, it defines:
  5. # - the model
  6. # - the image transforms
  7. # - the post processors
  8. # - cluster configuration (only relevant for slurm-based evals, ignored otherwise)
  9. #
  10. # Most of the parameters should be kept as-is. The main modifications you may want to make are:
  11. # - the cluster configuration, to adjust partitions/qos to your system
  12. # - the flag gather_pred_via_filesys if you ram is tight
  13. # - num_val_workers if your number of cores is small (should be roughly number of cores / number of gpus)
  14. # - the paths below
  15. # ============================================================================
  16. # Paths Configuration (Chage this to your own paths)
  17. # ============================================================================
  18. paths:
  19. # If you leave the checkpoint path to null, the model will be downloaded from hugging-face. Otherwise provide a path
  20. checkpoint_path: null
  21. # the experiments will be subfolders of this
  22. base_experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
  23. # base path to the annotation folder for gold (refer to the readmes on how to download)
  24. base_annotation_path: <YOUR_GOLD_GT_DIR>
  25. # base path to the annotation folder for silver (refer to the readmes on how to download)
  26. base_annotation_path_silver: <YOUR_SILVER_GT_DIR>
  27. # path to the metaclip images, used for SA-Co gold (refer to the readme for instructions). Can be null if you don't intend on evaluating on this dataset.
  28. metaclip_img_path: <YOUR_METACLIP_IMG_DIR>
  29. # path to the sa1b images, used for SA-Co gold (refer to the readme for instructions). Can be null if you don't intend on evaluating on this dataset.
  30. sa1b_img_path: <YOUR_SA1B_IMG_DIR>
  31. # path to the SA-Co/silver images
  32. silver_img_path: <YOUR_SILVER_IMG_DIR>
  33. bpe_path: <BPE_PATH> # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz
  34. # ============================================================================
  35. # Different helper parameters and functions
  36. # ============================================================================
  37. scratch:
  38. use_presence_eval: True
  39. base_val_transform:
  40. - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
  41. transforms:
  42. ######## transforms for validation (begin) ########
  43. - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
  44. sizes: ${scratch.resolution} # originally `resolution: 1024`
  45. max_size:
  46. _target_: sam3.train.transforms.basic.get_random_resize_max_size
  47. size: ${scratch.resolution} # originally `resolution: 1024`
  48. square: true
  49. consistent_transform: False
  50. ######## transforms for validation (end) ########
  51. - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
  52. - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
  53. mean: ${scratch.val_norm_mean}
  54. std: ${scratch.val_norm_std}
  55. loss: null
  56. # Model parameters
  57. d_model: 256
  58. input_box_embedding_dim: ${add:${scratch.d_model},2}
  59. # Box processing
  60. original_box_postprocessor:
  61. _target_: sam3.eval.postprocessors.PostProcessImage
  62. max_dets_per_img: -1 # infinite detections
  63. use_original_ids: true
  64. use_original_sizes_box: true
  65. use_presence: ${scratch.use_presence_eval}
  66. box_postprocessor:
  67. _target_: sam3.eval.postprocessors.PostProcessImage
  68. max_dets_per_img: -1 #infinite detections
  69. use_original_ids: false
  70. use_original_sizes_box: false
  71. use_presence: ${scratch.use_presence_eval}
  72. box_postprocessor_thresholded:
  73. _target_: sam3.eval.postprocessors.PostProcessImage
  74. max_dets_per_img: -1 #infinite detections
  75. use_original_ids: false
  76. use_original_sizes_box: false
  77. detection_threshold: 0.3
  78. use_presence: ${scratch.use_presence_eval}
  79. mask_postprocessor_thresholded:
  80. _target_: sam3.eval.postprocessors.PostProcessImage
  81. max_dets_per_img: -1 #infinite detections
  82. iou_type: "segm"
  83. use_original_ids: false
  84. use_original_sizes_box: false
  85. use_original_sizes_mask: true
  86. convert_mask_to_rle: True
  87. detection_threshold: 0.3
  88. use_presence: ${scratch.use_presence_eval}
  89. # Image processing parameters
  90. resolution: 1008
  91. max_ann_per_img: 200
  92. # Normalization parameters
  93. train_norm_mean: [0.5, 0.5, 0.5]
  94. train_norm_std: [0.5, 0.5, 0.5]
  95. val_norm_mean: [0.5, 0.5, 0.5]
  96. val_norm_std: [0.5, 0.5, 0.5]
  97. # Training parameters
  98. train_batch_size: 1
  99. val_batch_size: 1
  100. num_train_workers: 0
  101. num_val_workers: 10 # change this depending on the number of cpu cores available
  102. max_data_epochs: 20
  103. target_epoch_size: 1500
  104. hybrid_repeats: 1
  105. context_length: 2
  106. # All reduce - this controls how the predictions are sent back to node 0.
  107. # If you have a lot of ram, CPU gather is faster. Otherwise, we provide a fallback through filesystem (eg NFS)
  108. # Switch to true if you get cpu ooms during gather.
  109. gather_pred_via_filesys: false
  110. # Learning rate and scheduler parameters (unused for eval)
  111. lr_scale: 0.1
  112. lr_transformer: ${times:8e-4,${scratch.lr_scale}}
  113. lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
  114. lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
  115. lrd_vision_backbone: 0.9 # (lower for in-domain adn higher for ood)
  116. wd: 0.1
  117. scheduler_timescale: 20
  118. scheduler_warmup: 20
  119. scheduler_cooldown: 20
  120. # ============================================================================
  121. # Trainer Configuration
  122. # ============================================================================
  123. trainer:
  124. _target_: sam3.train.trainer.Trainer
  125. skip_saving_ckpts: true
  126. empty_gpu_mem_cache_after_eval: True
  127. skip_first_val: True
  128. max_epochs: ${scratch.max_data_epochs}
  129. accelerator: cuda
  130. seed_value: 123
  131. val_epoch_freq: 10
  132. mode: val
  133. distributed:
  134. backend: nccl
  135. find_unused_parameters: True
  136. gradient_as_bucket_view: True
  137. loss:
  138. all:
  139. _target_: sam3.train.loss.sam3_loss.DummyLoss
  140. default:
  141. _target_: sam3.train.loss.sam3_loss.DummyLoss
  142. data:
  143. train: null
  144. val: null
  145. model:
  146. _target_: sam3.model_builder.build_sam3_image_model
  147. bpe_path: ${paths.bpe_path}
  148. device: cpus
  149. eval_mode: true
  150. enable_segmentation: true # Warning: Enable this if using segmentation.
  151. checkpoint_path: ${paths.checkpoint_path}
  152. meters:
  153. val: null
  154. optim:
  155. amp:
  156. enabled: True
  157. amp_dtype: bfloat16
  158. optimizer:
  159. _target_: torch.optim.AdamW
  160. gradient_clip:
  161. _target_: sam3.train.optim.optimizer.GradientClipper
  162. max_norm: 0.1
  163. norm_type: 2
  164. param_group_modifiers:
  165. - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
  166. _partial_: True
  167. layer_decay_value: ${scratch.lrd_vision_backbone}
  168. apply_to: 'backbone.vision_backbone.trunk'
  169. overrides:
  170. - pattern: '*pos_embed*'
  171. value: 1.0
  172. options:
  173. lr:
  174. - scheduler: # transformer and class_embed
  175. _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
  176. base_lr: ${scratch.lr_transformer}
  177. timescale: ${scratch.scheduler_timescale}
  178. warmup_steps: ${scratch.scheduler_warmup}
  179. cooldown_steps: ${scratch.scheduler_cooldown}
  180. - scheduler:
  181. _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
  182. base_lr: ${scratch.lr_vision_backbone}
  183. timescale: ${scratch.scheduler_timescale}
  184. warmup_steps: ${scratch.scheduler_warmup}
  185. cooldown_steps: ${scratch.scheduler_cooldown}
  186. param_names:
  187. - 'backbone.vision_backbone.*'
  188. - scheduler:
  189. _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
  190. base_lr: ${scratch.lr_language_backbone}
  191. timescale: ${scratch.scheduler_timescale}
  192. warmup_steps: ${scratch.scheduler_warmup}
  193. cooldown_steps: ${scratch.scheduler_cooldown}
  194. param_names:
  195. - 'backbone.language_backbone.*'
  196. weight_decay:
  197. - scheduler:
  198. _target_: fvcore.common.param_scheduler.ConstantParamScheduler
  199. value: ${scratch.wd}
  200. - scheduler:
  201. _target_: fvcore.common.param_scheduler.ConstantParamScheduler
  202. value: 0.0
  203. param_names:
  204. - '*bias*'
  205. module_cls_names: ['torch.nn.LayerNorm']
  206. checkpoint:
  207. save_dir: ${launcher.experiment_log_dir}/checkpoints
  208. save_freq: 0 # 0 only last checkpoint is saved.
  209. logging:
  210. tensorboard_writer:
  211. _target_: sam3.train.utils.logger.make_tensorboard_logger
  212. log_dir: ${launcher.experiment_log_dir}/tensorboard
  213. flush_secs: 120
  214. should_log: True
  215. wandb_writer: null
  216. log_dir: ${launcher.experiment_log_dir}/logs/
  217. log_freq: 10
  218. # ============================================================================
  219. # Launcher and Submitit Configuration
  220. # ============================================================================
  221. launcher:
  222. num_nodes: 4
  223. gpus_per_node: 8
  224. experiment_log_dir: ${paths.experiment_log_dir}
  225. multiprocessing_context: forkserver
  226. submitit:
  227. account: null # Add your SLURM account if use_cluster == 1
  228. partition: null
  229. qos: null # Add your QoS if use_cluster == 1
  230. timeout_hour: 72
  231. use_cluster: True
  232. cpus_per_task: 10
  233. port_range: [10000, 65000]
  234. constraint: null