#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import socket
from typing import Optional
import torch
import torch.distributed as dist
from utils import logger
[docs]def is_master(opts) -> bool:
node_rank = getattr(opts, "ddp.rank", 0)
return node_rank == 0
[docs]def dist_barrier():
dist.barrier()
[docs]def dist_monitored_barrier(
timeout: Optional[float] = None,
wait_all_ranks: Optional[bool] = False,
group: Optional = None,
):
dist.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
[docs]def is_start_rank_node(opts) -> bool:
node_rank = getattr(opts, "ddp.rank", 0)
def_rank = getattr(opts, "ddp.start_rank", 0)
return node_rank == def_rank
[docs]def get_world_size():
return dist.get_world_size()
[docs]def get_node_rank():
return dist.get_rank()
[docs]def distributed_init(opts) -> int:
ddp_url = getattr(opts, "ddp.dist_url", None)
is_master_node = is_master(opts)
if ddp_url is None:
ddp_port = getattr(opts, "ddp.dist_port", 6006)
hostname = socket.gethostname()
ddp_url = "tcp://{}:{}".format(hostname, ddp_port)
setattr(opts, "ddp.dist_url", ddp_url)
node_rank = getattr(opts, "ddp.rank", 0)
world_size = getattr(opts, "ddp.world_size", 0)
if torch.distributed.is_initialized():
logger.warning("DDP is already initialized and cannot be initialize twice!")
else:
logger.info("distributed init (rank {}): {}".format(node_rank, ddp_url))
dist_backend = getattr(opts, "ddp.backend", "nccl") # "gloo"
if dist_backend is None and dist.is_nccl_available():
dist_backend = "nccl"
if is_master_node:
logger.log(
"Using NCCL as distributed backend with version={}".format(
torch.cuda.nccl.version()
)
)
elif dist_backend is None:
dist_backend = "gloo"
dist.init_process_group(
backend=dist_backend,
init_method=ddp_url,
world_size=world_size,
rank=node_rank,
)
# perform a dummy all-reduce to initialize the NCCL communicator
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())
node_rank = torch.distributed.get_rank()
setattr(opts, "ddp.rank", node_rank)
return node_rank