| """
|
| Tensor operations for distributed computing.
|
| """
|
| import torch
|
| import numpy as np
|
| from typing import Dict, List, Optional, Union, Tuple
|
|
|
| class TensorOps:
|
| """Utility class for distributed tensor operations."""
|
|
|
| @staticmethod
|
| def split_tensor(tensor: torch.Tensor, num_parts: int) -> List[torch.Tensor]:
|
| """Split a tensor into multiple parts for distributed processing."""
|
| return torch.chunk(tensor, num_parts)
|
|
|
| @staticmethod
|
| def merge_tensors(tensors: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
| """Merge multiple tensors back into a single tensor."""
|
| return torch.cat(tensors, dim=dim)
|
|
|
| @staticmethod
|
| def average_gradients(gradients: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| """Average gradients from multiple workers."""
|
| avg_gradients = {}
|
| for key in gradients[0].keys():
|
| avg_gradients[key] = torch.mean(torch.stack([g[key] for g in gradients]), dim=0)
|
| return avg_gradients
|
|
|
| @staticmethod
|
| def serialize_tensor(tensor: torch.Tensor) -> Dict[str, Union[List, str]]:
|
| """Serialize a tensor for storage/transmission."""
|
| return {
|
| 'data': tensor.cpu().numpy().tolist(),
|
| 'shape': list(tensor.shape),
|
| 'dtype': str(tensor.dtype)
|
| }
|
|
|
| @staticmethod
|
| def deserialize_tensor(tensor_dict: Dict[str, Union[List, str]]) -> torch.Tensor:
|
| """Deserialize a tensor from storage/transmission format."""
|
| data = np.array(tensor_dict['data'])
|
| shape = tensor_dict['shape']
|
| dtype = getattr(torch, tensor_dict['dtype'].split('.')[-1])
|
| return torch.tensor(data, dtype=dtype).reshape(shape)
|
|
|
| @staticmethod
|
| def gradient_clipping(gradients: Dict[str, torch.Tensor], max_norm: float) -> Dict[str, torch.Tensor]:
|
| """Apply gradient clipping to prevent exploding gradients."""
|
| for k, v in gradients.items():
|
| if v is not None:
|
| torch.nn.utils.clip_grad_norm_(v, max_norm)
|
| return gradients
|
|
|
| @staticmethod
|
| def reduce_precision(tensor: torch.Tensor, bits: int = 16) -> torch.Tensor:
|
| """Reduce tensor precision for efficient transmission."""
|
| if bits == 16:
|
| return tensor.half()
|
| elif bits == 32:
|
| return tensor.float()
|
| else:
|
| raise ValueError("Unsupported precision bits")
|
|
|
| @staticmethod
|
| def shard_tensor(tensor: torch.Tensor, shard_size: int) -> List[torch.Tensor]:
|
| """Shard a tensor into smaller pieces for distributed processing."""
|
| return [tensor[i:i + shard_size] for i in range(0, tensor.size(0), shard_size)]
|
|
|
| @staticmethod
|
| def compute_parameter_norm(parameters: Dict[str, torch.Tensor]) -> float:
|
| """Compute the total norm of all parameters."""
|
| total_norm = 0.0
|
| for param in parameters.values():
|
| total_norm += param.norm().item() ** 2
|
| return total_norm ** 0.5 |