|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import queue
|
| import collections
|
| import threading
|
|
|
| __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
|
|
|
|
| class FutureResult(object):
|
| """A thread-safe future implementation. Used only as one-to-one pipe."""
|
|
|
| def __init__(self):
|
| self._result = None
|
| self._lock = threading.Lock()
|
| self._cond = threading.Condition(self._lock)
|
|
|
| def put(self, result):
|
| with self._lock:
|
| assert self._result is None, 'Previous result has\'t been fetched.'
|
| self._result = result
|
| self._cond.notify()
|
|
|
| def get(self):
|
| with self._lock:
|
| if self._result is None:
|
| self._cond.wait()
|
|
|
| res = self._result
|
| self._result = None
|
| return res
|
|
|
|
|
| _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
| _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
|
|
|
|
| class SlavePipe(_SlavePipeBase):
|
| """Pipe for master-slave communication."""
|
|
|
| def run_slave(self, msg):
|
| self.queue.put((self.identifier, msg))
|
| ret = self.result.get()
|
| self.queue.put(True)
|
| return ret
|
|
|
|
|
| class SyncMaster(object):
|
| """An abstract `SyncMaster` object.
|
|
|
| - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
| call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
| - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
| and passed to a registered callback.
|
| - After receiving the messages, the master device should gather the information and determine to message passed
|
| back to each slave devices.
|
| """
|
|
|
| def __init__(self, master_callback):
|
| """
|
|
|
| Args:
|
| master_callback: a callback to be invoked after having collected messages from slave devices.
|
| """
|
| self._master_callback = master_callback
|
| self._queue = queue.Queue()
|
| self._registry = collections.OrderedDict()
|
| self._activated = False
|
|
|
| def __getstate__(self):
|
| return {'master_callback': self._master_callback}
|
|
|
| def __setstate__(self, state):
|
| self.__init__(state['master_callback'])
|
|
|
| def register_slave(self, identifier):
|
| """
|
| Register an slave device.
|
|
|
| Args:
|
| identifier: an identifier, usually is the device id.
|
|
|
| Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
|
|
| """
|
| if self._activated:
|
| assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
| self._activated = False
|
| self._registry.clear()
|
| future = FutureResult()
|
| self._registry[identifier] = _MasterRegistry(future)
|
| return SlavePipe(identifier, self._queue, future)
|
|
|
| def run_master(self, master_msg):
|
| """
|
| Main entry for the master device in each forward pass.
|
| The messages were first collected from each devices (including the master device), and then
|
| an callback will be invoked to compute the message to be sent back to each devices
|
| (including the master device).
|
|
|
| Args:
|
| master_msg: the message that the master want to send to itself. This will be placed as the first
|
| message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
|
|
| Returns: the message to be sent back to the master device.
|
|
|
| """
|
| self._activated = True
|
|
|
| intermediates = [(0, master_msg)]
|
| for i in range(self.nr_slaves):
|
| intermediates.append(self._queue.get())
|
|
|
| results = self._master_callback(intermediates)
|
| assert results[0][0] == 0, 'The first result should belongs to the master.'
|
|
|
| for i, res in results:
|
| if i == 0:
|
| continue
|
| self._registry[i].result.put(res)
|
|
|
| for i in range(self.nr_slaves):
|
| assert self._queue.get() is True
|
|
|
| return results[0][1]
|
|
|
| @property
|
| def nr_slaves(self):
|
| return len(self._registry)
|
|
|