distributed.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import datetime
  6. import functools
  7. import io
  8. import logging
  9. import os
  10. import random
  11. import tempfile
  12. import time
  13. from typing import Any, Callable, List, Tuple
  14. import torch
  15. import torch.autograd as autograd
  16. import torch.distributed as dist
  17. # Default to GPU 0
  18. _cuda_device_index: int = 0
  19. # Setting _cuda_device_index to -1 internally implies that we should use CPU
  20. _CPU_DEVICE_INDEX = -1
  21. _PRIMARY_RANK = 0
  22. @functools.lru_cache()
  23. def _get_global_gloo_group():
  24. """
  25. Return a process group based on gloo backend, containing all the ranks
  26. The result is cached.
  27. """
  28. if dist.get_backend() == "nccl":
  29. # Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
  30. # being much slower than others causing a timeout (which can happen in relation
  31. # or LVIS class mAP evaluation).
  32. timeout = 43200
  33. return dist.new_group(
  34. backend="gloo",
  35. timeout=datetime.timedelta(seconds=timeout),
  36. )
  37. return dist.group.WORLD
  38. def is_main_process():
  39. """Return true if the current process is the main one"""
  40. return get_rank() == 0
  41. def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
  42. """
  43. Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
  44. `all_gather` above, but using filesystem instead of collective ops.
  45. If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
  46. (and other ranks will have an empty list).
  47. """
  48. world_size = get_world_size()
  49. if world_size == 1:
  50. return [data]
  51. print("gathering via files")
  52. cpu_group = _get_global_gloo_group()
  53. # if unspecified, we will save to the current python file dir
  54. if filesys_save_dir is not None:
  55. save_dir = filesys_save_dir
  56. elif "EXP_DIR" in os.environ:
  57. save_dir = os.environ["EXP_DIR"]
  58. else:
  59. # try the same directory where the code is stored
  60. save_dir = filesys_save_dir or os.path.dirname(__file__)
  61. save_dir = os.path.join(save_dir, "all_gather_via_filesys")
  62. if is_main_process():
  63. os.makedirs(save_dir, exist_ok=True)
  64. # use a timestamp and salt to distinguish different all_gather
  65. timestamp = int(time.time()) if is_main_process() else 0
  66. salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
  67. # broadcast the timestamp and salt across ranks
  68. # (all-reduce will do the broadcasting since only rank 0 is non-zero)
  69. timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
  70. dist.all_reduce(timestamp_and_salt, group=cpu_group)
  71. timestamp, salt = timestamp_and_salt.tolist()
  72. # save the data to a file on the disk
  73. rank_save = get_rank()
  74. save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
  75. save_data_path = os.path.join(save_dir, save_data_filename)
  76. assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
  77. torch.save(data, save_data_path)
  78. dist.barrier(group=cpu_group)
  79. # read the data from the files
  80. data_list = []
  81. if rank_save == 0 or not gather_to_rank_0_only:
  82. for rank_load in range(world_size):
  83. load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
  84. load_data_path = os.path.join(save_dir, load_data_filename)
  85. assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
  86. data_list.append(torch.load(load_data_path))
  87. dist.barrier(group=cpu_group)
  88. # delete the saved file
  89. os.remove(save_data_path)
  90. return data_list
  91. def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
  92. """
  93. Run all_gather on arbitrary picklable data (not necessarily tensors)
  94. Args:
  95. data: any picklable object
  96. Returns:
  97. list[data]: list of data gathered from each rank
  98. """
  99. world_size = get_world_size()
  100. if world_size == 1:
  101. return [data]
  102. if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
  103. return all_gather_via_filesys(
  104. data, filesys_save_dir, gather_to_rank_0_only=True
  105. )
  106. if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
  107. return all_gather_via_filesys(data, filesys_save_dir)
  108. cpu_group = None
  109. if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
  110. cpu_group = _get_global_gloo_group()
  111. buffer = io.BytesIO()
  112. torch.save(data, buffer)
  113. data_view = buffer.getbuffer()
  114. device = "cuda" if cpu_group is None else "cpu"
  115. tensor = torch.ByteTensor(data_view).to(device)
  116. # obtain Tensor size of each rank
  117. local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
  118. size_list = [
  119. torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
  120. ]
  121. if cpu_group is None:
  122. dist.all_gather(size_list, local_size)
  123. else:
  124. print("gathering on cpu")
  125. dist.all_gather(size_list, local_size, group=cpu_group)
  126. size_list = [int(size.item()) for size in size_list]
  127. max_size = max(size_list)
  128. assert isinstance(local_size.item(), int)
  129. local_size = int(local_size.item())
  130. # receiving Tensor from all ranks
  131. # we pad the tensor because torch all_gather does not support
  132. # gathering tensors of different shapes
  133. tensor_list = []
  134. for _ in size_list:
  135. tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
  136. if local_size != max_size:
  137. padding = torch.empty(
  138. size=(max_size - local_size,), dtype=torch.uint8, device=device
  139. )
  140. tensor = torch.cat((tensor, padding), dim=0)
  141. if cpu_group is None:
  142. dist.all_gather(tensor_list, tensor)
  143. else:
  144. dist.all_gather(tensor_list, tensor, group=cpu_group)
  145. data_list = []
  146. for size, tensor in zip(size_list, tensor_list):
  147. tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
  148. buffer = io.BytesIO(tensor.cpu().numpy())
  149. obj = torch.load(buffer)
  150. data_list.append(obj)
  151. return data_list
  152. def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
  153. """
  154. For some backends, such as NCCL, communication only works if the
  155. tensor is on the GPU. This helper function converts to the correct
  156. device and returns the tensor + original device.
  157. """
  158. orig_device = "cpu" if not tensor.is_cuda else "gpu"
  159. if (
  160. torch.distributed.is_available()
  161. and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
  162. and not tensor.is_cuda
  163. ):
  164. tensor = tensor.cuda()
  165. return (tensor, orig_device)
  166. def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
  167. """
  168. For some backends, such as NCCL, communication only works if the
  169. tensor is on the GPU. This converts the tensor back to original device.
  170. """
  171. if tensor.is_cuda and orig_device == "cpu":
  172. tensor = tensor.cpu()
  173. return tensor
  174. def is_distributed_training_run() -> bool:
  175. return (
  176. torch.distributed.is_available()
  177. and torch.distributed.is_initialized()
  178. and (torch.distributed.get_world_size() > 1)
  179. )
  180. def is_primary() -> bool:
  181. """
  182. Returns True if this is rank 0 of a distributed training job OR if it is
  183. a single trainer job. Otherwise False.
  184. """
  185. return get_rank() == _PRIMARY_RANK
  186. def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
  187. """
  188. Wrapper over torch.distributed.all_reduce for performing mean reduction
  189. of tensor over all processes.
  190. """
  191. return all_reduce_op(
  192. tensor,
  193. torch.distributed.ReduceOp.SUM,
  194. lambda t: t / torch.distributed.get_world_size(),
  195. )
  196. def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
  197. """
  198. Wrapper over torch.distributed.all_reduce for performing sum
  199. reduction of tensor over all processes in both distributed /
  200. non-distributed scenarios.
  201. """
  202. return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
  203. def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
  204. """
  205. Wrapper over torch.distributed.all_reduce for performing min
  206. reduction of tensor over all processes in both distributed /
  207. non-distributed scenarios.
  208. """
  209. return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
  210. def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
  211. """
  212. Wrapper over torch.distributed.all_reduce for performing min
  213. reduction of tensor over all processes in both distributed /
  214. non-distributed scenarios.
  215. """
  216. return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
  217. def all_reduce_op(
  218. tensor: torch.Tensor,
  219. op: torch.distributed.ReduceOp,
  220. after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
  221. ) -> torch.Tensor:
  222. """
  223. Wrapper over torch.distributed.all_reduce for performing
  224. reduction of tensor over all processes in both distributed /
  225. non-distributed scenarios.
  226. """
  227. if is_distributed_training_run():
  228. tensor, orig_device = convert_to_distributed_tensor(tensor)
  229. torch.distributed.all_reduce(tensor, op)
  230. if after_op_func is not None:
  231. tensor = after_op_func(tensor)
  232. tensor = convert_to_normal_tensor(tensor, orig_device)
  233. return tensor
  234. def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
  235. """
  236. Wrapper over torch.distributed.all_gather for performing
  237. 'gather' of 'tensor' over all processes in both distributed /
  238. non-distributed scenarios.
  239. """
  240. if tensor.ndim == 0:
  241. # 0 dim tensors cannot be gathered. so unsqueeze
  242. tensor = tensor.unsqueeze(0)
  243. if is_distributed_training_run():
  244. tensor, orig_device = convert_to_distributed_tensor(tensor)
  245. gathered_tensors = [
  246. torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
  247. ]
  248. torch.distributed.all_gather(gathered_tensors, tensor)
  249. gathered_tensors = [
  250. convert_to_normal_tensor(_tensor, orig_device)
  251. for _tensor in gathered_tensors
  252. ]
  253. else:
  254. gathered_tensors = [tensor]
  255. return gathered_tensors
  256. def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
  257. gathered_tensors = gather_tensors_from_all(tensor)
  258. gathered_tensor = torch.cat(gathered_tensors, 0)
  259. return gathered_tensor
  260. def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
  261. """
  262. Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
  263. to all processes in both distributed / non-distributed scenarios.
  264. """
  265. if is_distributed_training_run():
  266. tensor, orig_device = convert_to_distributed_tensor(tensor)
  267. torch.distributed.broadcast(tensor, src)
  268. tensor = convert_to_normal_tensor(tensor, orig_device)
  269. return tensor
  270. def barrier() -> None:
  271. """
  272. Wrapper over torch.distributed.barrier, returns without waiting
  273. if the distributed process group is not initialized instead of throwing error.
  274. """
  275. if not torch.distributed.is_available() or not torch.distributed.is_initialized():
  276. return
  277. torch.distributed.barrier()
  278. def get_world_size() -> int:
  279. """
  280. Simple wrapper for correctly getting worldsize in both distributed
  281. / non-distributed settings
  282. """
  283. return (
  284. torch.distributed.get_world_size()
  285. if torch.distributed.is_available() and torch.distributed.is_initialized()
  286. else 1
  287. )
  288. def get_rank() -> int:
  289. """
  290. Simple wrapper for correctly getting rank in both distributed
  291. / non-distributed settings
  292. """
  293. return (
  294. torch.distributed.get_rank()
  295. if torch.distributed.is_available() and torch.distributed.is_initialized()
  296. else 0
  297. )
  298. def get_primary_rank() -> int:
  299. return _PRIMARY_RANK
  300. def set_cuda_device_index(idx: int) -> None:
  301. global _cuda_device_index
  302. _cuda_device_index = idx
  303. torch.cuda.set_device(_cuda_device_index)
  304. def set_cpu_device() -> None:
  305. global _cuda_device_index
  306. _cuda_device_index = _CPU_DEVICE_INDEX
  307. def get_cuda_device_index() -> int:
  308. return _cuda_device_index
  309. def init_distributed_data_parallel_model(
  310. model: torch.nn.Module,
  311. broadcast_buffers: bool = False,
  312. find_unused_parameters: bool = True,
  313. bucket_cap_mb: int = 25,
  314. ) -> torch.nn.parallel.DistributedDataParallel:
  315. global _cuda_device_index
  316. if _cuda_device_index == _CPU_DEVICE_INDEX:
  317. # CPU-only model, don't specify device
  318. return torch.nn.parallel.DistributedDataParallel(
  319. model,
  320. broadcast_buffers=broadcast_buffers,
  321. find_unused_parameters=find_unused_parameters,
  322. bucket_cap_mb=bucket_cap_mb,
  323. )
  324. else:
  325. # GPU model
  326. return torch.nn.parallel.DistributedDataParallel(
  327. model,
  328. device_ids=[_cuda_device_index],
  329. output_device=_cuda_device_index,
  330. broadcast_buffers=broadcast_buffers,
  331. find_unused_parameters=find_unused_parameters,
  332. bucket_cap_mb=bucket_cap_mb,
  333. )
  334. def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
  335. """Broadcast an object from a source to all workers.
  336. Args:
  337. obj: Object to broadcast, must be serializable
  338. src: Source rank for broadcast (default is primary)
  339. use_disk: If enabled, removes redundant CPU memory copies by writing to
  340. disk
  341. """
  342. # Either broadcast from primary to the fleet (default),
  343. # or use the src setting as the original rank
  344. if get_rank() == src:
  345. # Emit data
  346. buffer = io.BytesIO()
  347. torch.save(obj, buffer)
  348. data_view = buffer.getbuffer()
  349. length_tensor = torch.LongTensor([len(data_view)])
  350. length_tensor = broadcast(length_tensor, src=src)
  351. data_tensor = torch.ByteTensor(data_view)
  352. data_tensor = broadcast(data_tensor, src=src)
  353. else:
  354. # Fetch from the source
  355. length_tensor = torch.LongTensor([0])
  356. length_tensor = broadcast(length_tensor, src=src)
  357. data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
  358. data_tensor = broadcast(data_tensor, src=src)
  359. if use_disk:
  360. with tempfile.TemporaryFile("r+b") as f:
  361. f.write(data_tensor.numpy())
  362. # remove reference to the data tensor and hope that Python garbage
  363. # collects it
  364. del data_tensor
  365. f.seek(0)
  366. obj = torch.load(f)
  367. else:
  368. buffer = io.BytesIO(data_tensor.numpy())
  369. obj = torch.load(buffer)
  370. return obj
  371. def all_gather_tensor(tensor: torch.Tensor, world_size=None):
  372. if world_size is None:
  373. world_size = get_world_size()
  374. # make contiguous because NCCL won't gather the tensor otherwise
  375. assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
  376. tensor, orig_device = convert_to_distributed_tensor(tensor)
  377. tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
  378. dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
  379. tensor_all = [
  380. convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
  381. ]
  382. return tensor_all
  383. def all_gather_batch(tensors: List[torch.Tensor]):
  384. """
  385. Performs all_gather operation on the provided tensors.
  386. """
  387. # Queue the gathered tensors
  388. world_size = get_world_size()
  389. # There is no need for reduction in the single-proc case
  390. if world_size == 1:
  391. return tensors
  392. tensor_list = []
  393. output_tensor = []
  394. for tensor in tensors:
  395. tensor_all = all_gather_tensor(tensor, world_size)
  396. tensor_list.append(tensor_all)
  397. for tensor_all in tensor_list:
  398. output_tensor.append(torch.cat(tensor_all, dim=0))
  399. return output_tensor
  400. class GatherLayer(autograd.Function):
  401. """
  402. Gather tensors from all workers with support for backward propagation:
  403. This implementation does not cut the gradients as torch.distributed.all_gather does.
  404. """
  405. @staticmethod
  406. def forward(ctx, x):
  407. output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
  408. dist.all_gather(output, x)
  409. return tuple(output)
  410. @staticmethod
  411. def backward(ctx, *grads):
  412. all_gradients = torch.stack(grads)
  413. dist.all_reduce(all_gradients)
  414. return all_gradients[dist.get_rank()]
  415. def all_gather_batch_with_grad(tensors):
  416. """
  417. Performs all_gather operation on the provided tensors.
  418. Graph remains connected for backward grad computation.
  419. """
  420. # Queue the gathered tensors
  421. world_size = get_world_size()
  422. # There is no need for reduction in the single-proc case
  423. if world_size == 1:
  424. return tensors
  425. tensor_list = []
  426. output_tensor = []
  427. for tensor in tensors:
  428. tensor_all = GatherLayer.apply(tensor)
  429. tensor_list.append(tensor_all)
  430. for tensor_all in tensor_list:
  431. output_tensor.append(torch.cat(tensor_all, dim=0))
  432. return output_tensor
  433. def unwrap_ddp_if_wrapped(model):
  434. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  435. return model.module
  436. return model
  437. def create_new_process_group(group_size):
  438. """
  439. Creates process groups of a gives `group_size` and returns
  440. process group that current GPU participates in.
  441. `group_size` must divide the total number of GPUs (world_size).
  442. Modified from
  443. https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
  444. Args:
  445. group_size (int): number of GPU's to collaborate for sync bn
  446. """
  447. assert group_size > 0
  448. world_size = torch.distributed.get_world_size()
  449. if world_size <= 8:
  450. if group_size > world_size:
  451. logging.warning(
  452. f"Requested group size [{group_size}] > world size [{world_size}]. "
  453. "Assuming local debug run and capping it to world size."
  454. )
  455. group_size = world_size
  456. assert world_size >= group_size
  457. assert world_size % group_size == 0
  458. group = None
  459. for group_num in range(world_size // group_size):
  460. group_ids = range(group_num * group_size, (group_num + 1) * group_size)
  461. cur_group = torch.distributed.new_group(ranks=group_ids)
  462. if torch.distributed.get_rank() // group_size == group_num:
  463. group = cur_group
  464. # can not drop out and return here, every process must go through creation of all subgroups
  465. assert group is not None
  466. return group
  467. def is_dist_avail_and_initialized():
  468. if not dist.is_available():
  469. return False
  470. if not dist.is_initialized():
  471. return False
  472. return True