|
|
import re |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
class UnionFind: |
|
|
def __init__(self, size): |
|
|
self.parent = np.arange(size) |
|
|
def find(self, x): |
|
|
if self.parent[x] != x: |
|
|
self.parent[x] = self.find(self.parent[x]) |
|
|
return self.parent[x] |
|
|
def union(self, x, y): |
|
|
px = self.find(x) |
|
|
py = self.find(y) |
|
|
if px != py: |
|
|
self.parent[py] = px |
|
|
|
|
|
def get_select_mask(tensor, skip_ratio=0, rand=False): |
|
|
|
|
|
if type(tensor) == torch.Tensor: |
|
|
retain_mask = (tensor == -1).clone() |
|
|
unique_vals, counts = torch.unique(tensor, return_counts=True) |
|
|
|
|
|
for i, (val, count) in enumerate(zip(unique_vals, counts)): |
|
|
if val == -1: |
|
|
continue |
|
|
positions = (tensor == val).nonzero(as_tuple=True)[0] |
|
|
num_positions = len(positions) |
|
|
|
|
|
if num_positions == 1: |
|
|
retain_mask[positions] = True |
|
|
else: |
|
|
num_to_skip = int(round(num_positions * skip_ratio)) |
|
|
num_to_retain = max(1, num_positions - num_to_skip) |
|
|
if rand: |
|
|
|
|
|
perm = torch.randperm(num_positions, device=tensor.device) |
|
|
positions_to_retain = positions[perm[:num_to_retain]] |
|
|
else: |
|
|
indices = torch.linspace(0, num_positions - 1, steps=num_to_retain).long() |
|
|
positions_to_retain = positions[indices] |
|
|
|
|
|
retain_mask[positions_to_retain] = True |
|
|
else: |
|
|
assert type(tensor) == np.ndarray |
|
|
retain_mask = (tensor == -1).copy() |
|
|
unique_vals, counts = np.unique(tensor, return_counts=True) |
|
|
|
|
|
for val, count in zip(unique_vals, counts): |
|
|
if val == -1: |
|
|
continue |
|
|
positions = np.nonzero(tensor == val)[0] |
|
|
num_positions = len(positions) |
|
|
|
|
|
if num_positions == 1: |
|
|
retain_mask[positions] = True |
|
|
else: |
|
|
num_to_skip = int(round(num_positions * skip_ratio)) |
|
|
num_to_retain = max(1, num_positions - num_to_skip) |
|
|
if rand: |
|
|
perm = np.random.permutation(num_positions) |
|
|
positions_to_retain = positions[perm[:num_to_retain]] |
|
|
else: |
|
|
indices = np.linspace(0, num_positions - 1, num=num_to_retain, dtype=int) |
|
|
positions_to_retain = positions[indices] |
|
|
|
|
|
retain_mask[positions_to_retain] = True |
|
|
return retain_mask |
|
|
|
|
|
def parse_layer_type(str_ranges, L, default=0): |
|
|
|
|
|
result = [default] * L |
|
|
matches = re.findall(r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]', str_ranges) |
|
|
for start, end, value in matches: |
|
|
start, end, value = int(start) - 1, int(end) - 1, int(value) |
|
|
if end >= L: |
|
|
end = L - 1 |
|
|
result[start:end + 1] = [value] * (end - start + 1) |
|
|
return result |