| import pdb |
| import scipy |
| import numpy as np |
|
|
| scipy.inf = np.inf |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import torch.nn.functional as F |
| from dataset.custom_types import MsaInfo |
| from msaf.eval import compute_results |
| from postprocessing.functional import postprocess_functional_structure |
| from x_transformers import Encoder |
| import bisect |
|
|
|
|
| class Head(nn.Module): |
| def __init__(self, input_dim, output_dim, hidden_dims=None, activation="silu"): |
| super().__init__() |
| hidden_dims = hidden_dims or [] |
| act_layers = {"relu": nn.ReLU, "silu": nn.SiLU, "gelu": nn.GELU} |
| act_layer = act_layers.get(activation.lower()) |
| if not act_layer: |
| raise ValueError(f"Unsupported activation: {activation}") |
|
|
| dims = [input_dim] + hidden_dims + [output_dim] |
| layers = [] |
| for i in range(len(dims) - 1): |
| layers.append(nn.Linear(dims[i], dims[i + 1])) |
| if i < len(dims) - 2: |
| layers.append(act_layer()) |
| self.net = nn.Sequential(*layers) |
|
|
| def reset_parameters(self, confidence): |
| bias_value = -torch.log(torch.tensor((1 - confidence) / confidence)) |
| self.net[-1].bias.data.fill_(bias_value.item()) |
|
|
| def forward(self, x): |
| batch, T, C = x.shape |
| x = x.reshape(-1, C) |
| x = self.net(x) |
| return x.reshape(batch, T, -1) |
|
|
|
|
| class WrapedTransformerEncoder(nn.Module): |
| def __init__( |
| self, input_dim, transformer_input_dim, num_layers=1, nhead=8, dropout=0.1 |
| ): |
| super().__init__() |
| self.input_dim = input_dim |
| self.transformer_input_dim = transformer_input_dim |
|
|
| if input_dim != transformer_input_dim: |
| self.input_proj = nn.Sequential( |
| nn.Linear(input_dim, transformer_input_dim), |
| nn.LayerNorm(transformer_input_dim), |
| nn.GELU(), |
| nn.Dropout(dropout * 0.5), |
| nn.Linear(transformer_input_dim, transformer_input_dim), |
| ) |
| else: |
| self.input_proj = nn.Identity() |
|
|
| self.transformer = Encoder( |
| dim=transformer_input_dim, |
| depth=num_layers, |
| heads=nhead, |
| layer_dropout=dropout, |
| attn_dropout=dropout, |
| ff_dropout=dropout, |
| attn_flash=True, |
| rotary_pos_emb=True, |
| ) |
|
|
| def forward(self, x, src_key_padding_mask=None): |
| """ |
| The input src_key_padding_mask is a B x T boolean mask, where True indicates masked positions. |
| However, in x-transformers, False indicates masked positions. |
| Therefore, it needs to be converted so that False represents masked positions. |
| """ |
| x = self.input_proj(x) |
| mask = ( |
| ~torch.tensor(src_key_padding_mask, dtype=torch.bool, device=x.device) |
| if src_key_padding_mask is not None |
| else None |
| ) |
| return self.transformer(x, mask=mask) |
|
|
|
|
| def prefix_dict(d, prefix: str): |
| if prefix: |
| return d |
| return {prefix + key: value for key, value in d.items()} |
|
|
|
|
| class TimeDownsample(nn.Module): |
| def __init__( |
| self, dim_in, dim_out=None, kernel_size=5, stride=5, padding=0, dropout=0.1 |
| ): |
| super().__init__() |
| self.dim_out = dim_out or dim_in |
| assert self.dim_out % 2 == 0 |
|
|
| self.depthwise_conv = nn.Conv1d( |
| in_channels=dim_in, |
| out_channels=dim_in, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| groups=dim_in, |
| bias=False, |
| ) |
| self.pointwise_conv = nn.Conv1d( |
| in_channels=dim_in, |
| out_channels=self.dim_out, |
| kernel_size=1, |
| bias=False, |
| ) |
| self.pool = nn.AvgPool1d(kernel_size, stride, padding=padding) |
| self.norm1 = nn.LayerNorm(self.dim_out) |
| self.act1 = nn.GELU() |
| self.dropout1 = nn.Dropout(dropout) |
|
|
| if dim_in != self.dim_out: |
| self.residual_conv = nn.Conv1d( |
| dim_in, self.dim_out, kernel_size=1, bias=False |
| ) |
| else: |
| self.residual_conv = None |
|
|
| def forward(self, x): |
| residual = x |
| |
| x_c = x.transpose(1, 2) |
| x_c = self.depthwise_conv(x_c) |
| x_c = self.pointwise_conv(x_c) |
|
|
| |
| res = self.pool(residual.transpose(1, 2)) |
| if self.residual_conv: |
| res = self.residual_conv(res) |
| x_c = x_c + res |
| x_c = x_c.transpose(1, 2) |
| x_c = self.norm1(x_c) |
| x_c = self.act1(x_c) |
| x_c = self.dropout1(x_c) |
| return x_c |
|
|
|
|
| class AddFuse(nn.Module): |
| def __init__(self): |
| super(AddFuse, self).__init__() |
|
|
| def forward(self, x, cond): |
| return x + cond |
|
|
|
|
| class TVLoss1D(nn.Module): |
| def __init__( |
| self, beta=1.0, lambda_tv=0.4, boundary_threshold=0.01, reduction_weight=0.1 |
| ): |
| """ |
| Args: |
| beta: Exponential parameter for TV loss (recommended 0.5~1.0) |
| lambda_tv: Overall weight for TV loss |
| boundary_threshold: Label threshold to determine if a region is a "boundary area" (e.g., 0.01) |
| reduction_weight: Scaling factor for TV penalty within boundary regions (e.g., 0.1, meaning only 10% penalty) |
| """ |
| super().__init__() |
| self.beta = beta |
| self.lambda_tv = lambda_tv |
| self.boundary_threshold = boundary_threshold |
| self.reduction_weight = reduction_weight |
|
|
| def forward(self, pred, target=None): |
| """ |
| Args: |
| pred: (B, T) or (B, T, 1), float boundary scores output by the model |
| target: (B, T) or (B, T, 1), ground truth labels (optional, used for spatial weighting if provided) |
| |
| Returns: |
| scalar: weighted TV loss |
| """ |
| if pred.dim() == 3: |
| pred = pred.squeeze(-1) |
| if target is not None and target.dim() == 3: |
| target = target.squeeze(-1) |
|
|
| diff = pred[:, 1:] - pred[:, :-1] |
| tv_base = torch.pow(torch.abs(diff) + 1e-8, self.beta) |
|
|
| if target is None: |
| return self.lambda_tv * tv_base.mean() |
|
|
| left_in_boundary = target[:, :-1] > self.boundary_threshold |
| right_in_boundary = target[:, 1:] > self.boundary_threshold |
| near_boundary = left_in_boundary | right_in_boundary |
| weight_mask = torch.where( |
| near_boundary, |
| self.reduction_weight * torch.ones_like(tv_base), |
| torch.ones_like(tv_base), |
| ) |
| tv_weighted = (tv_base * weight_mask).mean() |
| return self.lambda_tv * tv_weighted |
|
|
|
|
| class SoftmaxFocalLoss(nn.Module): |
| """ |
| Softmax Focal Loss for single-label multi-class classification. |
| Suitable for mutually exclusive classes. |
| """ |
|
|
| def __init__(self, alpha: float = 0.25, gamma: float = 2.0): |
| super().__init__() |
| self.alpha = alpha |
| self.gamma = gamma |
|
|
| def forward(self, pred: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| pred: [B, T, C], raw logits |
| targets: [B, T, C] (soft) or [B, T] (hard, dtype=long) |
| Returns: |
| loss: scalar or [B, T] depending on reduction |
| """ |
| log_probs = F.log_softmax(pred, dim=-1) |
| probs = torch.exp(log_probs) |
|
|
| if targets.dtype == torch.long: |
| targets_onehot = F.one_hot(targets, num_classes=pred.size(-1)).float() |
| else: |
| targets_onehot = targets |
|
|
| p_t = (probs * targets_onehot).sum(dim=-1) |
| p_t = p_t.clamp(min=1e-8, max=1.0 - 1e-8) |
|
|
| if self.alpha > 0: |
| alpha_t = self.alpha * targets_onehot + (1 - self.alpha) * ( |
| 1 - targets_onehot |
| ) |
| alpha_weight = (alpha_t * targets_onehot).sum(dim=-1) |
| else: |
| alpha_weight = 1.0 |
|
|
| focal_weight = (1 - p_t) ** self.gamma |
| ce_loss = -log_probs * targets_onehot |
| ce_loss = ce_loss.sum(dim=-1) |
|
|
| loss = alpha_weight * focal_weight * ce_loss |
| return loss |
|
|
|
|
| class Model(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| self.input_norm = nn.LayerNorm(config.input_dim) |
| self.mixed_win_downsample = nn.Linear(config.input_dim_raw, config.input_dim) |
| self.dataset_class_prefix = nn.Embedding( |
| num_embeddings=config.num_dataset_classes, |
| embedding_dim=config.transformer_encoder_input_dim, |
| ) |
| self.down_sample_conv = TimeDownsample( |
| dim_in=config.input_dim, |
| dim_out=config.transformer_encoder_input_dim, |
| kernel_size=config.down_sample_conv_kernel_size, |
| stride=config.down_sample_conv_stride, |
| dropout=config.down_sample_conv_dropout, |
| padding=config.down_sample_conv_padding, |
| ) |
| self.AddFuse = AddFuse() |
| self.transformer = WrapedTransformerEncoder( |
| input_dim=config.transformer_encoder_input_dim, |
| transformer_input_dim=config.transformer_input_dim, |
| num_layers=config.num_transformer_layers, |
| nhead=config.transformer_nhead, |
| dropout=config.transformer_dropout, |
| ) |
| self.boundary_TVLoss1D = TVLoss1D( |
| beta=config.boundary_tv_loss_beta, |
| lambda_tv=config.boundary_tv_loss_lambda, |
| boundary_threshold=config.boundary_tv_loss_boundary_threshold, |
| reduction_weight=config.boundary_tv_loss_reduction_weight, |
| ) |
| self.label_focal_loss = SoftmaxFocalLoss( |
| alpha=config.label_focal_loss_alpha, gamma=config.label_focal_loss_gamma |
| ) |
| self.boundary_head = Head(config.transformer_input_dim, 1) |
| self.function_head = Head(config.transformer_input_dim, config.num_classes) |
|
|
| def cal_metrics(self, gt_info: MsaInfo, msa_info: MsaInfo): |
| assert gt_info[-1][1] == "end" and msa_info[-1][1] == "end", ( |
| "gt_info and msa_info should end with 'end'" |
| ) |
| gt_info_labels = [label for time_, label in gt_info][:-1] |
| gt_info_inters = [time_ for time_, label in gt_info] |
| gt_info_inters = np.column_stack( |
| [np.array(gt_info_inters[:-1]), np.array(gt_info_inters[1:])] |
| ) |
|
|
| msa_info_labels = [label for time_, label in msa_info][:-1] |
| msa_info_inters = [time_ for time_, label in msa_info] |
| msa_info_inters = np.column_stack( |
| [np.array(msa_info_inters[:-1]), np.array(msa_info_inters[1:])] |
| ) |
| result = compute_results( |
| ann_inter=gt_info_inters, |
| est_inter=msa_info_inters, |
| ann_labels=gt_info_labels, |
| est_labels=msa_info_labels, |
| bins=11, |
| est_file="test.txt", |
| weight=0.58, |
| ) |
| return result |
|
|
| def cal_acc( |
| self, ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3 |
| ): |
| ann_info_time = [ |
| int(round(time_, post_digit) * (10**post_digit)) |
| for time_, label in ann_info |
| ] |
| est_info_time = [ |
| int(round(time_, post_digit) * (10**post_digit)) |
| for time_, label in est_info |
| ] |
|
|
| common_start_time = max(ann_info_time[0], est_info_time[0]) |
| common_end_time = min(ann_info_time[-1], est_info_time[-1]) |
|
|
| time_points = {common_start_time, common_end_time} |
| time_points.update( |
| { |
| time_ |
| for time_ in ann_info_time |
| if common_start_time <= time_ <= common_end_time |
| } |
| ) |
| time_points.update( |
| { |
| time_ |
| for time_ in est_info_time |
| if common_start_time <= time_ <= common_end_time |
| } |
| ) |
|
|
| time_points = sorted(time_points) |
| total_duration, total_score = 0, 0 |
|
|
| for idx in range(len(time_points) - 1): |
| duration = time_points[idx + 1] - time_points[idx] |
| ann_label = ann_info[ |
| bisect.bisect_right(ann_info_time, time_points[idx]) - 1 |
| ][1] |
| est_label = est_info[ |
| bisect.bisect_right(est_info_time, time_points[idx]) - 1 |
| ][1] |
| total_duration += duration |
| if ann_label == est_label: |
| total_score += duration |
| return total_score / total_duration |
|
|
| def infer_with_metrics(self, batch, prefix: str = None): |
| with torch.no_grad(): |
| logits = self.forward_func(batch) |
|
|
| losses = self.compute_losses(logits, batch, prefix=None) |
|
|
| expanded_mask = batch["label_id_masks"].expand( |
| -1, logits["function_logits"].size(1), -1 |
| ) |
| logits["function_logits"] = logits["function_logits"].masked_fill( |
| expanded_mask, -float("inf") |
| ) |
|
|
| msa_info = postprocess_functional_structure( |
| logits=logits, config=self.config |
| ) |
| gt_info = batch["msa_infos"][0] |
| results = self.cal_metrics(gt_info=gt_info, msa_info=msa_info) |
|
|
| ret_results = { |
| "loss": losses["loss"].item(), |
| "HitRate_3P": results["HitRate_3P"], |
| "HitRate_3R": results["HitRate_3R"], |
| "HitRate_3F": results["HitRate_3F"], |
| "HitRate_0.5P": results["HitRate_0.5P"], |
| "HitRate_0.5R": results["HitRate_0.5R"], |
| "HitRate_0.5F": results["HitRate_0.5F"], |
| "PWF": results["PWF"], |
| "PWP": results["PWP"], |
| "PWR": results["PWR"], |
| "Sf": results["Sf"], |
| "So": results["So"], |
| "Su": results["Su"], |
| "acc": self.cal_acc(ann_info=gt_info, est_info=msa_info), |
| } |
| if prefix: |
| ret_results = prefix_dict(ret_results, prefix) |
|
|
| return ret_results |
|
|
| def infer( |
| self, |
| input_embeddings, |
| dataset_ids, |
| label_id_masks, |
| prefix: str = None, |
| with_logits=False, |
| ): |
| with torch.no_grad(): |
| input_embeddings = self.mixed_win_downsample(input_embeddings) |
| input_embeddings = self.input_norm(input_embeddings) |
| logits = self.down_sample_conv(input_embeddings) |
|
|
| dataset_prefix = self.dataset_class_prefix(dataset_ids) |
| dataset_prefix_expand = dataset_prefix.unsqueeze(1).expand( |
| logits.size(0), 1, -1 |
| ) |
| logits = self.AddFuse(x=logits, cond=dataset_prefix_expand) |
| logits = self.transformer(x=logits, src_key_padding_mask=None) |
|
|
| function_logits = self.function_head(logits) |
| boundary_logits = self.boundary_head(logits).squeeze(-1) |
|
|
| logits = { |
| "function_logits": function_logits, |
| "boundary_logits": boundary_logits, |
| } |
|
|
| expanded_mask = label_id_masks.expand( |
| -1, logits["function_logits"].size(1), -1 |
| ) |
| logits["function_logits"] = logits["function_logits"].masked_fill( |
| expanded_mask, -float("inf") |
| ) |
|
|
| msa_info = postprocess_functional_structure( |
| logits=logits, config=self.config |
| ) |
|
|
| return (msa_info, logits) if with_logits else msa_info |
|
|
| def compute_losses(self, outputs, batch, prefix: str = None): |
| loss = 0.0 |
| losses = {} |
|
|
| loss_section = F.binary_cross_entropy_with_logits( |
| outputs["boundary_logits"], |
| batch["widen_true_boundaries"], |
| reduction="none", |
| ) |
| loss_section += self.config.boundary_tvloss_weight * self.boundary_TVLoss1D( |
| pred=outputs["boundary_logits"], |
| target=batch["widen_true_boundaries"], |
| ) |
| loss_function = F.cross_entropy( |
| outputs["function_logits"].transpose(1, 2), |
| batch["true_functions"].transpose(1, 2), |
| reduction="none", |
| ) |
| |
| ttt = self.config.label_focal_loss_weight * self.label_focal_loss( |
| pred=outputs["function_logits"], targets=batch["true_functions"] |
| ) |
| loss_function += ttt |
|
|
| float_masks = (~batch["masks"]).float() |
| boundary_mask = batch.get("boundary_mask", None) |
| function_mask = batch.get("function_mask", None) |
| if boundary_mask is not None: |
| boundary_mask = (~boundary_mask).float() |
| else: |
| boundary_mask = 1 |
|
|
| if function_mask is not None: |
| function_mask = (~function_mask).float() |
| else: |
| function_mask = 1 |
|
|
| loss_section = torch.mean(boundary_mask * float_masks * loss_section) |
| loss_function = torch.mean(function_mask * float_masks * loss_function) |
|
|
| loss_section *= self.config.loss_weight_section |
| loss_function *= self.config.loss_weight_function |
|
|
| if self.config.learn_label: |
| loss += loss_function |
| if self.config.learn_segment: |
| loss += loss_section |
|
|
| losses.update( |
| loss=loss, |
| loss_section=loss_section, |
| loss_function=loss_function, |
| ) |
| if prefix: |
| losses = prefix_dict(losses, prefix) |
| return losses |
|
|
| def forward_func(self, batch): |
| input_embeddings = batch["input_embeddings"] |
| input_embeddings = self.mixed_win_downsample(input_embeddings) |
| input_embeddings = self.input_norm(input_embeddings) |
| logits = self.down_sample_conv(input_embeddings) |
|
|
| dataset_prefix = self.dataset_class_prefix(batch["dataset_ids"]) |
| logits = self.AddFuse(x=logits, cond=dataset_prefix.unsqueeze(1)) |
| src_key_padding_mask = batch["masks"] |
| logits = self.transformer(x=logits, src_key_padding_mask=src_key_padding_mask) |
|
|
| function_logits = self.function_head(logits) |
| boundary_logits = self.boundary_head(logits).squeeze(-1) |
|
|
| logits = { |
| "function_logits": function_logits, |
| "boundary_logits": boundary_logits, |
| } |
| return logits |
|
|
| def forward(self, batch): |
| logits = self.forward_func(batch) |
| losses = self.compute_losses(logits, batch, prefix=None) |
| return logits, losses["loss"], losses |
|
|