import gradio as gr from transformers import pipeline from PIL import Image import torch from diffusers import StableDiffusionPipeline import tempfile from groq import Groq import os # Replaced google.colab with os for environment variable access import nltk from nltk.translate.bleu_score import sentence_bleu import json import time # Download NLTK data for BLEU nltk.download('punkt') # Initialize Groq client client = Groq(api_key=os.getenv('GROQ_API_KEY')) # Updated to use os.getenv # Load models captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") pipe.enable_attention_slicing() # Caching for performance caption_cache = {} qa_cache = {} history = [] # Global history for report def generate_caption(image, progress=gr.Progress()): try: if image is None: return "Please upload an image.", {} progress(0.2, "Processing image...") pil_image = Image.open(image) if isinstance(image, str) else image cache_key = hash(pil_image.tobytes()) if cache_key in caption_cache: return caption_cache[cache_key], {} caption = captioner(pil_image)[0]['generated_text'] enhanced_caption = f"A creative take: {caption}." metrics = {"length": len(enhanced_caption.split()), "unique_words": len(set(enhanced_caption.split()))} caption_cache[cache_key] = enhanced_caption history.append({"action": "caption", "time": time.time()}) progress(1.0, "Caption generated!") return enhanced_caption, metrics except Exception as e: return f"Error: {str(e)}", {} def generate_image_from_caption(caption, progress=gr.Progress()): try: progress(0.1, "Refining prompt...") image = pipe(caption, num_inference_steps=25, guidance_scale=7.5).images[0] progress(0.8, "Generating image...") temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') image.save(temp_file.name) history.append({"action": "image_gen", "time": time.time()}) progress(1.0, "Image ready for download!") return image, temp_file.name except Exception as e: return None, f"Error: {str(e)}" def answer_question(image, question, progress=gr.Progress()): try: if not question.strip(): return "Please enter a question.", {} progress(0.2, "Analyzing context...") start_time = time.time() context = "" if image is not None: pil_image = Image.open(image) if isinstance(image, str) else image caption_result = captioner(pil_image)[0]['generated_text'] context = f"Based on the image description: '{caption_result}'. " cache_key = (context, question) if cache_key in qa_cache: return qa_cache[cache_key], {} prompt = f"{context}Question: {question}\nAnswer:" progress(0.5, "Querying AI...") chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama-3.1-8b-instant", ) answer = chat_completion.choices[0].message.content.strip() response_time = time.time() - start_time metrics = {"response_time": response_time, "length": len(answer.split())} qa_cache[cache_key] = answer history.append({"action": "qa", "time": time.time()}) progress(1.0, "Answer ready!") return answer, metrics except Exception as e: return f"Error: {str(e)}", {} def evaluate_caption(caption, reference="A sample reference caption for evaluation."): try: if not caption: return "No caption to evaluate." reference_tokens = nltk.word_tokenize(reference.lower()) candidate_tokens = nltk.word_tokenize(caption.lower()) bleu = sentence_bleu([reference_tokens], candidate_tokens) return f"BLEU Score: {bleu:.2f}, Length: {len(candidate_tokens)} words" except Exception as e: return f"Error: {str(e)}" def batch_caption(images): try: results = [] for img_path in images: if img_path: pil_image = Image.open(img_path) caption = captioner(pil_image)[0]['generated_text'] results.append(f"Image: {caption}") history.append({"action": "batch_caption", "time": time.time()}) return "\n".join(results) except Exception as e: return f"Error: {str(e)}" def generate_report(): try: total_interactions = len(history) avg_response_time = sum(h.get("response_time", 0) for h in history) / total_interactions if total_interactions > 0 else 0 report = { "total_interactions": total_interactions, "average_response_time": avg_response_time, "actions": [h["action"] for h in history] } return json.dumps(report, indent=2) except Exception as e: return f"Error generating report: {str(e)}" # Gradio UI with enhancements with gr.Blocks(title="ColabCraft: Advanced AI Image Assistant", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🧠 ColabCraft: Advanced AI Image Assistant **A Multimodal GenAI Project for Image Captioning, Q&A, and Generation** *Upload images, generate captions, ask questions, create images, and evaluate results. Built with Hugging Face, Stable Diffusion, and Groq's Llama 3.1 8B.* **Ethical Note:** This tool promotes positive AI use. Avoid uploading sensitive images. Citations: BLIP (Salesforce), Stable Diffusion (CompVis), Llama (Meta via Groq). """) # Shared image input image_input = gr.Image(type="pil", label="Upload Image (Shared for Captioning & Q&A)", elem_id="upload_img") with gr.Tabs(): with gr.TabItem("📸 Image Captioning", elem_id="caption_tab"): gr.Markdown("### Generate Creative Captions from Images") with gr.Row(): with gr.Column(): caption_output = gr.Textbox(label="Generated Caption", interactive=False) metrics_output = gr.JSON(label="Metrics") generate_btn = gr.Button("🚀 Generate Caption", variant="primary") generate_btn.click(generate_caption, inputs=image_input, outputs=[caption_output, metrics_output]) # Removed the problematic gr.Examples line with gr.TabItem("❓ Q&A Assistant", elem_id="qa_tab"): gr.Markdown("### Ask Questions About Images or General Topics") with gr.Row(): question_input = gr.Textbox(label="Enter Question", placeholder="e.g., What is in the image?") answer_output = gr.Textbox(label="AI Answer", interactive=False) qa_metrics = gr.JSON(label="Metrics") ask_btn = gr.Button("🔍 Get Answer", variant="primary") ask_btn.click(answer_question, inputs=[image_input, question_input], outputs=[answer_output, qa_metrics]) with gr.TabItem("🎨 Image Generation", elem_id="gen_tab"): gr.Markdown("### Create Images from Text Captions") with gr.Row(): text_input = gr.Textbox(label="Enter Caption for Generation", placeholder="e.g., A sunny beach with palm trees") image_output = gr.Image(label="Generated Image") download_file = gr.File(label="📥 Download Image") generate_img_btn = gr.Button("🖼️ Generate Image", variant="primary") generate_img_btn.click(generate_image_from_caption, inputs=text_input, outputs=[image_output, download_file]) with gr.TabItem("📊 Evaluation & Batch", elem_id="eval_tab"): gr.Markdown("### Evaluate Captions and Process Batches") with gr.Row(): eval_caption_input = gr.Textbox(label="Caption to Evaluate") eval_output = gr.Textbox(label="Evaluation Results", interactive=False) eval_btn = gr.Button("📈 Evaluate") eval_btn.click(evaluate_caption, inputs=eval_caption_input, outputs=eval_output) gr.Markdown("### Batch Captioning") batch_input = gr.File(file_count="multiple", label="Upload Multiple Images") batch_output = gr.Textbox(label="Batch Results", interactive=False, lines=10) batch_btn = gr.Button("🔄 Process Batch") batch_btn.click(batch_caption, inputs=batch_input, outputs=batch_output) with gr.TabItem("📋 Report & Help", elem_id="report_tab"): gr.Markdown("### Project Report & Help") report_output = gr.Textbox(label="Generated Report", interactive=False, lines=10) report_btn = gr.Button("📄 Generate Report") report_btn.click(generate_report, inputs=[], outputs=report_output) gr.Markdown(""" **Help & Features:** - **Captioning:** Uses BLIP for accurate descriptions. - **Q&A:** Powered by Llama 3.1 8B via Groq for contextual answers. - **Generation:** Stable Diffusion for high-quality images. - **Evaluation:** BLEU scores for caption quality. - **Batch:** Process multiple images at once. - **Report:** Summarizes usage metrics. **For Submission:** Export notebook as PDF. Include demo video and metrics in report. """) if __name__ == "__main__": demo.launch(share=True)