| |
| import asyncio |
| import contextlib |
| import logging |
| import os |
| import time |
| from typing import List |
|
|
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
| DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False)) |
|
|
|
|
| @contextlib.asynccontextmanager |
| async def completed(trace_name='', |
| name='', |
| sleep_interval=0.05, |
| streams: List[torch.cuda.Stream] = None): |
| """Async context manager that waits for work to complete on given CUDA |
| streams.""" |
| if not torch.cuda.is_available(): |
| yield |
| return |
|
|
| stream_before_context_switch = torch.cuda.current_stream() |
| if not streams: |
| streams = [stream_before_context_switch] |
| else: |
| streams = [s if s else stream_before_context_switch for s in streams] |
|
|
| end_events = [ |
| torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams |
| ] |
|
|
| if DEBUG_COMPLETED_TIME: |
| start = torch.cuda.Event(enable_timing=True) |
| stream_before_context_switch.record_event(start) |
|
|
| cpu_start = time.monotonic() |
| logger.debug('%s %s starting, streams: %s', trace_name, name, streams) |
| grad_enabled_before = torch.is_grad_enabled() |
| try: |
| yield |
| finally: |
| current_stream = torch.cuda.current_stream() |
| assert current_stream == stream_before_context_switch |
|
|
| if DEBUG_COMPLETED_TIME: |
| cpu_end = time.monotonic() |
| for i, stream in enumerate(streams): |
| event = end_events[i] |
| stream.record_event(event) |
|
|
| grad_enabled_after = torch.is_grad_enabled() |
|
|
| |
| |
| assert (grad_enabled_before == grad_enabled_after |
| ), 'Unexpected is_grad_enabled() value change' |
|
|
| are_done = [e.query() for e in end_events] |
| logger.debug('%s %s completed: %s streams: %s', trace_name, name, |
| are_done, streams) |
| with torch.cuda.stream(stream_before_context_switch): |
| while not all(are_done): |
| await asyncio.sleep(sleep_interval) |
| are_done = [e.query() for e in end_events] |
| logger.debug( |
| '%s %s completed: %s streams: %s', |
| trace_name, |
| name, |
| are_done, |
| streams, |
| ) |
|
|
| current_stream = torch.cuda.current_stream() |
| assert current_stream == stream_before_context_switch |
|
|
| if DEBUG_COMPLETED_TIME: |
| cpu_time = (cpu_end - cpu_start) * 1000 |
| stream_times_ms = '' |
| for i, stream in enumerate(streams): |
| elapsed_time = start.elapsed_time(end_events[i]) |
| stream_times_ms += f' {stream} {elapsed_time:.2f} ms' |
| logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time, |
| stream_times_ms) |
|
|
|
|
| @contextlib.asynccontextmanager |
| async def concurrent(streamqueue: asyncio.Queue, |
| trace_name='concurrent', |
| name='stream'): |
| """Run code concurrently in different streams. |
| |
| :param streamqueue: asyncio.Queue instance. |
| |
| Queue tasks define the pool of streams used for concurrent execution. |
| """ |
| if not torch.cuda.is_available(): |
| yield |
| return |
|
|
| initial_stream = torch.cuda.current_stream() |
|
|
| with torch.cuda.stream(initial_stream): |
| stream = await streamqueue.get() |
| assert isinstance(stream, torch.cuda.Stream) |
|
|
| try: |
| with torch.cuda.stream(stream): |
| logger.debug('%s %s is starting, stream: %s', trace_name, name, |
| stream) |
| yield |
| current = torch.cuda.current_stream() |
| assert current == stream |
| logger.debug('%s %s has finished, stream: %s', trace_name, |
| name, stream) |
| finally: |
| streamqueue.task_done() |
| streamqueue.put_nowait(stream) |
|
|