| """ |
| Usage: |
| Training: |
| python train.py --config-name=train_diffusion_lowdim_workspace -- logger.mode=online |
| """ |
| import os |
| import ray |
| import click |
|
|
| def worker_fn(command_args, data_src=None, unbuffer_python=False, use_shell=False): |
| import os |
| import subprocess |
| import signal |
| import time |
|
|
| |
| if data_src is not None: |
| cwd = os.getcwd() |
| src = os.path.expanduser(data_src) |
| dst = os.path.join(cwd, 'data') |
| try: |
| os.symlink(src=src, dst=dst) |
| except FileExistsError: |
| |
| pass |
|
|
| |
| process_env = os.environ.copy() |
| if unbuffer_python: |
| |
| |
| process_env['PYTHONUNBUFFERED'] = 'TRUE' |
| |
| |
| |
| def preexec_function(): |
| import signal |
| signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGINT}) |
| |
| if use_shell: |
| command_args = ' '.join(command_args) |
|
|
| |
| process = subprocess.Popen( |
| args=command_args, |
| env=process_env, |
| preexec_fn=preexec_function, |
| shell=use_shell) |
|
|
| while process.poll() is None: |
| try: |
| |
| |
| time.sleep(0.01) |
| except KeyboardInterrupt: |
| process.send_signal(signal.SIGINT) |
| print('SIGINT sent to subprocess') |
| except Exception as e: |
| process.terminate() |
| raise e |
|
|
| if process.returncode not in (0, -2): |
| print("Failed execution!") |
| raise RuntimeError("Failed execution.") |
| return process.returncode |
|
|
|
|
| @click.command() |
| @click.option('--ray_address', '-ra', default='auto') |
| @click.option('--num_cpus', '-nc', default=7, type=float) |
| @click.option('--num_gpus', '-ng', default=1, type=float) |
| @click.option('--max_retries', '-mr', default=0, type=int) |
| @click.option('--data_src', '-d', default='./data', type=str) |
| @click.option('--unbuffer_python', '-u', is_flag=True, default=False) |
| @click.argument('command_args', nargs=-1, type=str) |
| def main(ray_address, |
| num_cpus, num_gpus, max_retries, |
| data_src, unbuffer_python, |
| command_args): |
| |
| if data_src is not None: |
| data_src = os.path.abspath(os.path.expanduser(data_src)) |
|
|
| |
| root_dir = os.path.dirname(__file__) |
| runtime_env = { |
| 'working_dir': root_dir, |
| 'excludes': ['.git'] |
| } |
| ray.init( |
| address=ray_address, |
| runtime_env=runtime_env |
| ) |
| |
| worker_ray = ray.remote(worker_fn).options( |
| num_cpus=num_cpus, |
| num_gpus=num_gpus, |
| max_retries=max_retries, |
| |
| retry_exceptions=True |
| ) |
| |
| task_ref = worker_ray.remote(command_args, data_src, unbuffer_python) |
|
|
| try: |
| |
| result = ray.get(task_ref) |
| print('Return code: ', result) |
| except KeyboardInterrupt: |
| |
| ray.cancel(task_ref, force=False) |
| result = ray.get(task_ref) |
| print('Return code: ', result) |
| except Exception as e: |
| |
| ray.cancel(task_ref, force=True) |
| raise e |
| |
|
|
| if __name__ == '__main__': |
| main() |
|
|