| | import argparse |
| | import os |
| | import torch |
| | import torch.nn.functional as F |
| | import json |
| | import monai.transforms as transforms |
| |
|
| | from model.segment_anything_volumetric import sam_model_registry |
| | from model.network.model import SegVol |
| | from model.data_process.demo_data_process import process_ct_gt |
| | from model.utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor |
| | from model.utils.visualize import draw_result |
| | import streamlit as st |
| |
|
| | def set_parse(): |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--test_mode", default=True, type=bool) |
| | parser.add_argument("--resume", type = str, default = 'SegVol_v1.pth') |
| | parser.add_argument("-infer_overlap", default=0.0, type=float, help="sliding window inference overlap") |
| | parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple) |
| | parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple) |
| | parser.add_argument('-work_dir', type=str, default='./work_dir') |
| | |
| | parser.add_argument("--clip_ckpt", type = str, default = 'model/config/clip') |
| | args = parser.parse_args() |
| | return args |
| |
|
| | def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point_prompt, box_prompt): |
| | image_single_resize = image_resize |
| | image_single = image[0,0] |
| | ori_shape = image_single.shape |
| | resize_shape = image_single_resize.shape[2:] |
| | |
| | |
| | text_single = None if text_prompt is None else [text_prompt] |
| | points_single = None |
| | box_single = None |
| |
|
| | if args.use_point_prompt: |
| | point, point_label = point_prompt |
| | points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float()) |
| | binary_points_resize = build_binary_points(point, point_label, resize_shape) |
| | if args.use_box_prompt: |
| | box_single = box_prompt.unsqueeze(0).float() |
| | binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=resize_shape) |
| | |
| | |
| | |
| | print('--- zoom out inference ---') |
| | print(text_single) |
| | print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]') |
| | with torch.no_grad(): |
| | logits_global_single = segvol_model(image_single_resize, |
| | text=text_single, |
| | boxes=box_single, |
| | points=points_single) |
| | |
| | |
| | logits_global_single = F.interpolate( |
| | logits_global_single.cpu(), |
| | size=ori_shape, mode='nearest')[0][0] |
| | |
| | |
| | if args.use_point_prompt: |
| | binary_points = F.interpolate( |
| | binary_points_resize.unsqueeze(0).unsqueeze(0).float(), |
| | size=ori_shape, mode='nearest')[0][0] |
| | if args.use_box_prompt: |
| | binary_cube = F.interpolate( |
| | binary_cube_resize.unsqueeze(0).unsqueeze(0).float(), |
| | size=ori_shape, mode='nearest')[0][0] |
| | |
| | if not args.use_zoom_in: |
| | return logits_global_single |
| |
|
| | |
| | |
| | min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single) |
| | if min_d is None: |
| | print('Fail to detect foreground!') |
| | return logits_global_single |
| |
|
| | |
| | image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0) |
| | global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long() |
| | |
| | assert not (args.use_box_prompt and args.use_point_prompt) |
| | |
| | prompt_reflection = None |
| | if args.use_box_prompt: |
| | binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] |
| | prompt_reflection = ( |
| | binary_cube_cropped.unsqueeze(0).unsqueeze(0), |
| | global_preds.unsqueeze(0).unsqueeze(0) |
| | ) |
| | if args.use_point_prompt: |
| | binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] |
| | prompt_reflection = ( |
| | binary_points_cropped.unsqueeze(0).unsqueeze(0), |
| | global_preds.unsqueeze(0).unsqueeze(0) |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | logits_single_cropped = sliding_window_inference( |
| | image_single_cropped, prompt_reflection, |
| | args.spatial_size, 1, segvol_model, args.infer_overlap, |
| | text=text_single, |
| | use_box=args.use_box_prompt, |
| | use_point=args.use_point_prompt, |
| | logits_global_single=logits_global_single, |
| | ) |
| | logits_single_cropped = logits_single_cropped.cpu().squeeze() |
| | if logits_single_cropped.shape != logits_global_single.shape: |
| | logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped |
| |
|
| | return logits_global_single |
| |
|
| | @st.cache_resource |
| | def build_model(): |
| | |
| | st.write('building model') |
| | clip_ckpt = 'model/config/clip' |
| | resume = 'SegVol_v1.pth' |
| | sam_model = sam_model_registry['vit']() |
| | segvol_model = SegVol( |
| | image_encoder=sam_model.image_encoder, |
| | mask_decoder=sam_model.mask_decoder, |
| | prompt_encoder=sam_model.prompt_encoder, |
| | clip_ckpt=clip_ckpt, |
| | roi_size=(32,256,256), |
| | patch_size=(4,16,16), |
| | test_mode=True, |
| | ) |
| | segvol_model = torch.nn.DataParallel(segvol_model) |
| | segvol_model.eval() |
| | |
| | if os.path.isfile(resume): |
| | |
| | loc = 'cpu' |
| | checkpoint = torch.load(resume, map_location=loc) |
| | segvol_model.load_state_dict(checkpoint['model'], strict=True) |
| | print("loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch'])) |
| | print('model build done!') |
| | return segvol_model |
| |
|
| | @st.cache_data |
| | def inference_case(_image, _image_zoom_out, _point_prompt, text_prompt, _box_prompt): |
| | |
| | args = set_parse() |
| | args.use_zoom_in = True |
| | args.use_text_prompt = text_prompt is not None |
| | args.use_box_prompt = _box_prompt is not None |
| | args.use_point_prompt = _point_prompt is not None |
| |
|
| | segvol_model = build_model() |
| |
|
| | |
| | logits = zoom_in_zoom_out( |
| | args, segvol_model, |
| | _image.unsqueeze(0), _image_zoom_out.unsqueeze(0), |
| | text_prompt, _point_prompt, _box_prompt) |
| | print(logits.shape) |
| | resize_transform = transforms.Compose([ |
| | transforms.AddChannel(), |
| | transforms.Resize((325,325,325), mode='trilinear') |
| | ] |
| | ) |
| | logits_resize = resize_transform(logits)[0] |
| | return (torch.sigmoid(logits_resize) > 0.5).int().numpy(), (torch.sigmoid(logits) > 0.5).int().numpy() |
| | |
| |
|