distributed.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. # Copyright (c) Meta Platforms, Inc. and affiliates.
  4. # All rights reserved.
  5. # This source code is licensed under the license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import datetime
  8. import functools
  9. import io
  10. import logging
  11. import os
  12. import random
  13. import tempfile
  14. import time
  15. from typing import Any, Callable, List, Tuple
  16. import torch
  17. import torch.autograd as autograd
  18. import torch.distributed as dist
  19. # Default to GPU 0
  20. _cuda_device_index: int = 0
  21. # Setting _cuda_device_index to -1 internally implies that we should use CPU
  22. _CPU_DEVICE_INDEX = -1
  23. _PRIMARY_RANK = 0
  24. @functools.lru_cache()
  25. def _get_global_gloo_group():
  26. """
  27. Return a process group based on gloo backend, containing all the ranks
  28. The result is cached.
  29. """
  30. if dist.get_backend() == "nccl":
  31. # Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
  32. # being much slower than others causing a timeout (which can happen in relation
  33. # or LVIS class mAP evaluation).
  34. timeout = 43200
  35. return dist.new_group(
  36. backend="gloo",
  37. timeout=datetime.timedelta(seconds=timeout),
  38. )
  39. return dist.group.WORLD
  40. def is_main_process():
  41. """Return true if the current process is the main one"""
  42. return get_rank() == 0
  43. def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
  44. """
  45. Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
  46. `all_gather` above, but using filesystem instead of collective ops.
  47. If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
  48. (and other ranks will have an empty list).
  49. """
  50. world_size = get_world_size()
  51. if world_size == 1:
  52. return [data]
  53. print("gathering via files")
  54. cpu_group = _get_global_gloo_group()
  55. # if unspecified, we will save to the current python file dir
  56. if filesys_save_dir is not None:
  57. save_dir = filesys_save_dir
  58. elif "EXP_DIR" in os.environ:
  59. save_dir = os.environ["EXP_DIR"]
  60. else:
  61. # try the same directory where the code is stored
  62. save_dir = filesys_save_dir or os.path.dirname(__file__)
  63. save_dir = os.path.join(save_dir, "all_gather_via_filesys")
  64. if is_main_process():
  65. os.makedirs(save_dir, exist_ok=True)
  66. # use a timestamp and salt to distinguish different all_gather
  67. timestamp = int(time.time()) if is_main_process() else 0
  68. salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
  69. # broadcast the timestamp and salt across ranks
  70. # (all-reduce will do the broadcasting since only rank 0 is non-zero)
  71. timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
  72. dist.all_reduce(timestamp_and_salt, group=cpu_group)
  73. timestamp, salt = timestamp_and_salt.tolist()
  74. # save the data to a file on the disk
  75. rank_save = get_rank()
  76. save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
  77. save_data_path = os.path.join(save_dir, save_data_filename)
  78. assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
  79. torch.save(data, save_data_path)
  80. dist.barrier(group=cpu_group)
  81. # read the data from the files
  82. data_list = []
  83. if rank_save == 0 or not gather_to_rank_0_only:
  84. for rank_load in range(world_size):
  85. load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
  86. load_data_path = os.path.join(save_dir, load_data_filename)
  87. assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
  88. data_list.append(torch.load(load_data_path, weights_only=False))
  89. dist.barrier(group=cpu_group)
  90. # delete the saved file
  91. os.remove(save_data_path)
  92. return data_list
  93. def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
  94. """
  95. Run all_gather on arbitrary picklable data (not necessarily tensors)
  96. Args:
  97. data: any picklable object
  98. Returns:
  99. list[data]: list of data gathered from each rank
  100. """
  101. world_size = get_world_size()
  102. if world_size == 1:
  103. return [data]
  104. if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
  105. return all_gather_via_filesys(
  106. data, filesys_save_dir, gather_to_rank_0_only=True
  107. )
  108. if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
  109. return all_gather_via_filesys(data, filesys_save_dir)
  110. cpu_group = None
  111. if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
  112. cpu_group = _get_global_gloo_group()
  113. buffer = io.BytesIO()
  114. torch.save(data, buffer)
  115. data_view = buffer.getbuffer()
  116. device = "cuda" if cpu_group is None else "cpu"
  117. tensor = torch.ByteTensor(data_view).to(device)
  118. # obtain Tensor size of each rank
  119. local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
  120. size_list = [
  121. torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
  122. ]
  123. if cpu_group is None:
  124. dist.all_gather(size_list, local_size)
  125. else:
  126. print("gathering on cpu")
  127. dist.all_gather(size_list, local_size, group=cpu_group)
  128. size_list = [int(size.item()) for size in size_list]
  129. max_size = max(size_list)
  130. assert isinstance(local_size.item(), int)
  131. local_size = int(local_size.item())
  132. # receiving Tensor from all ranks
  133. # we pad the tensor because torch all_gather does not support
  134. # gathering tensors of different shapes
  135. tensor_list = []
  136. for _ in size_list:
  137. tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
  138. if local_size != max_size:
  139. padding = torch.empty(
  140. size=(max_size - local_size,), dtype=torch.uint8, device=device
  141. )
  142. tensor = torch.cat((tensor, padding), dim=0)
  143. if cpu_group is None:
  144. dist.all_gather(tensor_list, tensor)
  145. else:
  146. dist.all_gather(tensor_list, tensor, group=cpu_group)
  147. data_list = []
  148. for size, tensor in zip(size_list, tensor_list):
  149. tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
  150. buffer = io.BytesIO(tensor.cpu().numpy())
  151. obj = torch.load(buffer, weights_only=False)
  152. data_list.append(obj)
  153. return data_list
  154. def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
  155. """
  156. For some backends, such as NCCL, communication only works if the
  157. tensor is on the GPU. This helper function converts to the correct
  158. device and returns the tensor + original device.
  159. """
  160. orig_device = "cpu" if not tensor.is_cuda else "gpu"
  161. if (
  162. torch.distributed.is_available()
  163. and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
  164. and not tensor.is_cuda
  165. ):
  166. tensor = tensor.cuda()
  167. return (tensor, orig_device)
  168. def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
  169. """
  170. For some backends, such as NCCL, communication only works if the
  171. tensor is on the GPU. This converts the tensor back to original device.
  172. """
  173. if tensor.is_cuda and orig_device == "cpu":
  174. tensor = tensor.cpu()
  175. return tensor
  176. def is_distributed_training_run() -> bool:
  177. return (
  178. torch.distributed.is_available()
  179. and torch.distributed.is_initialized()
  180. and (torch.distributed.get_world_size() > 1)
  181. )
  182. def is_primary() -> bool:
  183. """
  184. Returns True if this is rank 0 of a distributed training job OR if it is
  185. a single trainer job. Otherwise False.
  186. """
  187. return get_rank() == _PRIMARY_RANK
  188. def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
  189. """
  190. Wrapper over torch.distributed.all_reduce for performing mean reduction
  191. of tensor over all processes.
  192. """
  193. return all_reduce_op(
  194. tensor,
  195. torch.distributed.ReduceOp.SUM,
  196. lambda t: t / torch.distributed.get_world_size(),
  197. )
  198. def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
  199. """
  200. Wrapper over torch.distributed.all_reduce for performing sum
  201. reduction of tensor over all processes in both distributed /
  202. non-distributed scenarios.
  203. """
  204. return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
  205. def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
  206. """
  207. Wrapper over torch.distributed.all_reduce for performing min
  208. reduction of tensor over all processes in both distributed /
  209. non-distributed scenarios.
  210. """
  211. return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
  212. def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
  213. """
  214. Wrapper over torch.distributed.all_reduce for performing min
  215. reduction of tensor over all processes in both distributed /
  216. non-distributed scenarios.
  217. """
  218. return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
  219. def all_reduce_op(
  220. tensor: torch.Tensor,
  221. op: torch.distributed.ReduceOp,
  222. after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
  223. ) -> torch.Tensor:
  224. """
  225. Wrapper over torch.distributed.all_reduce for performing
  226. reduction of tensor over all processes in both distributed /
  227. non-distributed scenarios.
  228. """
  229. if is_distributed_training_run():
  230. tensor, orig_device = convert_to_distributed_tensor(tensor)
  231. torch.distributed.all_reduce(tensor, op)
  232. if after_op_func is not None:
  233. tensor = after_op_func(tensor)
  234. tensor = convert_to_normal_tensor(tensor, orig_device)
  235. return tensor
  236. def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
  237. """
  238. Wrapper over torch.distributed.all_gather for performing
  239. 'gather' of 'tensor' over all processes in both distributed /
  240. non-distributed scenarios.
  241. """
  242. if tensor.ndim == 0:
  243. # 0 dim tensors cannot be gathered. so unsqueeze
  244. tensor = tensor.unsqueeze(0)
  245. if is_distributed_training_run():
  246. tensor, orig_device = convert_to_distributed_tensor(tensor)
  247. gathered_tensors = [
  248. torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
  249. ]
  250. torch.distributed.all_gather(gathered_tensors, tensor)
  251. gathered_tensors = [
  252. convert_to_normal_tensor(_tensor, orig_device)
  253. for _tensor in gathered_tensors
  254. ]
  255. else:
  256. gathered_tensors = [tensor]
  257. return gathered_tensors
  258. def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
  259. gathered_tensors = gather_tensors_from_all(tensor)
  260. gathered_tensor = torch.cat(gathered_tensors, 0)
  261. return gathered_tensor
  262. def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
  263. """
  264. Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
  265. to all processes in both distributed / non-distributed scenarios.
  266. """
  267. if is_distributed_training_run():
  268. tensor, orig_device = convert_to_distributed_tensor(tensor)
  269. torch.distributed.broadcast(tensor, src)
  270. tensor = convert_to_normal_tensor(tensor, orig_device)
  271. return tensor
  272. def barrier() -> None:
  273. """
  274. Wrapper over torch.distributed.barrier, returns without waiting
  275. if the distributed process group is not initialized instead of throwing error.
  276. """
  277. if not torch.distributed.is_available() or not torch.distributed.is_initialized():
  278. return
  279. torch.distributed.barrier()
  280. def get_world_size() -> int:
  281. """
  282. Simple wrapper for correctly getting worldsize in both distributed
  283. / non-distributed settings
  284. """
  285. return (
  286. torch.distributed.get_world_size()
  287. if torch.distributed.is_available() and torch.distributed.is_initialized()
  288. else 1
  289. )
  290. def get_rank() -> int:
  291. """
  292. Simple wrapper for correctly getting rank in both distributed
  293. / non-distributed settings
  294. """
  295. return (
  296. torch.distributed.get_rank()
  297. if torch.distributed.is_available() and torch.distributed.is_initialized()
  298. else 0
  299. )
  300. def get_primary_rank() -> int:
  301. return _PRIMARY_RANK
  302. def set_cuda_device_index(idx: int) -> None:
  303. global _cuda_device_index
  304. _cuda_device_index = idx
  305. torch.cuda.set_device(_cuda_device_index)
  306. def set_cpu_device() -> None:
  307. global _cuda_device_index
  308. _cuda_device_index = _CPU_DEVICE_INDEX
  309. def get_cuda_device_index() -> int:
  310. return _cuda_device_index
  311. def init_distributed_data_parallel_model(
  312. model: torch.nn.Module,
  313. broadcast_buffers: bool = False,
  314. find_unused_parameters: bool = True,
  315. bucket_cap_mb: int = 25,
  316. ) -> torch.nn.parallel.DistributedDataParallel:
  317. global _cuda_device_index
  318. if _cuda_device_index == _CPU_DEVICE_INDEX:
  319. # CPU-only model, don't specify device
  320. return torch.nn.parallel.DistributedDataParallel(
  321. model,
  322. broadcast_buffers=broadcast_buffers,
  323. find_unused_parameters=find_unused_parameters,
  324. bucket_cap_mb=bucket_cap_mb,
  325. )
  326. else:
  327. # GPU model
  328. return torch.nn.parallel.DistributedDataParallel(
  329. model,
  330. device_ids=[_cuda_device_index],
  331. output_device=_cuda_device_index,
  332. broadcast_buffers=broadcast_buffers,
  333. find_unused_parameters=find_unused_parameters,
  334. bucket_cap_mb=bucket_cap_mb,
  335. )
  336. def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
  337. """Broadcast an object from a source to all workers.
  338. Args:
  339. obj: Object to broadcast, must be serializable
  340. src: Source rank for broadcast (default is primary)
  341. use_disk: If enabled, removes redundant CPU memory copies by writing to
  342. disk
  343. """
  344. # Either broadcast from primary to the fleet (default),
  345. # or use the src setting as the original rank
  346. if get_rank() == src:
  347. # Emit data
  348. buffer = io.BytesIO()
  349. torch.save(obj, buffer)
  350. data_view = buffer.getbuffer()
  351. length_tensor = torch.LongTensor([len(data_view)])
  352. length_tensor = broadcast(length_tensor, src=src)
  353. data_tensor = torch.ByteTensor(data_view)
  354. data_tensor = broadcast(data_tensor, src=src)
  355. else:
  356. # Fetch from the source
  357. length_tensor = torch.LongTensor([0])
  358. length_tensor = broadcast(length_tensor, src=src)
  359. data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
  360. data_tensor = broadcast(data_tensor, src=src)
  361. if use_disk:
  362. with tempfile.TemporaryFile("r+b") as f:
  363. f.write(data_tensor.numpy())
  364. # remove reference to the data tensor and hope that Python garbage
  365. # collects it
  366. del data_tensor
  367. f.seek(0)
  368. obj = torch.load(f, weights_only=False)
  369. else:
  370. buffer = io.BytesIO(data_tensor.numpy())
  371. obj = torch.load(buffer, weights_only=False)
  372. return obj
  373. def all_gather_tensor(tensor: torch.Tensor, world_size=None):
  374. if world_size is None:
  375. world_size = get_world_size()
  376. # make contiguous because NCCL won't gather the tensor otherwise
  377. assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
  378. tensor, orig_device = convert_to_distributed_tensor(tensor)
  379. tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
  380. dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
  381. tensor_all = [
  382. convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
  383. ]
  384. return tensor_all
  385. def all_gather_batch(tensors: List[torch.Tensor]):
  386. """
  387. Performs all_gather operation on the provided tensors.
  388. """
  389. # Queue the gathered tensors
  390. world_size = get_world_size()
  391. # There is no need for reduction in the single-proc case
  392. if world_size == 1:
  393. return tensors
  394. tensor_list = []
  395. output_tensor = []
  396. for tensor in tensors:
  397. tensor_all = all_gather_tensor(tensor, world_size)
  398. tensor_list.append(tensor_all)
  399. for tensor_all in tensor_list:
  400. output_tensor.append(torch.cat(tensor_all, dim=0))
  401. return output_tensor
  402. class GatherLayer(autograd.Function):
  403. """
  404. Gather tensors from all workers with support for backward propagation:
  405. This implementation does not cut the gradients as torch.distributed.all_gather does.
  406. """
  407. @staticmethod
  408. def forward(ctx, x):
  409. output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
  410. dist.all_gather(output, x)
  411. return tuple(output)
  412. @staticmethod
  413. def backward(ctx, *grads):
  414. all_gradients = torch.stack(grads)
  415. dist.all_reduce(all_gradients)
  416. return all_gradients[dist.get_rank()]
  417. def all_gather_batch_with_grad(tensors):
  418. """
  419. Performs all_gather operation on the provided tensors.
  420. Graph remains connected for backward grad computation.
  421. """
  422. # Queue the gathered tensors
  423. world_size = get_world_size()
  424. # There is no need for reduction in the single-proc case
  425. if world_size == 1:
  426. return tensors
  427. tensor_list = []
  428. output_tensor = []
  429. for tensor in tensors:
  430. tensor_all = GatherLayer.apply(tensor)
  431. tensor_list.append(tensor_all)
  432. for tensor_all in tensor_list:
  433. output_tensor.append(torch.cat(tensor_all, dim=0))
  434. return output_tensor
  435. def unwrap_ddp_if_wrapped(model):
  436. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  437. return model.module
  438. return model
  439. def create_new_process_group(group_size):
  440. """
  441. Creates process groups of a gives `group_size` and returns
  442. process group that current GPU participates in.
  443. `group_size` must divide the total number of GPUs (world_size).
  444. Modified from
  445. https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
  446. Args:
  447. group_size (int): number of GPU's to collaborate for sync bn
  448. """
  449. assert group_size > 0
  450. world_size = torch.distributed.get_world_size()
  451. if world_size <= 8:
  452. if group_size > world_size:
  453. logging.warning(
  454. f"Requested group size [{group_size}] > world size [{world_size}]. "
  455. "Assuming local debug run and capping it to world size."
  456. )
  457. group_size = world_size
  458. assert world_size >= group_size
  459. assert world_size % group_size == 0
  460. group = None
  461. for group_num in range(world_size // group_size):
  462. group_ids = range(group_num * group_size, (group_num + 1) * group_size)
  463. cur_group = torch.distributed.new_group(ranks=group_ids)
  464. if torch.distributed.get_rank() // group_size == group_num:
  465. group = cur_group
  466. # can not drop out and return here, every process must go through creation of all subgroups
  467. assert group is not None
  468. return group
  469. def is_dist_avail_and_initialized():
  470. if not dist.is_available():
  471. return False
  472. if not dist.is_initialized():
  473. return False
  474. return True
  475. def gather_to_rank_0_via_filesys(data, filesys_save_dir=None):
  476. """
  477. Gather any picklable data to rank 0 via filesystem, using all_gather_via_filesys.
  478. """
  479. return all_gather_via_filesys(data, filesys_save_dir, gather_to_rank_0_only=True)