Spaces:
Runtime error
Runtime error
Initial Commit
Browse files- .gitattributes +2 -0
- app.py +20 -8
- weights/sam_vit_h_4b8939.pth +3 -0
- weights/sam_vit_l_0b3195.pth +3 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
weights/sam_vit_b_01ec64.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
weights/sam_vit_b_01ec64.pth filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
weights/sam_vit_h_4b8939.pth filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
weights/sam_vit_l_0b3195.pth filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -6,17 +6,26 @@ from segment_anything import sam_model_registry, SamPredictor
|
|
| 6 |
from preprocess import show_mask, show_points, show_box
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
-
sam_checkpoint =
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
-
|
| 13 |
-
sam.to(device=device)
|
| 14 |
-
predictor = SamPredictor(sam)
|
| 15 |
|
| 16 |
def get_coords(evt: gr.SelectData):
|
| 17 |
return f"{evt.index[0]}, {evt.index[1]}"
|
| 18 |
|
| 19 |
-
def inference(image, input_label):
|
|
|
|
|
|
|
|
|
|
| 20 |
predictor.set_image(image)
|
| 21 |
input_point = np.array([[int(input_label['label'].split(',')[0]), int(input_label['label'].split(',')[1])]])
|
| 22 |
input_label = np.array([1])
|
|
@@ -38,8 +47,10 @@ with my_app:
|
|
| 38 |
with gr.TabItem("Select your image"):
|
| 39 |
with gr.Row():
|
| 40 |
with gr.Column():
|
| 41 |
-
img_source = gr.Image(label="Please select picture
|
|
|
|
| 42 |
coords = gr.Label(label="Image Coordinate")
|
|
|
|
| 43 |
infer = gr.Button(label="Segment")
|
| 44 |
with gr.Column():
|
| 45 |
img_output = gr.Image(label="Output Mask", shape=(1024, 1024))
|
|
@@ -49,7 +60,8 @@ with my_app:
|
|
| 49 |
inference,
|
| 50 |
[
|
| 51 |
img_source,
|
| 52 |
-
coords
|
|
|
|
| 53 |
],
|
| 54 |
[
|
| 55 |
img_output
|
|
|
|
| 6 |
from preprocess import show_mask, show_points, show_box
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
+
sam_checkpoint = {
|
| 10 |
+
"ViT-base": "weights/sam_vit_b_01ec64.pth",
|
| 11 |
+
"ViT-large": "weights/sam_vit_l_0b3195.pth",
|
| 12 |
+
"ViT-huge": "weights/sam_vit_h_4b8939.pth",
|
| 13 |
+
}
|
| 14 |
+
model_type = {
|
| 15 |
+
"ViT-base": "vit_b",
|
| 16 |
+
"ViT-large": "vit_l",
|
| 17 |
+
"ViT-huge": "vit_h",
|
| 18 |
+
}
|
| 19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def get_coords(evt: gr.SelectData):
|
| 23 |
return f"{evt.index[0]}, {evt.index[1]}"
|
| 24 |
|
| 25 |
+
def inference(image, input_label, model_choice):
|
| 26 |
+
sam = sam_model_registry[model_type[model_choice]](checkpoint=sam_checkpoint[model_choice])
|
| 27 |
+
sam.to(device=device)
|
| 28 |
+
predictor = SamPredictor(sam)
|
| 29 |
predictor.set_image(image)
|
| 30 |
input_point = np.array([[int(input_label['label'].split(',')[0]), int(input_label['label'].split(',')[1])]])
|
| 31 |
input_label = np.array([1])
|
|
|
|
| 47 |
with gr.TabItem("Select your image"):
|
| 48 |
with gr.Row():
|
| 49 |
with gr.Column():
|
| 50 |
+
img_source = gr.Image(label="Please select picture and click the part to segment",
|
| 51 |
+
value='./images/truck.jpg', shape=(1024, 1024))
|
| 52 |
coords = gr.Label(label="Image Coordinate")
|
| 53 |
+
model_choice = gr.Dropdown(['ViT-base', 'ViT-large', 'ViT-huge'], label='Model Backbone')
|
| 54 |
infer = gr.Button(label="Segment")
|
| 55 |
with gr.Column():
|
| 56 |
img_output = gr.Image(label="Output Mask", shape=(1024, 1024))
|
|
|
|
| 60 |
inference,
|
| 61 |
[
|
| 62 |
img_source,
|
| 63 |
+
coords,
|
| 64 |
+
model_choice
|
| 65 |
],
|
| 66 |
[
|
| 67 |
img_output
|
weights/sam_vit_h_4b8939.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
| 3 |
+
size 2564550879
|
weights/sam_vit_l_0b3195.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622
|
| 3 |
+
size 1249524607
|