KurtLin commited on
Commit
00a3d3d
·
1 Parent(s): 17191a0

Initial Commit

Browse files
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  weights/sam_vit_b_01ec64.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  weights/sam_vit_b_01ec64.pth filter=lfs diff=lfs merge=lfs -text
37
+ weights/sam_vit_h_4b8939.pth filter=lfs diff=lfs merge=lfs -text
38
+ weights/sam_vit_l_0b3195.pth filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -6,17 +6,26 @@ from segment_anything import sam_model_registry, SamPredictor
6
  from preprocess import show_mask, show_points, show_box
7
  import gradio as gr
8
 
9
- sam_checkpoint = "weights/sam_vit_b_01ec64.pth"
10
- model_type = "vit_b"
 
 
 
 
 
 
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
13
- sam.to(device=device)
14
- predictor = SamPredictor(sam)
15
 
16
  def get_coords(evt: gr.SelectData):
17
  return f"{evt.index[0]}, {evt.index[1]}"
18
 
19
- def inference(image, input_label):
 
 
 
20
  predictor.set_image(image)
21
  input_point = np.array([[int(input_label['label'].split(',')[0]), int(input_label['label'].split(',')[1])]])
22
  input_label = np.array([1])
@@ -38,8 +47,10 @@ with my_app:
38
  with gr.TabItem("Select your image"):
39
  with gr.Row():
40
  with gr.Column():
41
- img_source = gr.Image(label="Please select picture.", value='./images/truck.jpg', shape=(1024, 1024))
 
42
  coords = gr.Label(label="Image Coordinate")
 
43
  infer = gr.Button(label="Segment")
44
  with gr.Column():
45
  img_output = gr.Image(label="Output Mask", shape=(1024, 1024))
@@ -49,7 +60,8 @@ with my_app:
49
  inference,
50
  [
51
  img_source,
52
- coords
 
53
  ],
54
  [
55
  img_output
 
6
  from preprocess import show_mask, show_points, show_box
7
  import gradio as gr
8
 
9
+ sam_checkpoint = {
10
+ "ViT-base": "weights/sam_vit_b_01ec64.pth",
11
+ "ViT-large": "weights/sam_vit_l_0b3195.pth",
12
+ "ViT-huge": "weights/sam_vit_h_4b8939.pth",
13
+ }
14
+ model_type = {
15
+ "ViT-base": "vit_b",
16
+ "ViT-large": "vit_l",
17
+ "ViT-huge": "vit_h",
18
+ }
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
 
 
21
 
22
  def get_coords(evt: gr.SelectData):
23
  return f"{evt.index[0]}, {evt.index[1]}"
24
 
25
+ def inference(image, input_label, model_choice):
26
+ sam = sam_model_registry[model_type[model_choice]](checkpoint=sam_checkpoint[model_choice])
27
+ sam.to(device=device)
28
+ predictor = SamPredictor(sam)
29
  predictor.set_image(image)
30
  input_point = np.array([[int(input_label['label'].split(',')[0]), int(input_label['label'].split(',')[1])]])
31
  input_label = np.array([1])
 
47
  with gr.TabItem("Select your image"):
48
  with gr.Row():
49
  with gr.Column():
50
+ img_source = gr.Image(label="Please select picture and click the part to segment",
51
+ value='./images/truck.jpg', shape=(1024, 1024))
52
  coords = gr.Label(label="Image Coordinate")
53
+ model_choice = gr.Dropdown(['ViT-base', 'ViT-large', 'ViT-huge'], label='Model Backbone')
54
  infer = gr.Button(label="Segment")
55
  with gr.Column():
56
  img_output = gr.Image(label="Output Mask", shape=(1024, 1024))
 
60
  inference,
61
  [
62
  img_source,
63
+ coords,
64
+ model_choice
65
  ],
66
  [
67
  img_output
weights/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
weights/sam_vit_l_0b3195.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622
3
+ size 1249524607