import os import socket import torch import torch.distributed as dist from torch.multiprocessing import Process def run(rank, size, hostname, gpu, ngpus_per_node): print(f"I am {rank} of {size} in {hostname}") group = dist.new_group([0, 1,2,3]) tensor = torch.ones(1).cuda() dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group) print('Rank ', rank, ' has data ', tensor[0]) def init_processes(rank, size, gpu,ngpus_per_node, hostname, fn, backend='mpi'): """ Initialize the distributed environment. """ torch.cuda.set_device(gpu) dist.init_process_group(backend, rank=rank, world_size=size) fn(rank, size, hostname, gpu, ngpus_per_node) if __name__ == "__main__": world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) gpu= int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) ngpus_per_node=torch.cuda.device_count() hostname = socket.gethostname() init_processes(world_rank, world_size, gpu, ngpus_per_node, hostname, run, backend='mpi')