| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from PIL import Image, ImageEnhance, ImageDraw |
| | import torch |
| | import streamlit as st |
| | from model.inference_cpu import inference_case |
| |
|
| | initial_rectangle = { |
| | "version": "4.4.0", |
| | 'objects': [ |
| | { |
| | "type": "rect", |
| | "version": "4.4.0", |
| | "originX": "left", |
| | "originY": "top", |
| | "left": 50, |
| | "top": 50, |
| | "width": 100, |
| | "height": 100, |
| | 'fill': 'rgba(255, 165, 0, 0.3)', |
| | 'stroke': '#2909F1', |
| | 'strokeWidth': 3, |
| | 'strokeDashArray': None, |
| | 'strokeLineCap': 'butt', |
| | 'strokeDashOffset': 0, |
| | 'strokeLineJoin': 'miter', |
| | 'strokeUniform': True, |
| | 'strokeMiterLimit': 4, |
| | 'scaleX': 1, |
| | 'scaleY': 1, |
| | 'angle': 0, |
| | 'flipX': False, |
| | 'flipY': False, |
| | 'opacity': 1, |
| | 'shadow': None, |
| | 'visible': True, |
| | 'backgroundColor': '', |
| | 'fillRule': |
| | 'nonzero', |
| | 'paintFirst': |
| | 'fill', |
| | 'globalCompositeOperation': 'source-over', |
| | 'skewX': 0, |
| | 'skewY': 0, |
| | 'rx': 0, |
| | 'ry': 0 |
| | } |
| | ] |
| | } |
| |
|
| | def run(): |
| | image = st.session_state.data_item["image"].float() |
| | image_zoom_out = st.session_state.data_item["zoom_out_image"].float() |
| | text_prompt = None |
| | point_prompt = None |
| | box_prompt = None |
| | if st.session_state.use_text_prompt: |
| | text_prompt = st.session_state.text_prompt |
| | if st.session_state.use_point_prompt and len(st.session_state.points) > 0: |
| | point_prompt = reflect_points_into_model(st.session_state.points) |
| | if st.session_state.use_box_prompt: |
| | box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox) |
| | inference_case.clear() |
| | st.write("text_prompt: {}".format(text_prompt)) |
| | st.write("box_prompt: {}".format(box_prompt)) |
| | st.write("point_prompt: {}".format(point_prompt)) |
| | st.write("image shape: {}".format(image.shape)) |
| | st.session_state.preds_3D, st.session_state.preds_3D_ori = inference_case(image, image_zoom_out, |
| | text_prompt=text_prompt, |
| | _point_prompt=point_prompt, |
| | _box_prompt=box_prompt) |
| |
|
| | def reflect_box_into_model(box_3d): |
| | z1, y1, x1, z2, y2, x2 = box_3d |
| | x1_prompt = int(x1 * 256.0 / 325.0) |
| | y1_prompt = int(y1 * 256.0 / 325.0) |
| | z1_prompt = int(z1 * 32.0 / 325.0) |
| | x2_prompt = int(x2 * 256.0 / 325.0) |
| | y2_prompt = int(y2 * 256.0 / 325.0) |
| | z2_prompt = int(z2 * 32.0 / 325.0) |
| | return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt])) |
| |
|
| | def reflect_json_data_to_3D_box(json_data, view): |
| | if view == 'xy': |
| | st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top'] |
| | st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left'] |
| | st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY'] |
| | st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX'] |
| | print(st.session_state.rectangle_3Dbox) |
| |
|
| | def reflect_points_into_model(points): |
| | points_prompt_list = [] |
| | for point in points: |
| | z, y, x = point |
| | x_prompt = int(x * 256.0 / 325.0) |
| | y_prompt = int(y * 256.0 / 325.0) |
| | z_prompt = int(z * 32.0 / 325.0) |
| | points_prompt_list.append([z_prompt, y_prompt, x_prompt]) |
| | points_prompt = np.array(points_prompt_list) |
| | points_label = np.ones(points_prompt.shape[0]) |
| | print(points_prompt, points_label) |
| | return (torch.tensor(points_prompt), torch.tensor(points_label)) |
| |
|
| | def show_points(points_ax, points_label, ax): |
| | color = 'red' if points_label == 0 else 'blue' |
| | ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200) |
| |
|
| | def make_fig(image, preds, point_axs=None, current_idx=None, view=None): |
| | |
| | image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB") |
| | enhancer = ImageEnhance.Contrast(image) |
| | image = enhancer.enhance(2.0) |
| |
|
| | |
| | if preds is not None: |
| | mask = np.where(preds == 1, 255, 0).astype(np.uint8) |
| | mask = Image.merge("RGB", |
| | (Image.fromarray(mask), |
| | Image.fromarray(mask), |
| | Image.fromarray(np.zeros_like(mask, dtype=np.uint8)))) |
| |
|
| | |
| | image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency) |
| | |
| | if point_axs is not None: |
| | draw = ImageDraw.Draw(image) |
| | radius = 5 |
| | for point in point_axs: |
| | z, y, x = point |
| | if view == 'xy' and z == current_idx: |
| | draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue") |
| | elif view == 'xz'and y == current_idx: |
| | draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue") |
| | return image |