File size: 3,174 Bytes
bb14d6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import numpy as np
import sam3
from sam3 import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from .masks import rle_to_mask

def get_index(dataset, image_id):
    idx = dataset.metadata['image_id'] == image_id
    if idx.sum() != 1:
        raise ValueError('image_id not found or found multiple times.')
    return dataset.metadata[idx].index[0]

def mask_centroid(mask):
    ys, xs = np.nonzero(mask)
    return np.array([xs.mean(), ys.mean()])

def rle_centroid(rle):
    return mask_centroid(rle_to_mask(rle))

def assign_flippers(df):
    df = df.copy()

    # Check that there is only one head
    head_rows = df[df['label'] == 'head']
    if len(head_rows) != 1:
        return df
    
    # Compute the head centroid
    head_center = rle_centroid(head_rows.iloc[0]['mask'])

    # Extract the flippers
    flippers = df[df['label'] == 'flipper']
    n_flippers = len(flippers)
    if n_flippers == 0:
        return df

    # Compute the flipper centroids
    flipper_centers = np.vstack([
        rle_centroid(rle) for rle in flippers['mask']
    ])

    # Vector from turtle center to head defines "forward"
    turtle_center = flipper_centers.mean(axis=0)
    forward_vec = head_center - turtle_center
    forward_vec /= np.linalg.norm(forward_vec)

    # Perpendicular defines left/right
    left_vec = np.array([-forward_vec[1], forward_vec[0]])

    # Project flippers
    forward_proj = flipper_centers @ forward_vec
    lateral_proj = flipper_centers @ left_vec

    if n_flippers <= 2:
        # Always front flippers
        order = np.argsort(lateral_proj)
        left_idx, right_idx = order[0], order[-1]

        df.loc[flippers.index[left_idx], 'label'] = 'flipper_fl'
        df.loc[flippers.index[right_idx], 'label'] = 'flipper_fr'
        return df
    elif n_flippers <= 4:
        # Sort by forward distance
        order_fwd = np.argsort(forward_proj)
        rear_idxs = order_fwd[:2]
        front_idxs = order_fwd[-2:]

        # Front flippers
        front_l = front_idxs[np.argmin(lateral_proj[front_idxs])]
        front_r = front_idxs[np.argmax(lateral_proj[front_idxs])]

        df.loc[flippers.index[front_l], 'label'] = 'flipper_fl'
        df.loc[flippers.index[front_r], 'label'] = 'flipper_fr'

        # Rear flippers (if present)
        if len(rear_idxs) == 2:
            rear_l = rear_idxs[np.argmin(lateral_proj[rear_idxs])]
            rear_r = rear_idxs[np.argmax(lateral_proj[rear_idxs])]

            df.loc[flippers.index[rear_l], 'label'] = 'flipper_rl'
            df.loc[flippers.index[rear_r], 'label'] = 'flipper_rr'
        else:
            # 3 flippers: assign only the most rear one
            idx = rear_idxs[0]
            side = 'l' if lateral_proj[idx] < 0 else 'r'
            df.loc[flippers.index[idx], 'label'] = f'flipper_r{side}'

    return df

def initialize_sam3():
    sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
    bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
    model = build_sam3_image_model(bpe_path=bpe_path)
    processor = Sam3Processor(model, confidence_threshold=0.5)
    return model, processor