| |
| import importlib.util |
| import os |
| import subprocess |
| import sys |
| from typing import Dict, List, Optional |
|
|
| from swift.utils import get_logger |
|
|
| logger = get_logger() |
|
|
| ROUTE_MAPPING: Dict[str, str] = { |
| 'pt': 'swift.cli.pt', |
| 'sft': 'swift.cli.sft', |
| 'infer': 'swift.cli.infer', |
| 'merge-lora': 'swift.cli.merge_lora', |
| 'web-ui': 'swift.cli.web_ui', |
| 'deploy': 'swift.cli.deploy', |
| 'rollout': 'swift.cli.rollout', |
| 'rlhf': 'swift.cli.rlhf', |
| 'sample': 'swift.cli.sample', |
| 'export': 'swift.cli.export', |
| 'eval': 'swift.cli.eval', |
| 'app': 'swift.cli.app', |
| } |
|
|
|
|
| def use_torchrun() -> bool: |
| nproc_per_node = os.getenv('NPROC_PER_NODE') |
| nnodes = os.getenv('NNODES') |
| if nproc_per_node is None and nnodes is None: |
| return False |
| return True |
|
|
|
|
| def get_torchrun_args() -> Optional[List[str]]: |
| if not use_torchrun(): |
| return |
| torchrun_args = [] |
| for env_key in ['NPROC_PER_NODE', 'MASTER_PORT', 'NNODES', 'NODE_RANK', 'MASTER_ADDR']: |
| env_val = os.getenv(env_key) |
| if env_val is None: |
| continue |
| torchrun_args += [f'--{env_key.lower()}', env_val] |
| return torchrun_args |
|
|
|
|
| def _compat_web_ui(argv): |
| |
| method_name = argv[0] |
| if method_name in {'web-ui', 'web_ui'} and ('--model' in argv or '--adapters' in argv or '--ckpt_dir' in argv): |
| argv[0] = 'app' |
| logger.warning('Please use `swift app`.') |
|
|
|
|
| def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: |
| route_mapping = route_mapping or ROUTE_MAPPING |
| argv = sys.argv[1:] |
| _compat_web_ui(argv) |
| method_name = argv[0].replace('_', '-') |
| argv = argv[1:] |
| file_path = importlib.util.find_spec(route_mapping[method_name]).origin |
| torchrun_args = get_torchrun_args() |
| python_cmd = sys.executable |
| if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}: |
| args = [python_cmd, file_path, *argv] |
| else: |
| args = [python_cmd, '-m', 'torch.distributed.run', *torchrun_args, file_path, *argv] |
| print(f"run sh: `{' '.join(args)}`", flush=True) |
| result = subprocess.run(args) |
| if result.returncode != 0: |
| sys.exit(result.returncode) |
|
|
|
|
| if __name__ == '__main__': |
| cli_main() |
|
|