File size: 1,709 Bytes
c183d33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import numpy as np
import axengine as axe
class SAMDecoder:
def __init__(self, model_path):
self.sess = axe.InferenceSession(model_path)
for input in self.sess.get_inputs():
print(input.name, input.shape)
for output in self.sess.get_outputs():
print(output.name, output.shape)
self.mask = np.zeros((1, 1, 256, 256), np.float32)
self.has_mask = np.array([0], np.float32)
def decode(self, image_embedding, point = None, box = None, scale = None):
if point is not None:
point = np.array(point).astype(np.float32) * scale
point_coords = np.array([point, (0,0), (0,0), (0,0), (0,0)]).astype(np.float32).reshape((1, -1, 2))
point_labels = np.array([1, 0, 0, 0, 0], np.float32).reshape((1, -1))
elif box is not None:
box = np.array(box).astype(np.float32)*scale
x, y, w, h = box
center = np.array([x + w/2, y + h/2], np.float32)
topleft = np.array([x, y], np.float32)
bottomright = np.array([x + w, y + h], np.float32)
point_coords = np.array([center, topleft, bottomright, (0,0), (0,0)]).astype(np.float32).reshape((1, -1, 2))
point_labels = np.array([1, 2, 3, 0, 0], np.float32).reshape((1, -1))
else:
raise ValueError("Either point or box must be provided.")
inputs = {
"image_embeddings": image_embedding,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": self.mask,
"has_mask_input": self.has_mask,
}
outputs = self.sess.run(None, inputs)
return outputs
|