odinw_text_and_visual.yaml 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # @package _global_
  2. defaults:
  3. - _self_
  4. # ============================================================================
  5. # Paths Configuration (Chage this to your own paths)
  6. # ============================================================================
  7. # python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
  8. paths:
  9. odinw_data_root: <YOUR_DATA_DIR>
  10. experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
  11. bpe_path: <BPE_PATH> # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz
  12. supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
  13. # Validation transforms pipeline
  14. val_transforms:
  15. - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
  16. transforms:
  17. - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
  18. sizes: ${scratch.resolution}
  19. max_size:
  20. _target_: sam3.train.transforms.basic.get_random_resize_max_size
  21. size: ${scratch.resolution}
  22. square: true
  23. consistent_transform: False
  24. - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
  25. - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
  26. mean: ${scratch.val_norm_mean}
  27. std: ${scratch.val_norm_std}
  28. - _target_: sam3.train.transforms.filter_query_transforms.TextQueryToVisual
  29. keep_text_queries: true # Note: set this to false if you only want visual
  30. probability: 1.0 # always
  31. # ============================================================================
  32. # Different helper parameters and functions
  33. # ============================================================================
  34. scratch:
  35. enable_segmentation: True
  36. # Box processing
  37. use_presence_eval: True
  38. original_box_postprocessor:
  39. _target_: sam3.eval.postprocessors.PostProcessImage
  40. max_dets_per_img: -1 # infinite detections
  41. use_original_ids: true
  42. use_original_sizes_box: true
  43. use_presence: ${scratch.use_presence_eval}
  44. # Image processing parameters
  45. resolution: 1008
  46. # Normalization parameters
  47. val_norm_mean: [0.5, 0.5, 0.5]
  48. val_norm_std: [0.5, 0.5, 0.5]
  49. # Training parameters
  50. val_batch_size: 2
  51. num_val_workers: 0
  52. gather_pred_via_filesys: false
  53. # ============================================================================
  54. # Trainer Configuration
  55. # ============================================================================
  56. trainer:
  57. _target_: sam3.train.trainer.Trainer
  58. skip_saving_ckpts: true
  59. empty_gpu_mem_cache_after_eval: True
  60. max_epochs: 1
  61. accelerator: cuda
  62. seed_value: 123
  63. mode: val
  64. distributed:
  65. backend: nccl
  66. find_unused_parameters: True
  67. gradient_as_bucket_view: True
  68. loss:
  69. default:
  70. _target_: sam3.train.loss.sam3_loss.DummyLoss
  71. data:
  72. val:
  73. _target_: sam3.train.data.torch_dataset.TorchDataset
  74. dataset:
  75. _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
  76. coco_json_loader:
  77. _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
  78. prompts: ${odinw35_prompts.${supercategory_tuple.name}}
  79. include_negatives: true
  80. category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
  81. _partial_: true
  82. img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
  83. ann_file:
  84. _target_: sam3.eval.coco_reindex.reindex_coco_to_temp
  85. input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
  86. transforms: ${val_transforms}
  87. max_ann_per_img: 100000
  88. multiplier: 1
  89. training: false
  90. shuffle: False
  91. batch_size: ${scratch.val_batch_size}
  92. num_workers: ${scratch.num_val_workers}
  93. pin_memory: False
  94. drop_last: False
  95. collate_fn:
  96. _target_: sam3.train.data.collator.collate_fn_api
  97. _partial_: true
  98. repeats: 1
  99. dict_key: odinw35
  100. model:
  101. _target_: sam3.model_builder.build_sam3_image_model
  102. bpe_path: ${paths.bpe_path}
  103. device: cpus
  104. eval_mode: true # Set to false if training
  105. enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
  106. meters:
  107. val:
  108. odinw35:
  109. detection:
  110. _target_: sam3.eval.coco_writer.PredictionDumper
  111. iou_type: "bbox"
  112. dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name}
  113. merge_predictions: True
  114. postprocessor: ${scratch.original_box_postprocessor}
  115. gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
  116. maxdets: 100
  117. pred_file_evaluators:
  118. - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
  119. gt_path:
  120. _target_: sam3.eval.coco_reindex.reindex_coco_to_temp
  121. input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
  122. tide: False
  123. iou_type: "bbox"
  124. positive_split: true
  125. checkpoint:
  126. save_dir: ${launcher.experiment_log_dir}/checkpoints
  127. save_freq: 0 # 0 only last checkpoint is saved.
  128. logging:
  129. tensorboard_writer:
  130. _target_: sam3.train.utils.logger.make_tensorboard_logger
  131. log_dir: ${launcher.experiment_log_dir}/tensorboard
  132. flush_secs: 120
  133. should_log: True
  134. wandb_writer: null
  135. log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
  136. log_freq: 10
  137. # ============================================================================
  138. # Launcher and Submitit Configuration
  139. # ============================================================================
  140. launcher:
  141. num_nodes: 1
  142. gpus_per_node: 2
  143. experiment_log_dir: ${paths.experiment_log_dir}
  144. multiprocessing_context: forkserver
  145. submitit:
  146. account: null
  147. partition: null
  148. qos: null
  149. timeout_hour: 72
  150. use_cluster: True
  151. cpus_per_task: 10
  152. port_range: [10000, 65000]
  153. constraint: null
  154. job_array:
  155. num_tasks: 13
  156. task_index: 0
  157. # ============================================================================
  158. # ODinW13 Supercategories
  159. # ============================================================================
  160. all_odinw_supercategories:
  161. - name: AerialMaritimeDrone_large
  162. val:
  163. img_folder: AerialMaritimeDrone/large/test/
  164. json: AerialMaritimeDrone/large/test/annotations_without_background.json
  165. - name: Aquarium
  166. val:
  167. img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
  168. json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
  169. - name: CottontailRabbits
  170. val:
  171. img_folder: CottontailRabbits/test/
  172. json: CottontailRabbits/test/annotations_without_background.json
  173. - name: EgoHands_generic
  174. val:
  175. img_folder: EgoHands/generic/test/
  176. json: EgoHands/generic/test/annotations_without_background.json
  177. - name: NorthAmericaMushrooms
  178. val:
  179. img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
  180. json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
  181. - name: Packages
  182. val:
  183. img_folder: Packages/Raw/test/
  184. json: Packages/Raw/test/annotations_without_background.json
  185. - name: PascalVOC
  186. val:
  187. img_folder: PascalVOC/valid/
  188. json: PascalVOC/valid/annotations_without_background.json
  189. - name: Raccoon
  190. val:
  191. img_folder: Raccoon/Raccoon.v2-raw.coco/test/
  192. json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
  193. - name: ShellfishOpenImages
  194. val:
  195. img_folder: ShellfishOpenImages/raw/test/
  196. json: ShellfishOpenImages/raw/test/annotations_without_background.json
  197. - name: VehiclesOpenImages
  198. val:
  199. img_folder: VehiclesOpenImages/416x416/test/
  200. json: VehiclesOpenImages/416x416/test/annotations_without_background.json
  201. - name: pistols
  202. val:
  203. img_folder: pistols/export/
  204. json: pistols/export/test_annotations_without_background.json
  205. - name: pothole
  206. val:
  207. img_folder: pothole/test/
  208. json: pothole/test/annotations_without_background.json
  209. - name: thermalDogsAndPeople
  210. val:
  211. img_folder: thermalDogsAndPeople/test/
  212. json: thermalDogsAndPeople/test/annotations_without_background.json
  213. odinw35_prompts:
  214. AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
  215. {"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
  216. "supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
  217. {"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
  218. Aquarium: null
  219. CottontailRabbits: null
  220. EgoHands_generic: null
  221. NorthAmericaMushrooms: '[{''id'': 1, ''name'':
  222. ''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
  223. Packages: null
  224. PascalVOC: null
  225. Raccoon: null
  226. ShellfishOpenImages: null
  227. VehiclesOpenImages: null
  228. pistols: null
  229. pothole: null
  230. thermalDogsAndPeople: null