| | |
| | import importlib.util |
| | import os |
| | import subprocess |
| | import sys |
| | from typing import Dict, List, Optional |
| |
|
| | from swift.utils import get_logger |
| |
|
| | logger = get_logger() |
| |
|
| | ROUTE_MAPPING: Dict[str, str] = { |
| | 'pt': 'swift.cli.pt', |
| | 'sft': 'swift.cli.sft', |
| | 'infer': 'swift.cli.infer', |
| | 'merge-lora': 'swift.cli.merge_lora', |
| | 'web-ui': 'swift.cli.web_ui', |
| | 'deploy': 'swift.cli.deploy', |
| | 'rollout': 'swift.cli.rollout', |
| | 'rlhf': 'swift.cli.rlhf', |
| | 'sample': 'swift.cli.sample', |
| | 'export': 'swift.cli.export', |
| | 'eval': 'swift.cli.eval', |
| | 'app': 'swift.cli.app', |
| | } |
| |
|
| |
|
| | def use_torchrun() -> bool: |
| | nproc_per_node = os.getenv('NPROC_PER_NODE') |
| | nnodes = os.getenv('NNODES') |
| | if nproc_per_node is None and nnodes is None: |
| | return False |
| | return True |
| |
|
| |
|
| | def get_torchrun_args() -> Optional[List[str]]: |
| | if not use_torchrun(): |
| | return |
| | torchrun_args = [] |
| | for env_key in ['NPROC_PER_NODE', 'MASTER_PORT', 'NNODES', 'NODE_RANK', 'MASTER_ADDR']: |
| | env_val = os.getenv(env_key) |
| | if env_val is None: |
| | continue |
| | torchrun_args += [f'--{env_key.lower()}', env_val] |
| | return torchrun_args |
| |
|
| |
|
| | def _compat_web_ui(argv): |
| | |
| | method_name = argv[0] |
| | if method_name in {'web-ui', 'web_ui'} and ('--model' in argv or '--adapters' in argv or '--ckpt_dir' in argv): |
| | argv[0] = 'app' |
| | logger.warning('Please use `swift app`.') |
| |
|
| |
|
| | def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: |
| | route_mapping = route_mapping or ROUTE_MAPPING |
| | argv = sys.argv[1:] |
| | _compat_web_ui(argv) |
| | method_name = argv[0].replace('_', '-') |
| | argv = argv[1:] |
| | file_path = importlib.util.find_spec(route_mapping[method_name]).origin |
| | torchrun_args = get_torchrun_args() |
| | python_cmd = sys.executable |
| | if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}: |
| | args = [python_cmd, file_path, *argv] |
| | else: |
| | args = [python_cmd, '-m', 'torch.distributed.run', *torchrun_args, file_path, *argv] |
| | print(f"run sh: `{' '.join(args)}`", flush=True) |
| | result = subprocess.run(args) |
| | if result.returncode != 0: |
| | sys.exit(result.returncode) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | cli_main() |
| |
|