| | |
| | import asyncio |
| | import contextlib |
| | import logging |
| | import os |
| | import time |
| | from typing import List |
| |
|
| | import torch |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False)) |
| |
|
| |
|
| | @contextlib.asynccontextmanager |
| | async def completed(trace_name='', |
| | name='', |
| | sleep_interval=0.05, |
| | streams: List[torch.cuda.Stream] = None): |
| | """Async context manager that waits for work to complete on given CUDA |
| | streams.""" |
| | if not torch.cuda.is_available(): |
| | yield |
| | return |
| |
|
| | stream_before_context_switch = torch.cuda.current_stream() |
| | if not streams: |
| | streams = [stream_before_context_switch] |
| | else: |
| | streams = [s if s else stream_before_context_switch for s in streams] |
| |
|
| | end_events = [ |
| | torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams |
| | ] |
| |
|
| | if DEBUG_COMPLETED_TIME: |
| | start = torch.cuda.Event(enable_timing=True) |
| | stream_before_context_switch.record_event(start) |
| |
|
| | cpu_start = time.monotonic() |
| | logger.debug('%s %s starting, streams: %s', trace_name, name, streams) |
| | grad_enabled_before = torch.is_grad_enabled() |
| | try: |
| | yield |
| | finally: |
| | current_stream = torch.cuda.current_stream() |
| | assert current_stream == stream_before_context_switch |
| |
|
| | if DEBUG_COMPLETED_TIME: |
| | cpu_end = time.monotonic() |
| | for i, stream in enumerate(streams): |
| | event = end_events[i] |
| | stream.record_event(event) |
| |
|
| | grad_enabled_after = torch.is_grad_enabled() |
| |
|
| | |
| | |
| | assert (grad_enabled_before == grad_enabled_after |
| | ), 'Unexpected is_grad_enabled() value change' |
| |
|
| | are_done = [e.query() for e in end_events] |
| | logger.debug('%s %s completed: %s streams: %s', trace_name, name, |
| | are_done, streams) |
| | with torch.cuda.stream(stream_before_context_switch): |
| | while not all(are_done): |
| | await asyncio.sleep(sleep_interval) |
| | are_done = [e.query() for e in end_events] |
| | logger.debug( |
| | '%s %s completed: %s streams: %s', |
| | trace_name, |
| | name, |
| | are_done, |
| | streams, |
| | ) |
| |
|
| | current_stream = torch.cuda.current_stream() |
| | assert current_stream == stream_before_context_switch |
| |
|
| | if DEBUG_COMPLETED_TIME: |
| | cpu_time = (cpu_end - cpu_start) * 1000 |
| | stream_times_ms = '' |
| | for i, stream in enumerate(streams): |
| | elapsed_time = start.elapsed_time(end_events[i]) |
| | stream_times_ms += f' {stream} {elapsed_time:.2f} ms' |
| | logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time, |
| | stream_times_ms) |
| |
|
| |
|
| | @contextlib.asynccontextmanager |
| | async def concurrent(streamqueue: asyncio.Queue, |
| | trace_name='concurrent', |
| | name='stream'): |
| | """Run code concurrently in different streams. |
| | |
| | :param streamqueue: asyncio.Queue instance. |
| | |
| | Queue tasks define the pool of streams used for concurrent execution. |
| | """ |
| | if not torch.cuda.is_available(): |
| | yield |
| | return |
| |
|
| | initial_stream = torch.cuda.current_stream() |
| |
|
| | with torch.cuda.stream(initial_stream): |
| | stream = await streamqueue.get() |
| | assert isinstance(stream, torch.cuda.Stream) |
| |
|
| | try: |
| | with torch.cuda.stream(stream): |
| | logger.debug('%s %s is starting, stream: %s', trace_name, name, |
| | stream) |
| | yield |
| | current = torch.cuda.current_stream() |
| | assert current == stream |
| | logger.debug('%s %s has finished, stream: %s', trace_name, |
| | name, stream) |
| | finally: |
| | streamqueue.task_done() |
| | streamqueue.put_nowait(stream) |
| |
|