directionality_probe / protify /seed_utils.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
"""
Global seed management utilities for reproducible experiments.
This module provides a centralized way to set random seeds across all
random number generators used in the platform (torch, numpy, scikit-learn, random).
"""
import os
import time
import random
import numpy as np
from typing import Optional
# Global variable to store the current seed
_GLOBAL_SEED: Optional[int] = None
def get_global_seed() -> Optional[int]:
"""
Get the currently set global seed.
Returns:
The current global seed value, or None if not set.
"""
return _GLOBAL_SEED
def set_cublas_workspace_config():
"""Set CUBLAS workspace config to an allowed deterministic value.
Must be set BEFORE importing torch. Valid values (per NVIDIA docs):
- ":4096:8" (recommended)
- ":16:8" (minimal workspace)
"""
# Only set if not already provided by the environment/user
if "CUBLAS_WORKSPACE_CONFIG" not in os.environ:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def seed_worker(worker_id: int):
"""Use with torch.utils.data.DataLoader(worker_init_fn=seed_worker) to sync NumPy/random per-worker."""
import torch
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def dataloader_generator(seed: Optional[int]):
"""
Use with torch.utils.data.DataLoader(generator=dataloader_generator(seed)) to sync NumPy/random per-worker.
"""
import torch
if seed is None:
seed = set_global_seed()
g = torch.Generator()
g.manual_seed(seed)
return g
def set_global_seed(seed: Optional[int] = None) -> int:
"""
Set the global random seed for all random number generators.
This function sets seeds for:
- Python's random module
- NumPy
- PyTorch
Args:
seed: The seed value to use. If None, uses current timestamp.
Returns:
The seed value that was set.
"""
# Generate seed from current time if not provided
if seed is None:
seed = int(time.time() * 1000000) % (2**31)
# Store the global seed
global _GLOBAL_SEED
_GLOBAL_SEED = seed
random.seed(seed)
np.random.seed(seed)
# Import torch lazily to avoid initializing CUDA before env is set elsewhere
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # For multi-GPU setups
return seed
def set_determinism():
# set_cublas_workspace_config() must happen BEFORE importing torch
#set_cublas_workspace_config()
# Import torch only after the env var has been set
import torch
# Set deterministic behavior for reproducibility
# Note: This can significantly slow down operations. Only use if you need to be 100% reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if hasattr(torch, 'use_deterministic_algorithms'):
try:
torch.use_deterministic_algorithms(True, warn_only=False)
except Exception as e:
print(f'torch.use_deterministic_algorithms is not available: {e}')
# print torch version
print(f'torch version: {torch.__version__}')
print('Make sure you are using the correct version of torch')