roboflow_v100_eval.yaml 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # @package _global_
  2. defaults:
  3. - _self_
  4. # ============================================================================
  5. # Paths Configuration (Chage this to your own paths)
  6. # ============================================================================
  7. paths:
  8. roboflow_vl_100_root: <YOUR_DATASET_DIR>
  9. experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
  10. bpe_path: <BPE_PATH> # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz
  11. # Roboflow dataset configuration
  12. roboflow_train:
  13. num_images: 100 # Note: This is the number of images used for training. If null, all images are used.
  14. supercategory: ${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}}
  15. # Training transforms pipeline
  16. train_transforms:
  17. - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
  18. transforms:
  19. - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
  20. query_filter:
  21. _target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
  22. - _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
  23. box_noise_std: 0.1
  24. box_noise_max: 20
  25. - _target_: sam3.train.transforms.segmentation.DecodeRle
  26. - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
  27. sizes:
  28. _target_: sam3.train.transforms.basic.get_random_resize_scales
  29. size: ${scratch.resolution}
  30. min_size: 480
  31. rounded: false
  32. max_size:
  33. _target_: sam3.train.transforms.basic.get_random_resize_max_size
  34. size: ${scratch.resolution}
  35. square: true
  36. consistent_transform: ${scratch.consistent_transform}
  37. - _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
  38. size: ${scratch.resolution}
  39. consistent_transform: ${scratch.consistent_transform}
  40. - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
  41. - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
  42. query_filter:
  43. _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
  44. - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
  45. mean: ${scratch.train_norm_mean}
  46. std: ${scratch.train_norm_std}
  47. - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
  48. query_filter:
  49. _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
  50. - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
  51. query_filter:
  52. _target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
  53. max_num_objects: ${scratch.max_ann_per_img}
  54. # Validation transforms pipeline
  55. val_transforms:
  56. - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
  57. transforms:
  58. - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
  59. sizes: ${scratch.resolution}
  60. max_size:
  61. _target_: sam3.train.transforms.basic.get_random_resize_max_size
  62. size: ${scratch.resolution}
  63. square: true
  64. consistent_transform: False
  65. - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
  66. - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
  67. mean: ${scratch.train_norm_mean}
  68. std: ${scratch.train_norm_std}
  69. # loss config (no mask loss)
  70. loss:
  71. _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
  72. matcher: ${scratch.matcher}
  73. o2m_weight: 2.0
  74. o2m_matcher:
  75. _target_: sam3.train.matcher.BinaryOneToManyMatcher
  76. alpha: 0.3
  77. threshold: 0.4
  78. topk: 4
  79. use_o2m_matcher_on_o2m_aux: false # Another option is true
  80. loss_fns_find:
  81. - _target_: sam3.train.loss.loss_fns.Boxes
  82. weight_dict:
  83. loss_bbox: 5.0
  84. loss_giou: 2.0
  85. - _target_: sam3.train.loss.loss_fns.IABCEMdetr
  86. weak_loss: False
  87. weight_dict:
  88. loss_ce: 20.0 # Another option is 100.0
  89. presence_loss: 20.0
  90. pos_weight: 10.0 # Another option is 5.0
  91. alpha: 0.25
  92. gamma: 2
  93. use_presence: True # Change
  94. pos_focal: false
  95. pad_n_queries: 200
  96. pad_scale_pos: 1.0
  97. loss_fn_semantic_seg: null
  98. scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
  99. # NOTE: Loss to be used for training in case of segmentation
  100. # loss:
  101. # _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
  102. # matcher: ${scratch.matcher}
  103. # o2m_weight: 2.0
  104. # o2m_matcher:
  105. # _target_: sam3.train.matcher.BinaryOneToManyMatcher
  106. # alpha: 0.3
  107. # threshold: 0.4
  108. # topk: 4
  109. # use_o2m_matcher_on_o2m_aux: false
  110. # loss_fns_find:
  111. # - _target_: sam3.train.loss.loss_fns.Boxes
  112. # weight_dict:
  113. # loss_bbox: 5.0
  114. # loss_giou: 2.0
  115. # - _target_: sam3.train.loss.loss_fns.IABCEMdetr
  116. # weak_loss: False
  117. # weight_dict:
  118. # loss_ce: 20.0 # Another option is 100.0
  119. # presence_loss: 20.0
  120. # pos_weight: 10.0 # Another option is 5.0
  121. # alpha: 0.25
  122. # gamma: 2
  123. # use_presence: True # Change
  124. # pos_focal: false
  125. # pad_n_queries: 200
  126. # pad_scale_pos: 1.0
  127. # - _target_: sam3.train.loss.loss_fns.Masks
  128. # focal_alpha: 0.25
  129. # focal_gamma: 2.0
  130. # weight_dict:
  131. # loss_mask: 200.0
  132. # loss_dice: 10.0
  133. # compute_aux: false
  134. # loss_fn_semantic_seg:
  135. # _target_: sam3.losses.loss_fns.SemanticSegCriterion
  136. # presence_head: True
  137. # presence_loss: False # Change
  138. # focal: True
  139. # focal_alpha: 0.6
  140. # focal_gamma: 2.0
  141. # downsample: False
  142. # weight_dict:
  143. # loss_semantic_seg: 20.0
  144. # loss_semantic_presence: 1.0
  145. # loss_semantic_dice: 30.0
  146. # scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
  147. # ============================================================================
  148. # Different helper parameters and functions
  149. # ============================================================================
  150. scratch:
  151. enable_segmentation: False # NOTE: This is the number of queries used for segmentation
  152. # Model parameters
  153. d_model: 256
  154. pos_embed:
  155. _target_: sam3.model.position_encoding.PositionEmbeddingSine
  156. num_pos_feats: ${scratch.d_model}
  157. normalize: true
  158. scale: null
  159. temperature: 10000
  160. # Box processing
  161. use_presence_eval: True
  162. original_box_postprocessor:
  163. _target_: sam3.eval.postprocessors.PostProcessImage
  164. max_dets_per_img: -1 # infinite detections
  165. use_original_ids: true
  166. use_original_sizes_box: true
  167. use_presence: ${scratch.use_presence_eval}
  168. # Matcher configuration
  169. matcher:
  170. _target_: sam3.train.matcher.BinaryHungarianMatcherV2
  171. focal: true # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
  172. cost_class: 2.0
  173. cost_bbox: 5.0
  174. cost_giou: 2.0
  175. alpha: 0.25
  176. gamma: 2
  177. stable: False
  178. scale_by_find_batch_size: True
  179. # Image processing parameters
  180. resolution: 1008
  181. consistent_transform: False
  182. max_ann_per_img: 200
  183. # Normalization parameters
  184. train_norm_mean: [0.5, 0.5, 0.5]
  185. train_norm_std: [0.5, 0.5, 0.5]
  186. val_norm_mean: [0.5, 0.5, 0.5]
  187. val_norm_std: [0.5, 0.5, 0.5]
  188. # Training parameters
  189. num_train_workers: 10
  190. num_val_workers: 0
  191. max_data_epochs: 20
  192. target_epoch_size: 1500
  193. hybrid_repeats: 1
  194. context_length: 2
  195. gather_pred_via_filesys: false
  196. # Learning rate and scheduler parameters
  197. lr_scale: 0.1
  198. lr_transformer: ${times:8e-4,${scratch.lr_scale}}
  199. lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
  200. lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
  201. lrd_vision_backbone: 0.9
  202. wd: 0.1
  203. scheduler_timescale: 20
  204. scheduler_warmup: 20
  205. scheduler_cooldown: 20
  206. val_batch_size: 1
  207. collate_fn_val:
  208. _target_: sam3.train.data.collator.collate_fn_api
  209. _partial_: true
  210. repeats: ${scratch.hybrid_repeats}
  211. dict_key: roboflow100
  212. with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
  213. gradient_accumulation_steps: 1
  214. train_batch_size: 1
  215. collate_fn:
  216. _target_: sam3.train.data.collator.collate_fn_api
  217. _partial_: true
  218. repeats: ${scratch.hybrid_repeats}
  219. dict_key: all
  220. with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
  221. # ============================================================================
  222. # Trainer Configuration
  223. # ============================================================================
  224. trainer:
  225. _target_: sam3.train.trainer.Trainer
  226. skip_saving_ckpts: true
  227. empty_gpu_mem_cache_after_eval: True
  228. skip_first_val: True
  229. max_epochs: 20
  230. accelerator: cuda
  231. seed_value: 123
  232. val_epoch_freq: 10
  233. mode: val
  234. gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}
  235. distributed:
  236. backend: nccl
  237. find_unused_parameters: True
  238. gradient_as_bucket_view: True
  239. loss:
  240. all: ${roboflow_train.loss}
  241. default:
  242. _target_: sam3.train.loss.sam3_loss.DummyLoss
  243. data:
  244. train:
  245. _target_: sam3.train.data.torch_dataset.TorchDataset
  246. dataset:
  247. _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
  248. limit_ids: ${roboflow_train.num_images}
  249. transforms: ${roboflow_train.train_transforms}
  250. load_segmentation: ${scratch.enable_segmentation}
  251. max_ann_per_img: 500000
  252. multiplier: 1
  253. max_train_queries: 50000
  254. max_val_queries: 50000
  255. training: true
  256. use_caching: False
  257. img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/
  258. ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/_annotations.coco.json
  259. shuffle: True
  260. batch_size: ${scratch.train_batch_size}
  261. num_workers: ${scratch.num_train_workers}
  262. pin_memory: True
  263. drop_last: True
  264. collate_fn: ${scratch.collate_fn}
  265. val:
  266. _target_: sam3.train.data.torch_dataset.TorchDataset
  267. dataset:
  268. _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
  269. load_segmentation: ${scratch.enable_segmentation}
  270. coco_json_loader:
  271. _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
  272. include_negatives: true
  273. category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
  274. _partial_: true
  275. img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
  276. ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
  277. transforms: ${roboflow_train.val_transforms}
  278. max_ann_per_img: 100000
  279. multiplier: 1
  280. training: false
  281. shuffle: False
  282. batch_size: ${scratch.val_batch_size}
  283. num_workers: ${scratch.num_val_workers}
  284. pin_memory: True
  285. drop_last: False
  286. collate_fn: ${scratch.collate_fn_val}
  287. model:
  288. _target_: sam3.model_builder.build_sam3_image_model
  289. bpe_path: ${paths.bpe_path}
  290. device: cpus
  291. eval_mode: true
  292. enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
  293. meters:
  294. val:
  295. roboflow100:
  296. detection:
  297. _target_: sam3.eval.coco_writer.PredictionDumper
  298. iou_type: "bbox"
  299. dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory}
  300. merge_predictions: True
  301. postprocessor: ${scratch.original_box_postprocessor}
  302. gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
  303. maxdets: 100
  304. pred_file_evaluators:
  305. - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
  306. gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
  307. tide: False
  308. iou_type: "bbox"
  309. optim:
  310. amp:
  311. enabled: True
  312. amp_dtype: bfloat16
  313. optimizer:
  314. _target_: torch.optim.AdamW
  315. gradient_clip:
  316. _target_: sam3.train.optim.optimizer.GradientClipper
  317. max_norm: 0.1
  318. norm_type: 2
  319. param_group_modifiers:
  320. - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
  321. _partial_: True
  322. layer_decay_value: ${scratch.lrd_vision_backbone}
  323. apply_to: 'backbone.vision_backbone.trunk'
  324. overrides:
  325. - pattern: '*pos_embed*'
  326. value: 1.0
  327. options:
  328. lr:
  329. - scheduler: # transformer and class_embed
  330. _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
  331. base_lr: ${scratch.lr_transformer}
  332. timescale: ${scratch.scheduler_timescale}
  333. warmup_steps: ${scratch.scheduler_warmup}
  334. cooldown_steps: ${scratch.scheduler_cooldown}
  335. - scheduler:
  336. _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
  337. base_lr: ${scratch.lr_vision_backbone}
  338. timescale: ${scratch.scheduler_timescale}
  339. warmup_steps: ${scratch.scheduler_warmup}
  340. cooldown_steps: ${scratch.scheduler_cooldown}
  341. param_names:
  342. - 'backbone.vision_backbone.*'
  343. - scheduler:
  344. _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
  345. base_lr: ${scratch.lr_language_backbone}
  346. timescale: ${scratch.scheduler_timescale}
  347. warmup_steps: ${scratch.scheduler_warmup}
  348. cooldown_steps: ${scratch.scheduler_cooldown}
  349. param_names:
  350. - 'backbone.language_backbone.*'
  351. weight_decay:
  352. - scheduler:
  353. _target_: fvcore.common.param_scheduler.ConstantParamScheduler
  354. value: ${scratch.wd}
  355. - scheduler:
  356. _target_: fvcore.common.param_scheduler.ConstantParamScheduler
  357. value: 0.0
  358. param_names:
  359. - '*bias*'
  360. module_cls_names: ['torch.nn.LayerNorm']
  361. checkpoint:
  362. save_dir: ${launcher.experiment_log_dir}/checkpoints
  363. save_freq: 0 # 0 only last checkpoint is saved.
  364. logging:
  365. tensorboard_writer:
  366. _target_: sam3.train.utils.logger.make_tensorboard_logger
  367. log_dir: ${launcher.experiment_log_dir}/tensorboard
  368. flush_secs: 120
  369. should_log: True
  370. wandb_writer: null
  371. log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory}
  372. log_freq: 10
  373. # ============================================================================
  374. # Launcher and Submitit Configuration
  375. # ============================================================================
  376. launcher:
  377. num_nodes: 1
  378. gpus_per_node: 2
  379. experiment_log_dir: ${paths.experiment_log_dir}
  380. multiprocessing_context: forkserver
  381. submitit:
  382. account: null
  383. partition: null
  384. qos: null
  385. timeout_hour: 72
  386. use_cluster: True
  387. cpus_per_task: 10
  388. port_range: [10000, 65000]
  389. constraint: null
  390. # Uncomment for job array configuration
  391. job_array:
  392. num_tasks: 100
  393. task_index: 0
  394. # ============================================================================
  395. # Available Roboflow Supercategories (for reference)
  396. # ============================================================================
  397. all_roboflow_supercategories:
  398. - -grccs
  399. - zebrasatasturias
  400. - cod-mw-warzone
  401. - canalstenosis
  402. - label-printing-defect-version-2
  403. - new-defects-in-wood
  404. - orionproducts
  405. - aquarium-combined
  406. - varroa-mites-detection--test-set
  407. - clashroyalechardetector
  408. - stomata-cells
  409. - halo-infinite-angel-videogame
  410. - pig-detection
  411. - urine-analysis1
  412. - aerial-sheep
  413. - orgharvest
  414. - actions
  415. - mahjong
  416. - liver-disease
  417. - needle-base-tip-min-max
  418. - wheel-defect-detection
  419. - aircraft-turnaround-dataset
  420. - xray
  421. - wildfire-smoke
  422. - spinefrxnormalvindr
  423. - ufba-425
  424. - speech-bubbles-detection
  425. - train
  426. - pill
  427. - truck-movement
  428. - car-logo-detection
  429. - inbreast
  430. - sea-cucumbers-new-tiles
  431. - uavdet-small
  432. - penguin-finder-seg
  433. - aerial-airport
  434. - bibdetection
  435. - taco-trash-annotations-in-context
  436. - bees
  437. - recode-waste
  438. - screwdetectclassification
  439. - wine-labels
  440. - aerial-cows
  441. - into-the-vale
  442. - gwhd2021
  443. - lacrosse-object-detection
  444. - defect-detection
  445. - dataconvert
  446. - x-ray-id
  447. - ball
  448. - tube
  449. - 2024-frc
  450. - crystal-clean-brain-tumors-mri-dataset
  451. - grapes-5
  452. - human-detection-in-floods
  453. - buoy-onboarding
  454. - apoce-aerial-photographs-for-object-detection-of-construction-equipment
  455. - l10ul502
  456. - floating-waste
  457. - deeppcb
  458. - ism-band-packet-detection
  459. - weeds4
  460. - invoice-processing
  461. - thermal-cheetah
  462. - tomatoes-2
  463. - marine-sharks
  464. - peixos-fish
  465. - sssod
  466. - aerial-pool
  467. - countingpills
  468. - asphaltdistressdetection
  469. - roboflow-trained-dataset
  470. - everdaynew
  471. - underwater-objects
  472. - soda-bottles
  473. - dentalai
  474. - jellyfish
  475. - deepfruits
  476. - activity-diagrams
  477. - circuit-voltages
  478. - all-elements
  479. - macro-segmentation
  480. - exploratorium-daphnia
  481. - signatures
  482. - conveyor-t-shirts
  483. - fruitjes
  484. - grass-weeds
  485. - infraredimageofpowerequipment
  486. - 13-lkc01
  487. - wb-prova
  488. - flir-camera-objects
  489. - paper-parts
  490. - football-player-detection
  491. - trail-camera
  492. - smd-components
  493. - water-meter
  494. - nih-xray
  495. - the-dreidel-project
  496. - electric-pylon-detection-in-rsi
  497. - cable-damage