| import torch |
| from torchvision import transforms |
| from math import pi |
| import torchvision.transforms.functional as TF |
|
|
|
|
| |
| def exists(val): |
| """Check if a variable exists""" |
| return val is not None |
|
|
|
|
| def uniq(arr): |
| return {el: True for el in arr}.keys() |
|
|
|
|
| def default(val, d): |
| """If a value exists, return it; otherwise, return a default value""" |
| return val if exists(val) else d |
|
|
|
|
| def max_neg_value(t): |
| return -torch.finfo(t.dtype).max |
|
|
|
|
| def cast_tuple(val, depth=1): |
| if isinstance(val, list): |
| val = tuple(val) |
| return val if isinstance(val, tuple) else (val,) * depth |
|
|
|
|
| def is_empty(t): |
| """Check if a tensor is empty""" |
| |
| return t.nelement() == 0 |
|
|
|
|
| def masked_mean(t, mask, dim=1): |
| """ |
| Compute the mean of a tensor, masked by a given mask |
| |
| Args: |
| t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim) |
| mask (torch.Tensor): mask tensor of shape (batch_size, seq_len) |
| dim (int): dimension along which to compute the mean (default=1) |
| |
| Returns: |
| torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim) |
| """ |
| t = t.masked_fill(~mask[:, :, None], 0.0) |
| return t.sum(dim=1) / mask.sum(dim=1)[..., None] |
|
|
|
|
| def set_requires_grad(model, value): |
| """ |
| Set whether or not the model's parameters require gradients |
| |
| Args: |
| model (torch.nn.Module): the PyTorch model to modify |
| value (bool): whether or not to require gradients |
| """ |
| for param in model.parameters(): |
| param.requires_grad = value |
|
|
|
|
| def eval_decorator(fn): |
| """ |
| Decorator function to evaluate a given function |
| |
| Args: |
| fn (callable): function to evaluate |
| |
| Returns: |
| callable: the decorated function |
| """ |
|
|
| def inner(model, *args, **kwargs): |
| was_training = model.training |
| model.eval() |
| out = fn(model, *args, **kwargs) |
| model.train(was_training) |
| return out |
|
|
| return inner |
|
|
|
|
| def log(t, eps=1e-20): |
| """ |
| Compute the natural logarithm of a tensor |
| |
| Args: |
| t (torch.Tensor): input tensor |
| eps (float): small value to add to prevent taking the log of 0 (default=1e-20) |
| |
| Returns: |
| torch.Tensor: the natural logarithm of the input tensor |
| """ |
| return torch.log(t + eps) |
|
|
|
|
| def gumbel_noise(t): |
| """ |
| Generate Gumbel noise |
| |
| Args: |
| t (torch.Tensor): input tensor |
| |
| Returns: |
| torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor |
| """ |
| noise = torch.zeros_like(t).uniform_(0, 1) |
| return -log(-log(noise)) |
|
|
|
|
| def gumbel_sample(t, temperature=0.9, dim=-1): |
| """ |
| Sample from a Gumbel-softmax distribution |
| |
| Args: |
| t (torch.Tensor): input tensor of shape (batch_size, num_classes) |
| temperature (float): temperature for the Gumbel-softmax distribution (default=0.9) |
| dim (int): dimension along which to sample (default=-1) |
| |
| Returns: |
| torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor |
| """ |
| return (t / max(temperature, 1e-10)) + gumbel_noise(t) |
|
|
|
|
| def top_k(logits, thres=0.5): |
| """ |
| Return a tensor where all but the top k values are set to negative infinity |
| |
| Args: |
| logits (torch.Tensor): input tensor of shape (batch_size, num_classes) |
| thres (float): threshold for the top k values (default=0.5) |
| |
| Returns: |
| torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity |
| """ |
| num_logits = logits.shape[-1] |
| k = max(int((1 - thres) * num_logits), 1) |
| val, ind = torch.topk(logits, k) |
| probs = torch.full_like(logits, float("-inf")) |
| probs.scatter_(-1, ind, val) |
| return probs |
|
|
|
|
| def gamma_func(mode="cosine", scale=0.15): |
| """Return a function that takes a single input r and returns a value based on the selected mode""" |
|
|
| |
| if mode == "linear": |
| return lambda r: 1 - r |
| elif mode == "cosine": |
| return lambda r: torch.cos(r * pi / 2) |
| elif mode == "square": |
| return lambda r: 1 - r**2 |
| elif mode == "cubic": |
| return lambda r: 1 - r**3 |
| elif mode == "scaled-cosine": |
| return lambda r: scale * (torch.cos(r * pi / 2)) |
| else: |
| |
| raise NotImplementedError |
|
|
|
|
| class always: |
| """Helper class to always return a given value""" |
|
|
| def __init__(self, val): |
| self.val = val |
|
|
| def __call__(self, x, *args, **kwargs): |
| return self.val |
|
|
|
|
| class DivideMax(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| maxes = x.amax(dim=self.dim, keepdim=True).detach() |
| return x / maxes |
| |
| def replace_outliers(image, percentile=0.0001): |
|
|
| lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile( |
| image, 1 - percentile |
| ) |
| mask = (image <= upper_bound) & (image >= lower_bound) |
|
|
| valid_pixels = image[mask] |
|
|
| image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels)) |
|
|
| return image |
|
|
|
|
| def process_image(image, dataset=None, image_type=None): |
| image /= image.max() |
|
|
| if dataset == "HPA": |
| if image_type == 'nucleus': |
| normalize = (0.0655, 0.0650) |
| |
| elif image_type == 'protein': |
| normalize = (0.1732, 0.1208) |
|
|
| elif dataset == "OpenCell": |
|
|
| if image_type == 'nucleus': |
| normalize = (0.0272, 0.0244) |
| |
| elif image_type == 'protein': |
| normalize = (0.0486, 0.0671) |
|
|
| t_forms = [] |
|
|
| t_forms.append(transforms.RandomCrop(256)) |
| |
| |
|
|
|
|
| image = transforms.Compose(t_forms)(image) |
|
|
| return image |
|
|