| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import datetime
- import functools
- import io
- import logging
- import os
- import random
- import tempfile
- import time
- from typing import Any, Callable, List, Tuple
- import torch
- import torch.autograd as autograd
- import torch.distributed as dist
- # Default to GPU 0
- _cuda_device_index: int = 0
- # Setting _cuda_device_index to -1 internally implies that we should use CPU
- _CPU_DEVICE_INDEX = -1
- _PRIMARY_RANK = 0
- @functools.lru_cache()
- def _get_global_gloo_group():
- """
- Return a process group based on gloo backend, containing all the ranks
- The result is cached.
- """
- if dist.get_backend() == "nccl":
- # Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
- # being much slower than others causing a timeout (which can happen in relation
- # or LVIS class mAP evaluation).
- timeout = 43200
- return dist.new_group(
- backend="gloo",
- timeout=datetime.timedelta(seconds=timeout),
- )
- return dist.group.WORLD
- def is_main_process():
- """Return true if the current process is the main one"""
- return get_rank() == 0
- def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
- """
- Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
- `all_gather` above, but using filesystem instead of collective ops.
- If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
- (and other ranks will have an empty list).
- """
- world_size = get_world_size()
- if world_size == 1:
- return [data]
- print("gathering via files")
- cpu_group = _get_global_gloo_group()
- # if unspecified, we will save to the current python file dir
- if filesys_save_dir is not None:
- save_dir = filesys_save_dir
- elif "EXP_DIR" in os.environ:
- save_dir = os.environ["EXP_DIR"]
- else:
- # try the same directory where the code is stored
- save_dir = filesys_save_dir or os.path.dirname(__file__)
- save_dir = os.path.join(save_dir, "all_gather_via_filesys")
- if is_main_process():
- os.makedirs(save_dir, exist_ok=True)
- # use a timestamp and salt to distinguish different all_gather
- timestamp = int(time.time()) if is_main_process() else 0
- salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
- # broadcast the timestamp and salt across ranks
- # (all-reduce will do the broadcasting since only rank 0 is non-zero)
- timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
- dist.all_reduce(timestamp_and_salt, group=cpu_group)
- timestamp, salt = timestamp_and_salt.tolist()
- # save the data to a file on the disk
- rank_save = get_rank()
- save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
- save_data_path = os.path.join(save_dir, save_data_filename)
- assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
- torch.save(data, save_data_path)
- dist.barrier(group=cpu_group)
- # read the data from the files
- data_list = []
- if rank_save == 0 or not gather_to_rank_0_only:
- for rank_load in range(world_size):
- load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
- load_data_path = os.path.join(save_dir, load_data_filename)
- assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
- data_list.append(torch.load(load_data_path, weights_only=False))
- dist.barrier(group=cpu_group)
- # delete the saved file
- os.remove(save_data_path)
- return data_list
- def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
- """
- Run all_gather on arbitrary picklable data (not necessarily tensors)
- Args:
- data: any picklable object
- Returns:
- list[data]: list of data gathered from each rank
- """
- world_size = get_world_size()
- if world_size == 1:
- return [data]
- if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
- return all_gather_via_filesys(
- data, filesys_save_dir, gather_to_rank_0_only=True
- )
- if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
- return all_gather_via_filesys(data, filesys_save_dir)
- cpu_group = None
- if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
- cpu_group = _get_global_gloo_group()
- buffer = io.BytesIO()
- torch.save(data, buffer)
- data_view = buffer.getbuffer()
- device = "cuda" if cpu_group is None else "cpu"
- tensor = torch.ByteTensor(data_view).to(device)
- # obtain Tensor size of each rank
- local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
- size_list = [
- torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
- ]
- if cpu_group is None:
- dist.all_gather(size_list, local_size)
- else:
- print("gathering on cpu")
- dist.all_gather(size_list, local_size, group=cpu_group)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
- assert isinstance(local_size.item(), int)
- local_size = int(local_size.item())
- # receiving Tensor from all ranks
- # we pad the tensor because torch all_gather does not support
- # gathering tensors of different shapes
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
- if local_size != max_size:
- padding = torch.empty(
- size=(max_size - local_size,), dtype=torch.uint8, device=device
- )
- tensor = torch.cat((tensor, padding), dim=0)
- if cpu_group is None:
- dist.all_gather(tensor_list, tensor)
- else:
- dist.all_gather(tensor_list, tensor, group=cpu_group)
- data_list = []
- for size, tensor in zip(size_list, tensor_list):
- tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
- buffer = io.BytesIO(tensor.cpu().numpy())
- obj = torch.load(buffer, weights_only=False)
- data_list.append(obj)
- return data_list
- def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
- """
- For some backends, such as NCCL, communication only works if the
- tensor is on the GPU. This helper function converts to the correct
- device and returns the tensor + original device.
- """
- orig_device = "cpu" if not tensor.is_cuda else "gpu"
- if (
- torch.distributed.is_available()
- and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
- and not tensor.is_cuda
- ):
- tensor = tensor.cuda()
- return (tensor, orig_device)
- def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
- """
- For some backends, such as NCCL, communication only works if the
- tensor is on the GPU. This converts the tensor back to original device.
- """
- if tensor.is_cuda and orig_device == "cpu":
- tensor = tensor.cpu()
- return tensor
- def is_distributed_training_run() -> bool:
- return (
- torch.distributed.is_available()
- and torch.distributed.is_initialized()
- and (torch.distributed.get_world_size() > 1)
- )
- def is_primary() -> bool:
- """
- Returns True if this is rank 0 of a distributed training job OR if it is
- a single trainer job. Otherwise False.
- """
- return get_rank() == _PRIMARY_RANK
- def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
- """
- Wrapper over torch.distributed.all_reduce for performing mean reduction
- of tensor over all processes.
- """
- return all_reduce_op(
- tensor,
- torch.distributed.ReduceOp.SUM,
- lambda t: t / torch.distributed.get_world_size(),
- )
- def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
- """
- Wrapper over torch.distributed.all_reduce for performing sum
- reduction of tensor over all processes in both distributed /
- non-distributed scenarios.
- """
- return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
- def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
- """
- Wrapper over torch.distributed.all_reduce for performing min
- reduction of tensor over all processes in both distributed /
- non-distributed scenarios.
- """
- return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
- def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
- """
- Wrapper over torch.distributed.all_reduce for performing min
- reduction of tensor over all processes in both distributed /
- non-distributed scenarios.
- """
- return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
- def all_reduce_op(
- tensor: torch.Tensor,
- op: torch.distributed.ReduceOp,
- after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
- ) -> torch.Tensor:
- """
- Wrapper over torch.distributed.all_reduce for performing
- reduction of tensor over all processes in both distributed /
- non-distributed scenarios.
- """
- if is_distributed_training_run():
- tensor, orig_device = convert_to_distributed_tensor(tensor)
- torch.distributed.all_reduce(tensor, op)
- if after_op_func is not None:
- tensor = after_op_func(tensor)
- tensor = convert_to_normal_tensor(tensor, orig_device)
- return tensor
- def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
- """
- Wrapper over torch.distributed.all_gather for performing
- 'gather' of 'tensor' over all processes in both distributed /
- non-distributed scenarios.
- """
- if tensor.ndim == 0:
- # 0 dim tensors cannot be gathered. so unsqueeze
- tensor = tensor.unsqueeze(0)
- if is_distributed_training_run():
- tensor, orig_device = convert_to_distributed_tensor(tensor)
- gathered_tensors = [
- torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
- ]
- torch.distributed.all_gather(gathered_tensors, tensor)
- gathered_tensors = [
- convert_to_normal_tensor(_tensor, orig_device)
- for _tensor in gathered_tensors
- ]
- else:
- gathered_tensors = [tensor]
- return gathered_tensors
- def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
- gathered_tensors = gather_tensors_from_all(tensor)
- gathered_tensor = torch.cat(gathered_tensors, 0)
- return gathered_tensor
- def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
- """
- Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
- to all processes in both distributed / non-distributed scenarios.
- """
- if is_distributed_training_run():
- tensor, orig_device = convert_to_distributed_tensor(tensor)
- torch.distributed.broadcast(tensor, src)
- tensor = convert_to_normal_tensor(tensor, orig_device)
- return tensor
- def barrier() -> None:
- """
- Wrapper over torch.distributed.barrier, returns without waiting
- if the distributed process group is not initialized instead of throwing error.
- """
- if not torch.distributed.is_available() or not torch.distributed.is_initialized():
- return
- torch.distributed.barrier()
- def get_world_size() -> int:
- """
- Simple wrapper for correctly getting worldsize in both distributed
- / non-distributed settings
- """
- return (
- torch.distributed.get_world_size()
- if torch.distributed.is_available() and torch.distributed.is_initialized()
- else 1
- )
- def get_rank() -> int:
- """
- Simple wrapper for correctly getting rank in both distributed
- / non-distributed settings
- """
- return (
- torch.distributed.get_rank()
- if torch.distributed.is_available() and torch.distributed.is_initialized()
- else 0
- )
- def get_primary_rank() -> int:
- return _PRIMARY_RANK
- def set_cuda_device_index(idx: int) -> None:
- global _cuda_device_index
- _cuda_device_index = idx
- torch.cuda.set_device(_cuda_device_index)
- def set_cpu_device() -> None:
- global _cuda_device_index
- _cuda_device_index = _CPU_DEVICE_INDEX
- def get_cuda_device_index() -> int:
- return _cuda_device_index
- def init_distributed_data_parallel_model(
- model: torch.nn.Module,
- broadcast_buffers: bool = False,
- find_unused_parameters: bool = True,
- bucket_cap_mb: int = 25,
- ) -> torch.nn.parallel.DistributedDataParallel:
- global _cuda_device_index
- if _cuda_device_index == _CPU_DEVICE_INDEX:
- # CPU-only model, don't specify device
- return torch.nn.parallel.DistributedDataParallel(
- model,
- broadcast_buffers=broadcast_buffers,
- find_unused_parameters=find_unused_parameters,
- bucket_cap_mb=bucket_cap_mb,
- )
- else:
- # GPU model
- return torch.nn.parallel.DistributedDataParallel(
- model,
- device_ids=[_cuda_device_index],
- output_device=_cuda_device_index,
- broadcast_buffers=broadcast_buffers,
- find_unused_parameters=find_unused_parameters,
- bucket_cap_mb=bucket_cap_mb,
- )
- def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
- """Broadcast an object from a source to all workers.
- Args:
- obj: Object to broadcast, must be serializable
- src: Source rank for broadcast (default is primary)
- use_disk: If enabled, removes redundant CPU memory copies by writing to
- disk
- """
- # Either broadcast from primary to the fleet (default),
- # or use the src setting as the original rank
- if get_rank() == src:
- # Emit data
- buffer = io.BytesIO()
- torch.save(obj, buffer)
- data_view = buffer.getbuffer()
- length_tensor = torch.LongTensor([len(data_view)])
- length_tensor = broadcast(length_tensor, src=src)
- data_tensor = torch.ByteTensor(data_view)
- data_tensor = broadcast(data_tensor, src=src)
- else:
- # Fetch from the source
- length_tensor = torch.LongTensor([0])
- length_tensor = broadcast(length_tensor, src=src)
- data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
- data_tensor = broadcast(data_tensor, src=src)
- if use_disk:
- with tempfile.TemporaryFile("r+b") as f:
- f.write(data_tensor.numpy())
- # remove reference to the data tensor and hope that Python garbage
- # collects it
- del data_tensor
- f.seek(0)
- obj = torch.load(f, weights_only=False)
- else:
- buffer = io.BytesIO(data_tensor.numpy())
- obj = torch.load(buffer, weights_only=False)
- return obj
- def all_gather_tensor(tensor: torch.Tensor, world_size=None):
- if world_size is None:
- world_size = get_world_size()
- # make contiguous because NCCL won't gather the tensor otherwise
- assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
- tensor, orig_device = convert_to_distributed_tensor(tensor)
- tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
- dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
- tensor_all = [
- convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
- ]
- return tensor_all
- def all_gather_batch(tensors: List[torch.Tensor]):
- """
- Performs all_gather operation on the provided tensors.
- """
- # Queue the gathered tensors
- world_size = get_world_size()
- # There is no need for reduction in the single-proc case
- if world_size == 1:
- return tensors
- tensor_list = []
- output_tensor = []
- for tensor in tensors:
- tensor_all = all_gather_tensor(tensor, world_size)
- tensor_list.append(tensor_all)
- for tensor_all in tensor_list:
- output_tensor.append(torch.cat(tensor_all, dim=0))
- return output_tensor
- class GatherLayer(autograd.Function):
- """
- Gather tensors from all workers with support for backward propagation:
- This implementation does not cut the gradients as torch.distributed.all_gather does.
- """
- @staticmethod
- def forward(ctx, x):
- output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
- dist.all_gather(output, x)
- return tuple(output)
- @staticmethod
- def backward(ctx, *grads):
- all_gradients = torch.stack(grads)
- dist.all_reduce(all_gradients)
- return all_gradients[dist.get_rank()]
- def all_gather_batch_with_grad(tensors):
- """
- Performs all_gather operation on the provided tensors.
- Graph remains connected for backward grad computation.
- """
- # Queue the gathered tensors
- world_size = get_world_size()
- # There is no need for reduction in the single-proc case
- if world_size == 1:
- return tensors
- tensor_list = []
- output_tensor = []
- for tensor in tensors:
- tensor_all = GatherLayer.apply(tensor)
- tensor_list.append(tensor_all)
- for tensor_all in tensor_list:
- output_tensor.append(torch.cat(tensor_all, dim=0))
- return output_tensor
- def unwrap_ddp_if_wrapped(model):
- if isinstance(model, torch.nn.parallel.DistributedDataParallel):
- return model.module
- return model
- def create_new_process_group(group_size):
- """
- Creates process groups of a gives `group_size` and returns
- process group that current GPU participates in.
- `group_size` must divide the total number of GPUs (world_size).
- Modified from
- https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
- Args:
- group_size (int): number of GPU's to collaborate for sync bn
- """
- assert group_size > 0
- world_size = torch.distributed.get_world_size()
- if world_size <= 8:
- if group_size > world_size:
- logging.warning(
- f"Requested group size [{group_size}] > world size [{world_size}]. "
- "Assuming local debug run and capping it to world size."
- )
- group_size = world_size
- assert world_size >= group_size
- assert world_size % group_size == 0
- group = None
- for group_num in range(world_size // group_size):
- group_ids = range(group_num * group_size, (group_num + 1) * group_size)
- cur_group = torch.distributed.new_group(ranks=group_ids)
- if torch.distributed.get_rank() // group_size == group_num:
- group = cur_group
- # can not drop out and return here, every process must go through creation of all subgroups
- assert group is not None
- return group
- def is_dist_avail_and_initialized():
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
- def gather_to_rank_0_via_filesys(data, filesys_save_dir=None):
- """
- Gather any picklable data to rank 0 via filesystem, using all_gather_via_filesys.
- """
- return all_gather_via_filesys(data, filesys_save_dir, gather_to_rank_0_only=True)
|