| import torch | |
| from torch.utils.checkpoint import get_device_states, set_device_states | |
| class RandContext: | |
| def __init__(self, *tensors): | |
| self.fwd_cpu_state = torch.get_rng_state() | |
| self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) | |
| def __enter__(self): | |
| self._fork = torch.random.fork_rng( | |
| devices=self.fwd_gpu_devices, | |
| enabled=True | |
| ) | |
| self._fork.__enter__() | |
| torch.set_rng_state(self.fwd_cpu_state) | |
| set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self._fork.__exit__(exc_type, exc_val, exc_tb) | |
| self._fork = None | |