| | """ |
| | Demonstrates integrating Rerun visualization with Gradio and HF ZeroGPU. |
| | """ |
| |
|
| | import uuid |
| | import gradio as gr |
| | import rerun as rr |
| | import rerun.blueprint as rrb |
| | from gradio_rerun import Rerun |
| | import spaces |
| | from transformers import DetrImageProcessor, DetrForObjectDetection |
| | import torch |
| |
|
| | processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") |
| | model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") |
| |
|
| |
|
| | |
| | |
| | |
| | def get_recording(recording_id: str) -> rr.RecordingStream: |
| | return rr.RecordingStream( |
| | application_id="rerun_example_gradio", recording_id=recording_id |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | @spaces.GPU |
| | def streaming_object_detection(recording_id: str, img): |
| | |
| | rec = get_recording(recording_id) |
| | stream = rec.binary_stream() |
| |
|
| | if img is None: |
| | raise gr.Error("Must provide an image detect objects in.") |
| |
|
| | blueprint = rrb.Blueprint( |
| | rrb.Horizontal( |
| | rrb.Spatial2DView(origin="image"), |
| | ), |
| | collapse_panels=True, |
| | ) |
| |
|
| | rec.send_blueprint(blueprint) |
| | rec.set_time("iteration", sequence=0) |
| | rec.log("image", rr.Image(img)) |
| | yield stream.read() |
| |
|
| | with torch.inference_mode(): |
| | inputs = processor(images=img, return_tensors="pt") |
| | outputs = model(**inputs) |
| |
|
| | |
| | |
| | height, width = img.shape[:2] |
| | target_sizes = torch.tensor([[height, width]]) |
| | results = processor.post_process_object_detection( |
| | outputs, target_sizes=target_sizes, threshold=0.85 |
| | )[0] |
| |
|
| | rec.log( |
| | "image/objects", |
| | rr.Boxes2D( |
| | array=results["boxes"], |
| | array_format=rr.Box2DFormat.XYXY, |
| | labels=[model.config.id2label[label.item()] for label in results["labels"]], |
| | colors=[ |
| | ( |
| | label.item() * 50 % 255, |
| | (label.item() * 80 + 40) % 255, |
| | (label.item() * 120 + 100) % 255, |
| | ) |
| | for label in results["labels"] |
| | ], |
| | ), |
| | ) |
| |
|
| | |
| | stream.flush() |
| | yield stream.read() |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | with gr.Accordion("Your image", open=True): |
| | img = gr.Image(interactive=True, label="Image") |
| | detect_objects = gr.Button("Detect objects") |
| |
|
| | with gr.Column(scale=4): |
| | viewer = Rerun( |
| | streaming=True, |
| | panel_states={ |
| | "time": "collapsed", |
| | "blueprint": "hidden", |
| | "selection": "hidden", |
| | }, |
| | height=700, |
| | ) |
| |
|
| | |
| | recording_id = gr.State(uuid.uuid4()) |
| |
|
| | |
| | |
| | detect_objects.click( |
| | |
| | streaming_object_detection, |
| | inputs=[recording_id, img], |
| | outputs=[viewer], |
| | ) |
| | if __name__ == "__main__": |
| | demo.launch(ssr_mode=False) |
| |
|