import re import torch import numpy as np # Implement Union-Find operator for constructing ui patches class UnionFind: def __init__(self, size): self.parent = np.arange(size) def find(self, x): if self.parent[x] != x: self.parent[x] = self.find(self.parent[x]) # Path compression return self.parent[x] def union(self, x, y): px = self.find(x) py = self.find(y) if px != py: self.parent[py] = px def get_select_mask(tensor, skip_ratio=0, rand=False): # Use tensor operations for efficiency if type(tensor) == torch.Tensor: retain_mask = (tensor == -1).clone() unique_vals, counts = torch.unique(tensor, return_counts=True) for i, (val, count) in enumerate(zip(unique_vals, counts)): if val == -1: continue positions = (tensor == val).nonzero(as_tuple=True)[0] num_positions = len(positions) if num_positions == 1: retain_mask[positions] = True else: num_to_skip = int(round(num_positions * skip_ratio)) num_to_retain = max(1, num_positions - num_to_skip) if rand: # rand means random select subset of selective tokens for layer-wise perm = torch.randperm(num_positions, device=tensor.device) positions_to_retain = positions[perm[:num_to_retain]] else: indices = torch.linspace(0, num_positions - 1, steps=num_to_retain).long() positions_to_retain = positions[indices] retain_mask[positions_to_retain] = True else: assert type(tensor) == np.ndarray retain_mask = (tensor == -1).copy() unique_vals, counts = np.unique(tensor, return_counts=True) for val, count in zip(unique_vals, counts): if val == -1: continue positions = np.nonzero(tensor == val)[0] num_positions = len(positions) if num_positions == 1: retain_mask[positions] = True else: num_to_skip = int(round(num_positions * skip_ratio)) num_to_retain = max(1, num_positions - num_to_skip) if rand: perm = np.random.permutation(num_positions) positions_to_retain = positions[perm[:num_to_retain]] else: indices = np.linspace(0, num_positions - 1, num=num_to_retain, dtype=int) positions_to_retain = positions[indices] retain_mask[positions_to_retain] = True return retain_mask def parse_layer_type(str_ranges, L, default=0): # 0 is without layer token selection, 1 is with layer token selection result = [default] * L matches = re.findall(r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]', str_ranges) for start, end, value in matches: start, end, value = int(start) - 1, int(end) - 1, int(value) if end >= L: end = L - 1 result[start:end + 1] = [value] * (end - start + 1) return result