| | import contextlib |
| | import functools |
| | import os |
| | from typing import Callable, Dict, Optional |
| |
|
| | import torch |
| |
|
| | from loguru import logger |
| |
|
| | """ |
| | Usage: |
| | |
| | 1. Control through environment variable (at startup): |
| | export TORCH_COMPILE_ENABLE=true |
| | python your_script.py |
| | |
| | 2. Control through environment variable (disable): |
| | export TORCH_COMPILE_ENABLE=false # or not set |
| | python your_script.py |
| | |
| | 3. Dynamically control in code: |
| | compile_manager.set_compile_enabled(True) # enable |
| | compile_manager.set_compile_enabled(False) # disable |
| | |
| | 4. Select version at runtime: |
| | # use the version configured |
| | result = my_function(args) |
| | |
| | # force use the original version |
| | result = my_function.original(args) |
| | |
| | # force use the compiled version |
| | result = my_function.compiled(args) |
| | """ |
| |
|
| | |
| | |
| | ENABLE_TORCH_COMPILE = os.getenv("ENABLE_TORCH_COMPILE", "false").lower() == "true" |
| |
|
| |
|
| | class CompileManager: |
| | """Global controller for torch.compile""" |
| |
|
| | def __init__(self): |
| | self.compile_enabled = ENABLE_TORCH_COMPILE |
| | self.compiled_functions: Dict[str, Callable] = {} |
| | self.original_functions: Dict[str, Callable] = {} |
| |
|
| | def set_compile_enabled(self, enabled: bool): |
| | """Dynamic setting of whether to enable compile""" |
| | self.compile_enabled = enabled |
| |
|
| | def get_compile_status(self): |
| | """Get the current compile status""" |
| | return self.compile_enabled |
| |
|
| | @contextlib.contextmanager |
| | def compile_disabled(self): |
| | """Temporarily disable compile within the context""" |
| | original_status = self.compile_enabled |
| | try: |
| | self.compile_enabled = False |
| | yield |
| | finally: |
| | self.compile_enabled = original_status |
| |
|
| |
|
| | |
| | compile_manager = CompileManager() |
| |
|
| |
|
| | def smart_compile(func: Optional[Callable] = None, **compile_kwargs): |
| | """ |
| | Smart compile decorator |
| | |
| | Args: |
| | func: The function to decorate |
| | **compile_kwargs: Other compile parameters, see https://pytorch.org/docs/stable/generated/torch.compile.html |
| | """ |
| |
|
| | def decorator(fn: Callable) -> Callable: |
| | |
| | original_func = fn |
| | |
| | |
| | func_name = f"{fn.__module__}.{fn.__qualname__}" |
| | compile_manager.original_functions[func_name] = original_func |
| |
|
| | |
| | if not compile_manager.compile_enabled: |
| | |
| | original_func.original = original_func |
| | original_func.compiled = original_func |
| | return original_func |
| |
|
| | |
| | try: |
| | compiled_func = torch.compile(original_func, **compile_kwargs) |
| | compile_manager.compiled_functions[func_name] = compiled_func |
| | except Exception as e: |
| | logger.warning(f"[WARNING] Failed to compile function {func_name}: {e}") |
| | |
| | compiled_func = original_func |
| |
|
| | @functools.wraps(original_func) |
| | def wrapper(*args, **kwargs): |
| | |
| | if compile_manager.compile_enabled: |
| | return compiled_func(*args, **kwargs) |
| | else: |
| | return original_func(*args, **kwargs) |
| |
|
| | |
| | wrapper.original = original_func |
| | wrapper.compiled = compiled_func |
| |
|
| | return wrapper |
| |
|
| | |
| | if func is not None: |
| | return decorator(func) |
| | return decorator |
| |
|