| import torch |
| from transformers import pipeline |
| import numpy as np |
| import gradio as gr |
| from PIL import Image, ImageDraw |
| import scipy.io.wavfile as wavfile |
| import tempfile |
|
|
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| narrator = pipeline( |
| 'text-to-speech', |
| model="facebook/mms-tts-eng", |
| device=device |
| ) |
|
|
| the_detector = pipeline( |
| "object-detection", |
| model="facebook/detr-resnet-50", |
| device=device |
| ) |
|
|
| |
| def draw_bounding_boxes(image, detections): |
| draw_image = image.copy() |
| draw = ImageDraw.Draw(draw_image) |
|
|
| for detection in detections: |
| box = detection['box'] |
| xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax'] |
| |
| |
| draw.ellipse([(xmin, ymin), (xmax, ymax)], outline='red', width=3) |
|
|
| label = detection['label'] |
| score = detection['score'] |
| text = f"{label} {score:.2f}" |
|
|
| |
| draw.rectangle([(xmin, ymin - 15), (xmin + 80, ymin)], fill='red') |
| draw.text((xmin + 2, ymin - 15), text, fill='white') |
|
|
| return draw_image |
|
|
| |
| def generate_audio(text): |
| audio_data = narrator(text) |
| |
| waveform = audio_data['audio'] |
| sampling_rate = audio_data['sampling_rate'] |
| |
| |
| waveform = np.squeeze(waveform) |
| |
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
| wavfile.write(temp_file.name, sampling_rate, waveform) |
| |
| return temp_file.name |
|
|
| |
| def read_objects(detection_objects): |
| if not detection_objects: |
| return "I couldn't find any objects in this picture." |
| |
| object_counts = {} |
| for detection in detection_objects: |
| label = detection['label'] |
| object_counts[label] = object_counts.get(label, 0) + 1 |
|
|
| response = 'This picture contains' |
| labels = list(object_counts.keys()) |
| |
| for i, label in enumerate(labels): |
| count = object_counts[label] |
| plural_label = f"{label}s" if count > 1 else label |
| |
| if i == len(labels) - 1 and len(labels) > 1: |
| response += f" and {count} {plural_label}." |
| else: |
| response += f" {count} {plural_label}" + ("," if len(labels) > 2 and i < len(labels)-2 else "") |
| |
| if len(labels) == 1: response += "." |
| return response |
|
|
| |
| def collaborator(img): |
| output = the_detector(img) |
| gen_image = draw_bounding_boxes(image=img, detections=output) |
| natural_text = read_objects(output) |
| audio_path = generate_audio(natural_text) |
| return gen_image, audio_path |
|
|
| |
| demo = gr.Interface( |
| fn=collaborator, |
| inputs=gr.Image(label='Upload Image', type='pil'), |
| outputs=[ |
| gr.Image(label='Detected Objects', type='pil'), |
| gr.Audio(label='Description Audio') |
| ], |
| title='VisionTalk: Object Detector with Audio', |
| description='Upload an image to see what objects are inside and hear a generated description.' |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |