| | import torch |
| | from flow_matching.utils import categorical |
| | import math |
| | import inspect |
| | import random |
| |
|
| | def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor: |
| | def rec(n, H): |
| | if n == 1: |
| | return [[H]] |
| | points = [] |
| | for i in range(H + 1): |
| | for tail in rec(n - 1, H - i): |
| | points.append([i] + tail) |
| | return points |
| |
|
| | points = rec(num_obj, num_div) |
| | weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div |
| | return weight_vectors |
| |
|
| | def select_random_weight_vector(num_obj: int, num_div: int): |
| | weight_vectors = generate_simplex_lattice_points(num_obj, num_div) |
| | idx = torch.randint(0, weight_vectors.size(0), (1,)).item() |
| | random_weight_vector = weight_vectors[idx] |
| | return random_weight_vector, weight_vectors |
| |
|
| | def z_score_norm(tensor, eps=1e-8): |
| | mean = tensor.mean(dim=-1, keepdim=True) |
| | std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps) |
| | return (tensor - mean) / std |
| |
|
| | def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args): |
| | B, L, vocab_size = u_t.shape |
| | device = x_t.device |
| | guided_u_t = u_t.clone() |
| | |
| | |
| | |
| | pos_indices = torch.tensor([random.choice([i for i in range(1, L-2) if i != 6])]).to(x_t.device) |
| | batch_idx = torch.arange(B, device=device) |
| | current_tokens = x_t[batch_idx, pos_indices] |
| |
|
| | |
| | full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) |
| | mask = (full_cand_tokens != current_tokens.unsqueeze(1)) & (full_cand_tokens != 23) |
| | |
| | cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 2) |
| |
|
| | |
| | new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone() |
| | new_x = new_x[mask].view(B, vocab_size - 2, L) |
| | new_x[batch_idx, :, pos_indices] = cand_tokens |
| |
|
| | new_x_flat = new_x.view(B * (vocab_size - 2), L) |
| | improvements_list = [] |
| | with torch.no_grad(): |
| | count = 0 |
| | for i, s in enumerate(s_models): |
| | sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
| | if 't' in sig.parameters: |
| | candidate_scores = s(new_x_flat, t) |
| | base_score = s(x_t, t) |
| | else: |
| | candidate_scores = s(new_x_flat) |
| | base_score = s(x_t) |
| |
|
| | if isinstance(candidate_scores, tuple): |
| | for k, score in enumerate(candidate_scores): |
| | improvement = candidate_scores[k].view(B, vocab_size - 2) - base_score[k].unsqueeze(1) |
| | improvement = improvement.float() |
| | improvement *= importance[count] |
| | improvements_list.append(improvement.unsqueeze(2)) |
| | count += 1 |
| | else: |
| | improvement = candidate_scores.view(B, vocab_size - 2) - base_score.unsqueeze(1) |
| | improvement = improvement.float() |
| | improvement *= importance[count] |
| | improvements_list.append(improvement.unsqueeze(2)) |
| | count += 1 |
| |
|
| | improvement_values = torch.cat(improvements_list, dim=2) |
| | if args.is_peptide: |
| | improvement_values[:, :4, :] = -10 |
| |
|
| | |
| | ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 |
| | I_n = ranks / float(vocab_size - 2) |
| | avg_I = I_n.mean(dim=2) |
| | norm_avg_I = z_score_norm(avg_I) |
| | |
| | |
| | D = (improvement_values * w.view(1, 1, -1)).sum(dim=2) |
| | norm_D = z_score_norm(D) |
| |
|
| | |
| | delta_S = norm_avg_I + args.lambda_ * norm_D |
| |
|
| | |
| | factor = torch.exp(args.beta * delta_S) |
| | factor = torch.clamp(factor, min=-100, max=100) |
| |
|
| | guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor |
| |
|
| | |
| | |
| | updated_vals = guided_u_t[batch_idx, pos_indices, :] |
| | sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens] |
| | guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag |
| |
|
| | return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S |
| |
|
| | def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None): |
| | B, num_candidates, N = improvement_values.shape |
| | device = improvement_values.device |
| | eps = 1e-8 |
| |
|
| | |
| | imp_norm = torch.norm(improvement_values.float(), dim=2) |
| | dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2) |
| | w_norm = torch.norm(w) + eps |
| | cos_angle = dot_product / (imp_norm * w_norm + eps) |
| | cos_angle = cos_angle.clamp(-1.0, 1.0) |
| | angles = torch.acos(cos_angle) |
| |
|
| | valid_mask = angles < math.pi / 2 |
| | accepted_mask = valid_mask & (angles <= Phi) |
| |
|
| | |
| | |
| | best_candidate = torch.empty(B, dtype=torch.long, device=device) |
| | for i in range(B): |
| | |
| | if valid_mask[i].any(): |
| | |
| | if accepted_mask[i].any(): |
| | |
| | candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf'))) |
| | else: |
| | |
| | candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf'))) |
| | best_candidate[i] = cand_tokens[i, candidate_idx] |
| | else: |
| | |
| | best_candidate[i] = -1 |
| |
|
| | |
| | rejection_rates = [] |
| | for i in range(B): |
| | valid_candidates = valid_mask[i] |
| | total_valid = valid_candidates.sum().item() |
| | if total_valid > 0: |
| | |
| | num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item() |
| | rejection_rates.append(num_rejected / total_valid) |
| | if len(rejection_rates) > 0: |
| | r_t = sum(rejection_rates) / len(rejection_rates) |
| | else: |
| | |
| | r_t = 0.0 |
| |
|
| | if ema_r_t is None: |
| | ema_r_t = args.tau |
| |
|
| | |
| | if valid_mask.any(): |
| | new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t |
| | new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device)) |
| | new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item() |
| | else: |
| | new_ema_r_t = ema_r_t |
| | new_Phi = Phi |
| |
|
| | return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t |
| |
|
| | def get_best_candidate(improvement_values, cand_tokens, delta_S): |
| | B, num_candidates, N = improvement_values.shape |
| | device = improvement_values.device |
| | best_candidate = torch.empty(B, dtype=torch.long, device=device) |
| | |
| | for i in range(B): |
| | candidate_idx = torch.argmax(delta_S[i]) |
| | best_candidate[i] = cand_tokens[i, candidate_idx] |
| | |
| | return best_candidate |
| |
|
| | def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h): |
| | B, L, V = guided_u_t.shape |
| | device = x_t.device |
| | u = torch.zeros_like(guided_u_t) |
| |
|
| | valid_mask = best_candidate != -1 |
| | if valid_mask.any(): |
| | valid_idx = torch.nonzero(valid_mask).squeeze(-1) |
| | |
| | u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \ |
| | guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] |
| | |
| | |
| | |
| | intensity = torch.zeros(B, device=device) |
| | if valid_mask.any(): |
| | intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1) |
| |
|
| | |
| | |
| | |
| | |
| | p_jump = 1 - torch.exp(-1 * intensity) |
| | |
| | rand_val = torch.rand(B, device=device) |
| |
|
| | jump_decision = (rand_val < p_jump) & valid_mask |
| | |
| | |
| | x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision] |
| |
|
| | return x_t |
| |
|