158 lines
5.0 KiB
Python
158 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Distributed helpers."""
|
|
|
|
import multiprocessing
|
|
import os
|
|
import signal
|
|
import threading
|
|
import traceback
|
|
|
|
import torch
|
|
from pycls.core.config import cfg
|
|
|
|
|
|
def is_master_proc():
|
|
"""Determines if the current process is the master process.
|
|
|
|
Master process is responsible for logging, writing and loading checkpoints. In
|
|
the multi GPU setting, we assign the master role to the rank 0 process. When
|
|
training using a single GPU, there is a single process which is considered master.
|
|
"""
|
|
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
|
|
|
|
|
|
def init_process_group(proc_rank, world_size):
|
|
"""Initializes the default process group."""
|
|
# Set the GPU to use
|
|
torch.cuda.set_device(proc_rank)
|
|
# Initialize the process group
|
|
torch.distributed.init_process_group(
|
|
backend=cfg.DIST_BACKEND,
|
|
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
|
|
world_size=world_size,
|
|
rank=proc_rank,
|
|
)
|
|
|
|
|
|
def destroy_process_group():
|
|
"""Destroys the default process group."""
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
def scaled_all_reduce(tensors):
|
|
"""Performs the scaled all_reduce operation on the provided tensors.
|
|
|
|
The input tensors are modified in-place. Currently supports only the sum
|
|
reduction operator. The reduced values are scaled by the inverse size of the
|
|
process group (equivalent to cfg.NUM_GPUS).
|
|
"""
|
|
# There is no need for reduction in the single-proc case
|
|
if cfg.NUM_GPUS == 1:
|
|
return tensors
|
|
# Queue the reductions
|
|
reductions = []
|
|
for tensor in tensors:
|
|
reduction = torch.distributed.all_reduce(tensor, async_op=True)
|
|
reductions.append(reduction)
|
|
# Wait for reductions to finish
|
|
for reduction in reductions:
|
|
reduction.wait()
|
|
# Scale the results
|
|
for tensor in tensors:
|
|
tensor.mul_(1.0 / cfg.NUM_GPUS)
|
|
return tensors
|
|
|
|
|
|
class ChildException(Exception):
|
|
"""Wraps an exception from a child process."""
|
|
|
|
def __init__(self, child_trace):
|
|
super(ChildException, self).__init__(child_trace)
|
|
|
|
|
|
class ErrorHandler(object):
|
|
"""Multiprocessing error handler (based on fairseq's).
|
|
|
|
Listens for errors in child processes and propagates the tracebacks to the parent.
|
|
"""
|
|
|
|
def __init__(self, error_queue):
|
|
# Shared error queue
|
|
self.error_queue = error_queue
|
|
# Children processes sharing the error queue
|
|
self.children_pids = []
|
|
# Start a thread listening to errors
|
|
self.error_listener = threading.Thread(target=self.listen, daemon=True)
|
|
self.error_listener.start()
|
|
# Register the signal handler
|
|
signal.signal(signal.SIGUSR1, self.signal_handler)
|
|
|
|
def add_child(self, pid):
|
|
"""Registers a child process."""
|
|
self.children_pids.append(pid)
|
|
|
|
def listen(self):
|
|
"""Listens for errors in the error queue."""
|
|
# Wait until there is an error in the queue
|
|
child_trace = self.error_queue.get()
|
|
# Put the error back for the signal handler
|
|
self.error_queue.put(child_trace)
|
|
# Invoke the signal handler
|
|
os.kill(os.getpid(), signal.SIGUSR1)
|
|
|
|
def signal_handler(self, _sig_num, _stack_frame):
|
|
"""Signal handler."""
|
|
# Kill children processes
|
|
for pid in self.children_pids:
|
|
os.kill(pid, signal.SIGINT)
|
|
# Propagate the error from the child process
|
|
raise ChildException(self.error_queue.get())
|
|
|
|
|
|
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
|
|
"""Runs a function from a child process."""
|
|
try:
|
|
# Initialize the process group
|
|
init_process_group(proc_rank, world_size)
|
|
# Run the function
|
|
fun(*fun_args, **fun_kwargs)
|
|
except KeyboardInterrupt:
|
|
# Killed by the parent process
|
|
pass
|
|
except Exception:
|
|
# Propagate exception to the parent process
|
|
error_queue.put(traceback.format_exc())
|
|
finally:
|
|
# Destroy the process group
|
|
destroy_process_group()
|
|
|
|
|
|
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
|
|
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
|
|
# There is no need for multi-proc in the single-proc case
|
|
fun_kwargs = fun_kwargs if fun_kwargs else {}
|
|
if num_proc == 1:
|
|
fun(*fun_args, **fun_kwargs)
|
|
return
|
|
# Handle errors from training subprocesses
|
|
error_queue = multiprocessing.SimpleQueue()
|
|
error_handler = ErrorHandler(error_queue)
|
|
# Run each training subprocess
|
|
ps = []
|
|
for i in range(num_proc):
|
|
p_i = multiprocessing.Process(
|
|
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
|
|
)
|
|
ps.append(p_i)
|
|
p_i.start()
|
|
error_handler.add_child(p_i.pid)
|
|
# Wait for each subprocess to finish
|
|
for p in ps:
|
|
p.join()
|