| """ |
| Streaming module API that should be implemented by all Streaming components, |
| """ |
|
|
| from contextlib import contextmanager |
| import typing as tp |
| from torch import nn |
| import torch |
|
|
|
|
| State = tp.Dict[str, torch.Tensor] |
|
|
| class StreamingModule(nn.Module): |
| """Common API for streaming components. |
| |
| Each streaming component has a streaming state, which is just a dict[str, Tensor]. |
| By convention, the first dim of each tensor must be the batch size. |
| Don't use dots in the key names, as this would clash with submodules |
| (like in state_dict). |
| |
| If `self._is_streaming` is True, the component should use and remember |
| the proper state inside `self._streaming_state`. |
| |
| To set a streaming component in streaming state, use |
| |
| with module.streaming(): |
| ... |
| |
| This will automatically reset the streaming state when exiting the context manager. |
| This also automatically propagates to all streaming children module. |
| |
| Some module might also implement the `StreamingModule.flush` method, although |
| this one is trickier, as all parents module must be StreamingModule and implement |
| it as well for it to work properly. See `StreamingSequential` after. |
| """ |
| def __init__(self) -> None: |
| super().__init__() |
| self._streaming_state: State = {} |
| self._is_streaming = False |
|
|
| def _apply_named_streaming(self, fn: tp.Any): |
| for name, module in self.named_modules(): |
| if isinstance(module, StreamingModule): |
| fn(name, module) |
|
|
| def _set_streaming(self, streaming: bool): |
| def _set_streaming(name, module): |
| module._is_streaming = streaming |
| self._apply_named_streaming(_set_streaming) |
|
|
| @contextmanager |
| def streaming(self): |
| """Context manager to enter streaming mode. Reset streaming state on exit.""" |
| self._set_streaming(True) |
| try: |
| yield |
| finally: |
| self._set_streaming(False) |
| self.reset_streaming() |
|
|
| def reset_streaming(self): |
| """Reset the streaming state.""" |
| def _reset(name: str, module: StreamingModule): |
| module._streaming_state.clear() |
|
|
| self._apply_named_streaming(_reset) |
|
|
| def get_streaming_state(self) -> State: |
| """Return the streaming state, including that of sub-modules.""" |
| state: State = {} |
|
|
| def _add(name: str, module: StreamingModule): |
| if name: |
| name += "." |
| for key, value in module._streaming_state.items(): |
| state[name + key] = value |
|
|
| self._apply_named_streaming(_add) |
| return state |
|
|
| def set_streaming_state(self, state: State): |
| """Set the streaming state, including that of sub-modules.""" |
| state = dict(state) |
|
|
| def _set(name: str, module: StreamingModule): |
| if name: |
| name += "." |
| module._streaming_state.clear() |
| for key, value in list(state.items()): |
| |
| if key.startswith(name): |
| local_key = key[len(name):] |
| if '.' not in local_key: |
| module._streaming_state[local_key] = value |
| del state[key] |
|
|
| self._apply_named_streaming(_set) |
| assert len(state) == 0, list(state.keys()) |
|
|
| def flush(self, x: tp.Optional[torch.Tensor] = None): |
| """Flush any remaining outputs that were waiting for completion. |
| Typically, for convolutions, this will add the final padding |
| and process the last buffer. |
| |
| This should take an optional argument `x`, which will be provided |
| if a module before this one in the streaming pipeline has already |
| spitted out a flushed out buffer. |
| """ |
| if x is None: |
| return None |
| else: |
| return self(x) |