ariffurrahman's picture
Create app.py
39f0297 verified
Raw
History Blame Contribute Delete
3.34 kB
import torch
from transformers import pipeline
import numpy as np
import gradio as gr
from PIL import Image, ImageDraw
import scipy.io.wavfile as wavfile
import tempfile
# 1. Loading Models
# Use 'cuda' if you have a GPU, otherwise 'cpu'
device = "cuda" if torch.cuda.is_available() else "cpu"
narrator = pipeline(
'text-to-speech', # Note: 'text-to-audio' works, but 'text-to-speech' is standard
model="facebook/mms-tts-eng",
device=device
)
the_detector = pipeline(
"object-detection",
model="facebook/detr-resnet-50",
device=device
)
# 2. The Boundary Maker
def draw_bounding_boxes(image, detections):
draw_image = image.copy()
draw = ImageDraw.Draw(draw_image)
for detection in detections:
box = detection['box']
xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
# Draw elliptical boundary
draw.ellipse([(xmin, ymin), (xmax, ymax)], outline='red', width=3)
label = detection['label']
score = detection['score']
text = f"{label} {score:.2f}"
# Draw a small background for the text
draw.rectangle([(xmin, ymin - 15), (xmin + 80, ymin)], fill='red')
draw.text((xmin + 2, ymin - 15), text, fill='white')
return draw_image
# 3. The Audio Artist (Fixed for Gradio Compatibility)
def generate_audio(text):
audio_data = narrator(text)
waveform = audio_data['audio']
sampling_rate = audio_data['sampling_rate']
# Standardize waveform shape
waveform = np.squeeze(waveform)
# Save to a temporary file for Gradio to read
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
wavfile.write(temp_file.name, sampling_rate, waveform)
return temp_file.name
# 4. Generate Natural Text
def read_objects(detection_objects):
if not detection_objects:
return "I couldn't find any objects in this picture."
object_counts = {}
for detection in detection_objects:
label = detection['label']
object_counts[label] = object_counts.get(label, 0) + 1
response = 'This picture contains'
labels = list(object_counts.keys())
for i, label in enumerate(labels):
count = object_counts[label]
plural_label = f"{label}s" if count > 1 else label
if i == len(labels) - 1 and len(labels) > 1:
response += f" and {count} {plural_label}."
else:
response += f" {count} {plural_label}" + ("," if len(labels) > 2 and i < len(labels)-2 else "")
if len(labels) == 1: response += "."
return response
# 5. The Collaborator
def collaborator(img):
output = the_detector(img)
gen_image = draw_bounding_boxes(image=img, detections=output)
natural_text = read_objects(output)
audio_path = generate_audio(natural_text)
return gen_image, audio_path
# 6. UI Interface
demo = gr.Interface(
fn=collaborator,
inputs=gr.Image(label='Upload Image', type='pil'),
outputs=[
gr.Image(label='Detected Objects', type='pil'),
gr.Audio(label='Description Audio')
],
title='VisionTalk: Object Detector with Audio',
description='Upload an image to see what objects are inside and hear a generated description.'
)
if __name__ == "__main__":
demo.launch()