|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
from PIL import Image |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
import tempfile |
|
|
from groq import Groq |
|
|
import os |
|
|
import nltk |
|
|
from nltk.translate.bleu_score import sentence_bleu |
|
|
import json |
|
|
import time |
|
|
|
|
|
|
|
|
nltk.download('punkt') |
|
|
|
|
|
|
|
|
client = Groq(api_key=os.getenv('GROQ_API_KEY')) |
|
|
|
|
|
|
|
|
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") |
|
|
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) |
|
|
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
pipe.enable_attention_slicing() |
|
|
|
|
|
|
|
|
caption_cache = {} |
|
|
qa_cache = {} |
|
|
history = [] |
|
|
|
|
|
def generate_caption(image, progress=gr.Progress()): |
|
|
try: |
|
|
if image is None: |
|
|
return "Please upload an image.", {} |
|
|
progress(0.2, "Processing image...") |
|
|
pil_image = Image.open(image) if isinstance(image, str) else image |
|
|
cache_key = hash(pil_image.tobytes()) |
|
|
if cache_key in caption_cache: |
|
|
return caption_cache[cache_key], {} |
|
|
caption = captioner(pil_image)[0]['generated_text'] |
|
|
enhanced_caption = f"A creative take: {caption}." |
|
|
metrics = {"length": len(enhanced_caption.split()), "unique_words": len(set(enhanced_caption.split()))} |
|
|
caption_cache[cache_key] = enhanced_caption |
|
|
history.append({"action": "caption", "time": time.time()}) |
|
|
progress(1.0, "Caption generated!") |
|
|
return enhanced_caption, metrics |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}", {} |
|
|
|
|
|
def generate_image_from_caption(caption, progress=gr.Progress()): |
|
|
try: |
|
|
progress(0.1, "Refining prompt...") |
|
|
image = pipe(caption, num_inference_steps=25, guidance_scale=7.5).images[0] |
|
|
progress(0.8, "Generating image...") |
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
image.save(temp_file.name) |
|
|
history.append({"action": "image_gen", "time": time.time()}) |
|
|
progress(1.0, "Image ready for download!") |
|
|
return image, temp_file.name |
|
|
except Exception as e: |
|
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
def answer_question(image, question, progress=gr.Progress()): |
|
|
try: |
|
|
if not question.strip(): |
|
|
return "Please enter a question.", {} |
|
|
progress(0.2, "Analyzing context...") |
|
|
start_time = time.time() |
|
|
context = "" |
|
|
if image is not None: |
|
|
pil_image = Image.open(image) if isinstance(image, str) else image |
|
|
caption_result = captioner(pil_image)[0]['generated_text'] |
|
|
context = f"Based on the image description: '{caption_result}'. " |
|
|
cache_key = (context, question) |
|
|
if cache_key in qa_cache: |
|
|
return qa_cache[cache_key], {} |
|
|
prompt = f"{context}Question: {question}\nAnswer:" |
|
|
progress(0.5, "Querying AI...") |
|
|
chat_completion = client.chat.completions.create( |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
model="llama-3.1-8b-instant", |
|
|
) |
|
|
answer = chat_completion.choices[0].message.content.strip() |
|
|
response_time = time.time() - start_time |
|
|
metrics = {"response_time": response_time, "length": len(answer.split())} |
|
|
qa_cache[cache_key] = answer |
|
|
history.append({"action": "qa", "time": time.time()}) |
|
|
progress(1.0, "Answer ready!") |
|
|
return answer, metrics |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}", {} |
|
|
|
|
|
def evaluate_caption(caption, reference="A sample reference caption for evaluation."): |
|
|
try: |
|
|
if not caption: |
|
|
return "No caption to evaluate." |
|
|
reference_tokens = nltk.word_tokenize(reference.lower()) |
|
|
candidate_tokens = nltk.word_tokenize(caption.lower()) |
|
|
bleu = sentence_bleu([reference_tokens], candidate_tokens) |
|
|
return f"BLEU Score: {bleu:.2f}, Length: {len(candidate_tokens)} words" |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def batch_caption(images): |
|
|
try: |
|
|
results = [] |
|
|
for img_path in images: |
|
|
if img_path: |
|
|
pil_image = Image.open(img_path) |
|
|
caption = captioner(pil_image)[0]['generated_text'] |
|
|
results.append(f"Image: {caption}") |
|
|
history.append({"action": "batch_caption", "time": time.time()}) |
|
|
return "\n".join(results) |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def generate_report(): |
|
|
try: |
|
|
total_interactions = len(history) |
|
|
avg_response_time = sum(h.get("response_time", 0) for h in history) / total_interactions if total_interactions > 0 else 0 |
|
|
report = { |
|
|
"total_interactions": total_interactions, |
|
|
"average_response_time": avg_response_time, |
|
|
"actions": [h["action"] for h in history] |
|
|
} |
|
|
return json.dumps(report, indent=2) |
|
|
except Exception as e: |
|
|
return f"Error generating report: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="ColabCraft: Advanced AI Image Assistant", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# π§ ColabCraft: Advanced AI Image Assistant |
|
|
**A Multimodal GenAI Project for Image Captioning, Q&A, and Generation** |
|
|
|
|
|
*Upload images, generate captions, ask questions, create images, and evaluate results. Built with Hugging Face, Stable Diffusion, and Groq's Llama 3.1 8B.* |
|
|
|
|
|
**Ethical Note:** This tool promotes positive AI use. Avoid uploading sensitive images. Citations: BLIP (Salesforce), Stable Diffusion (CompVis), Llama (Meta via Groq). |
|
|
""") |
|
|
|
|
|
|
|
|
image_input = gr.Image(type="pil", label="Upload Image (Shared for Captioning & Q&A)", elem_id="upload_img") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("πΈ Image Captioning", elem_id="caption_tab"): |
|
|
gr.Markdown("### Generate Creative Captions from Images") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
caption_output = gr.Textbox(label="Generated Caption", interactive=False) |
|
|
metrics_output = gr.JSON(label="Metrics") |
|
|
generate_btn = gr.Button("π Generate Caption", variant="primary") |
|
|
generate_btn.click(generate_caption, inputs=image_input, outputs=[caption_output, metrics_output]) |
|
|
|
|
|
|
|
|
|
|
|
with gr.TabItem("β Q&A Assistant", elem_id="qa_tab"): |
|
|
gr.Markdown("### Ask Questions About Images or General Topics") |
|
|
with gr.Row(): |
|
|
question_input = gr.Textbox(label="Enter Question", placeholder="e.g., What is in the image?") |
|
|
answer_output = gr.Textbox(label="AI Answer", interactive=False) |
|
|
qa_metrics = gr.JSON(label="Metrics") |
|
|
ask_btn = gr.Button("π Get Answer", variant="primary") |
|
|
ask_btn.click(answer_question, inputs=[image_input, question_input], outputs=[answer_output, qa_metrics]) |
|
|
|
|
|
with gr.TabItem("π¨ Image Generation", elem_id="gen_tab"): |
|
|
gr.Markdown("### Create Images from Text Captions") |
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox(label="Enter Caption for Generation", placeholder="e.g., A sunny beach with palm trees") |
|
|
image_output = gr.Image(label="Generated Image") |
|
|
download_file = gr.File(label="π₯ Download Image") |
|
|
generate_img_btn = gr.Button("πΌοΈ Generate Image", variant="primary") |
|
|
generate_img_btn.click(generate_image_from_caption, inputs=text_input, outputs=[image_output, download_file]) |
|
|
|
|
|
with gr.TabItem("π Evaluation & Batch", elem_id="eval_tab"): |
|
|
gr.Markdown("### Evaluate Captions and Process Batches") |
|
|
with gr.Row(): |
|
|
eval_caption_input = gr.Textbox(label="Caption to Evaluate") |
|
|
eval_output = gr.Textbox(label="Evaluation Results", interactive=False) |
|
|
eval_btn = gr.Button("π Evaluate") |
|
|
eval_btn.click(evaluate_caption, inputs=eval_caption_input, outputs=eval_output) |
|
|
|
|
|
gr.Markdown("### Batch Captioning") |
|
|
batch_input = gr.File(file_count="multiple", label="Upload Multiple Images") |
|
|
batch_output = gr.Textbox(label="Batch Results", interactive=False, lines=10) |
|
|
batch_btn = gr.Button("π Process Batch") |
|
|
batch_btn.click(batch_caption, inputs=batch_input, outputs=batch_output) |
|
|
|
|
|
with gr.TabItem("π Report & Help", elem_id="report_tab"): |
|
|
gr.Markdown("### Project Report & Help") |
|
|
report_output = gr.Textbox(label="Generated Report", interactive=False, lines=10) |
|
|
report_btn = gr.Button("π Generate Report") |
|
|
report_btn.click(generate_report, inputs=[], outputs=report_output) |
|
|
|
|
|
gr.Markdown(""" |
|
|
**Help & Features:** |
|
|
- **Captioning:** Uses BLIP for accurate descriptions. |
|
|
- **Q&A:** Powered by Llama 3.1 8B via Groq for contextual answers. |
|
|
- **Generation:** Stable Diffusion for high-quality images. |
|
|
- **Evaluation:** BLEU scores for caption quality. |
|
|
- **Batch:** Process multiple images at once. |
|
|
- **Report:** Summarizes usage metrics. |
|
|
|
|
|
**For Submission:** Export notebook as PDF. Include demo video and metrics in report. |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |