| |
| |
|
|
| import gc |
| import logging |
| import os |
|
|
| import composer |
| import pytest |
| import torch |
| from composer.devices import DeviceCPU, DeviceGPU |
| from composer.utils import dist, reproducibility |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def clear_cuda_cache(request: pytest.FixtureRequest): |
| """Clear memory between GPU tests.""" |
| marker = request.node.get_closest_marker('gpu') |
| if marker is not None and torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def reset_mlflow_tracking_dir(): |
| """Reset MLFlow tracking dir so it doesn't persist across tests.""" |
| try: |
| import mlflow |
| mlflow.set_tracking_uri(None) |
| except ModuleNotFoundError: |
| |
| pass |
|
|
|
|
| @pytest.fixture(scope='session') |
| def cleanup_dist(): |
| """Ensure all dist tests clean up resources properly.""" |
| yield |
| |
| |
| dist.barrier() |
|
|
|
|
| @pytest.fixture(autouse=True, scope='session') |
| def configure_dist(request: pytest.FixtureRequest): |
| |
| |
| |
|
|
| if dist.get_world_size() == 1: |
| return |
|
|
| device = None |
|
|
| for item in request.session.items: |
| device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU() |
| break |
|
|
| assert device is not None |
|
|
| if not dist.is_initialized(): |
| dist.initialize_dist(device, timeout=300.0) |
| |
| |
| |
| dist.barrier() |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def set_log_levels(): |
| """Ensures all log levels are set to DEBUG.""" |
| logging.basicConfig() |
| logging.getLogger(composer.__name__).setLevel(logging.DEBUG) |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch): |
| """Monkeypatch reproducibility. |
| |
| Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local |
| seed. |
| """ |
| monkeypatch.setattr( |
| reproducibility, |
| 'get_random_seed', |
| lambda: rank_zero_seed, |
| ) |
| reproducibility.seed_all(rank_zero_seed + dist.get_global_rank()) |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def remove_run_name_env_var(): |
| |
| composer_run_name = os.environ.get('COMPOSER_RUN_NAME') |
| run_name = os.environ.get('RUN_NAME') |
|
|
| if 'COMPOSER_RUN_NAME' in os.environ: |
| del os.environ['COMPOSER_RUN_NAME'] |
| if 'RUN_NAME' in os.environ: |
| del os.environ['RUN_NAME'] |
|
|
| yield |
|
|
| if composer_run_name is not None: |
| os.environ['COMPOSER_RUN_NAME'] = composer_run_name |
| if run_name is not None: |
| os.environ['RUN_NAME'] = run_name |
|
|