File size: 1,709 Bytes
c183d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import axengine as axe


class SAMDecoder:

    def __init__(self, model_path):
        self.sess = axe.InferenceSession(model_path)
        for input in self.sess.get_inputs():
            print(input.name, input.shape)
        for output in self.sess.get_outputs():
            print(output.name, output.shape)
        
        self.mask = np.zeros((1, 1, 256, 256), np.float32)
        self.has_mask = np.array([0], np.float32)

    def decode(self, image_embedding, point = None, box = None, scale = None):
        if point is not None:
            point = np.array(point).astype(np.float32) * scale
            point_coords = np.array([point, (0,0), (0,0), (0,0), (0,0)]).astype(np.float32).reshape((1, -1, 2))
            point_labels = np.array([1, 0, 0, 0, 0], np.float32).reshape((1, -1))
        elif box is not None:
            box = np.array(box).astype(np.float32)*scale
            x, y, w, h = box
            center = np.array([x + w/2, y + h/2], np.float32)
            topleft = np.array([x, y], np.float32)
            bottomright = np.array([x + w, y + h], np.float32)
            point_coords = np.array([center, topleft, bottomright, (0,0), (0,0)]).astype(np.float32).reshape((1, -1, 2))
            point_labels = np.array([1, 2, 3, 0, 0], np.float32).reshape((1, -1))
        else:
            raise ValueError("Either point or box must be provided.")
        inputs = {
            "image_embeddings": image_embedding,
            "point_coords": point_coords,
            "point_labels": point_labels,
            "mask_input": self.mask,
            "has_mask_input": self.has_mask,
        }
        outputs = self.sess.run(None, inputs)
        return outputs