Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import json | |
| import gradio as gr | |
| from openai import OpenAI | |
| import uuid | |
| import time | |
| import io | |
| import yaml | |
| import random | |
| import traceback | |
| import tempfile | |
| from PIL import Image | |
| from PIL.PngImagePlugin import PngInfo | |
| from huggingface_hub import InferenceClient | |
| from google import genai | |
| from google.genai import types | |
| from .config import ( | |
| GEMINI_API_KEY, OLLAMA_HOST, OLLAMA_PORT, COMFY_URL, | |
| COMFY_WORKFLOW_FILE, PROMPTS_FILE, HF_TOKEN, HF_BASE_URL, | |
| HF_TEXT_MODEL, HF_IMAGE_MODEL, GEMINI_TEXT_MODEL, | |
| GEMINI_IMAGE_MODEL | |
| ) | |
| # Setup Gemini | |
| client = None | |
| gemini_active = False | |
| if GEMINI_API_KEY: | |
| try: | |
| client = genai.Client(api_key=GEMINI_API_KEY) | |
| gemini_active = True | |
| except Exception as e: | |
| gemini_active = False | |
| else: | |
| gemini_active = False | |
| # Setup Hugging Face Router | |
| hf_client = None | |
| hf_active = False | |
| if HF_TOKEN: | |
| try: | |
| hf_client = OpenAI( | |
| base_url=HF_BASE_URL, | |
| api_key=HF_TOKEN, | |
| ) | |
| hf_active = True | |
| except Exception as e: | |
| hf_active = False | |
| else: | |
| hf_active = False | |
| def load_system_prompt(key="refinement"): | |
| """Loads a system prompt from prompts.yaml.""" | |
| try: | |
| with open(PROMPTS_FILE, "r") as f: | |
| prompts = yaml.safe_load(f) | |
| return prompts.get(key, {}).get("system_instructions", "") | |
| except Exception as e: | |
| print(f"Error loading system prompt: {e}") | |
| return "" | |
| def get_ollama_models(): | |
| """Fetches available models from Ollama server and checks if it's running.""" | |
| url = f"http://{OLLAMA_HOST}:{OLLAMA_PORT}/api/tags" | |
| try: | |
| response = requests.get(url, timeout=2) | |
| if response.status_code == 200: | |
| models = response.json().get("models", []) | |
| return [m["name"] for m in models] | |
| return [] | |
| except Exception: | |
| return [] | |
| def check_comfy_availability(): | |
| """Checks if ComfyUI is running by pinging the URL.""" | |
| try: | |
| response = requests.get(f"{COMFY_URL}/system_stats", timeout=2) | |
| return response.status_code == 200 | |
| except Exception: | |
| return False | |
| def refine_with_gemini(prompt, mode="refinement"): | |
| if not gemini_active: | |
| return "Gemini API key not found in .env file." | |
| system_prompt = load_system_prompt(mode) | |
| if not system_prompt: | |
| system_prompt = ( | |
| "You are an expert prompt engineer for AI image generators. " | |
| "Your task is to take the provided technical prompt and refine it into a more vivid, " | |
| "artistic, and detailed description while maintaining all the core features." | |
| ) | |
| try: | |
| response = client.models.generate_content( | |
| model=GEMINI_TEXT_MODEL, | |
| config=types.GenerateContentConfig( | |
| system_instruction=system_prompt, | |
| temperature=0.7, | |
| ), | |
| contents=[prompt] | |
| ) | |
| return response.text.strip() | |
| except Exception as e: | |
| print(f"Gemini Refinement Error: {e}") | |
| return f"ERROR: {e}" | |
| def refine_with_ollama(prompt, model, mode="refinement"): | |
| """Refines the prompt using a local Ollama instance.""" | |
| system_prompt = load_system_prompt(mode) | |
| url = f"http://{OLLAMA_HOST}:{OLLAMA_PORT}/api/generate" | |
| payload = { | |
| "model": model, | |
| "prompt": f"{system_prompt}\n\nOriginal Prompt: {prompt}", | |
| "stream": False | |
| } | |
| try: | |
| response = requests.post(url, json=payload) | |
| response.raise_for_status() | |
| text = response.json().get("response", "").strip() | |
| # Clean up potential markdown | |
| if text.startswith("```"): | |
| lines = text.splitlines() | |
| if lines[0].startswith("```"): lines = lines[1:] | |
| if lines and lines[-1].startswith("```"): lines = lines[:-1] | |
| text = "\n".join(lines).strip() | |
| return text | |
| except Exception as e: | |
| print(f"Ollama Refinement Error: {e}") | |
| return f"ERROR: {e}" | |
| def refine_with_hf(prompt, model_id=None, provider=None, token=None, mode="refinement"): | |
| """Refines the prompt using Hugging Face Router (OpenAI compatible).""" | |
| active_client = hf_client | |
| # If a manual token is provided, create a temporary client | |
| if token: | |
| try: | |
| active_client = OpenAI( | |
| base_url=HF_BASE_URL, | |
| api_key=token, | |
| ) | |
| except Exception as e: | |
| return f"Error initializing manual HF Client: {e}" | |
| if not active_client: | |
| return "Error: Hugging Face token not found. Please log in or provide a token." | |
| system_prompt = load_system_prompt(mode) | |
| active_model = model_id if model_id else HF_TEXT_MODEL | |
| active_provider = provider if provider and provider.strip() else "auto" | |
| try: | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Original Prompt: {prompt}"} | |
| ] | |
| # Note: Provider for Chat Completions is currently handled by the route or specific model naming conventions. | |
| # But we pass it if the client supports it or for future use. | |
| response = active_client.chat.completions.create( | |
| model=active_model, | |
| messages=messages, | |
| max_tokens=2048, # Increased for Thinking models | |
| temperature=0.7, | |
| extra_body={"provider": active_provider} | |
| ) | |
| msg = response.choices[0].message | |
| content = getattr(msg, 'content', '') or '' | |
| reasoning = getattr(msg, 'reasoning', '') or getattr(msg, 'reasoning_content', '') or '' | |
| # If content is empty but reasoning has the meat, use reasoning | |
| # For Thinking models like Gemma 4, the response is often in 'reasoning' | |
| final_text = content if content.strip() else reasoning | |
| return final_text.strip() | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "provider you have enabled" in error_msg: | |
| error_msg = f"Hugging Face Provider Error: {error_msg}. Check your HF Space settings or Token permissions." | |
| print(f"HF Refinement Error: {error_msg}") | |
| return f"ERROR: {error_msg}" | |
| def refine_master(prompt, backend, ollama_model, hf_text_model, hf_text_provider, oauth_token=None, character_name=None): | |
| """Routes prompt refinement to the selected backend.""" | |
| if not prompt.strip(): | |
| return "" | |
| # Prioritizes manual token | |
| hf_token = oauth_token.strip() if oauth_token and oauth_token.strip() else None | |
| if backend == "Ollama (Local)": | |
| result = refine_with_ollama(prompt, ollama_model, mode="refinement") | |
| elif backend == "Hugging Face (Cloud)": | |
| result = refine_with_hf(prompt, hf_text_model, hf_text_provider, hf_token, mode="refinement") | |
| else: | |
| result = refine_with_gemini(prompt, mode="refinement") | |
| if isinstance(result, str) and result.startswith("ERROR:"): | |
| return None, f"⚠️ Refinement failed: {result.replace('ERROR:', '').strip()}" | |
| if result is None: | |
| return None, "⚠️ Refinement failed: Internal process error. Check logs." | |
| return result, "" | |
| def generate_name_master(prompt, backend, ollama_model, hf_text_model, hf_text_provider, oauth_token=None): | |
| """Generates a thematic name based on the current prompt context.""" | |
| if not prompt.strip(): | |
| return "Unnamed Hero" | |
| hf_token = oauth_token.strip() if oauth_token and oauth_token.strip() else None | |
| if backend == "Ollama (Local)": | |
| result = refine_with_ollama(prompt, ollama_model, mode="naming") | |
| elif backend == "Hugging Face (Cloud)": | |
| result = refine_with_hf(prompt, hf_text_model, hf_text_provider, hf_token, mode="naming") | |
| else: | |
| result = refine_with_gemini(prompt, mode="naming") | |
| return result if result else "Unnamed Hero" | |
| def generate_image_with_gemini(refined_prompt, technical_prompt, aspect_ratio, character_name="Unnamed Hero"): | |
| if not gemini_active: | |
| return None, None, "Gemini API key not found in .env file." | |
| final_prompt = refined_prompt.strip() if refined_prompt and refined_prompt.strip() else technical_prompt.strip() | |
| if not final_prompt: | |
| return None, None, "No prompt available for generation." | |
| try: | |
| response = client.models.generate_images( | |
| model=GEMINI_IMAGE_MODEL, | |
| prompt=final_prompt, | |
| config=types.GenerateImagesConfig( | |
| aspect_ratio=aspect_ratio, | |
| output_mime_type='image/png' | |
| ) | |
| ) | |
| if response.generated_images: | |
| img = Image.open(io.BytesIO(response.generated_images[0].image.image_bytes)) | |
| # Embed metadata | |
| metadata = PngInfo() | |
| metadata.add_text("Comment", final_prompt) | |
| metadata.add_text("CharacterName", character_name) | |
| safe_name = "".join([c if c.isalnum() else "_" for c in character_name]).strip("_") | |
| filename = f"{safe_name}_portrait_gemini.png" if safe_name else "rpg_portrait_gemini.png" | |
| temp_dir = tempfile.mkdtemp() | |
| img_path = os.path.join(temp_dir, filename) | |
| img.save(img_path, "PNG", pnginfo=metadata) | |
| return img, img_path, f"Image generated using {'refined' if refined_prompt.strip() else 'technical'} prompt!" | |
| return None, None, "Gemini Image generation did not return any images." | |
| except Exception as e: | |
| traceback.print_exc() | |
| return None, None, f"Image Generation Error: {e}" | |
| def generate_image_with_comfy(prompt, aspect_ratio, character_name="Unnamed Hero"): | |
| """Generates an image using a local ComfyUI instance.""" | |
| if not os.path.exists(COMFY_WORKFLOW_FILE): | |
| return None, None, f"Error: Workflow file {COMFY_WORKFLOW_FILE} not found." | |
| try: | |
| with open(COMFY_WORKFLOW_FILE, 'r') as f: | |
| workflow = json.load(f) | |
| workflow["6"]["inputs"]["text"] = prompt | |
| res_map = { | |
| "1:1": (1024, 1024), | |
| "16:9": (1344, 768), | |
| "9:16": (768, 1344), | |
| "4:3": (1152, 864), | |
| "3:4": (864, 1152) | |
| } | |
| width, height = res_map.get(aspect_ratio, (1024, 1024)) | |
| workflow["13"]["inputs"]["width"] = width | |
| workflow["13"]["inputs"]["height"] = height | |
| workflow["38"]["inputs"]["seed"] = random.randint(1, 1125899906842624) | |
| client_id = str(uuid.uuid4()) | |
| payload = {"prompt": workflow, "client_id": client_id} | |
| response = requests.post(f"{COMFY_URL}/prompt", json=payload) | |
| response.raise_for_status() | |
| prompt_id = response.json().get("prompt_id") | |
| max_retries = 60 | |
| for _ in range(max_retries): | |
| hist_resp = requests.get(f"{COMFY_URL}/history/{prompt_id}") | |
| if hist_resp.status_code == 200: | |
| history = hist_resp.json() | |
| if prompt_id in history: | |
| outputs = history[prompt_id].get("outputs", {}) | |
| for node_id in outputs: | |
| if "images" in outputs[node_id]: | |
| image_data = outputs[node_id]["images"][0] | |
| img_url = f"{COMFY_URL}/view?filename={image_data['filename']}&subfolder={image_data['subfolder']}&type={image_data['type']}" | |
| img_resp = requests.get(img_url) | |
| img_resp.raise_for_status() | |
| img = Image.open(io.BytesIO(img_resp.content)) | |
| # Embed metadata | |
| metadata = PngInfo() | |
| metadata.add_text("Comment", prompt) | |
| metadata.add_text("CharacterName", character_name) | |
| safe_name = "".join([c if c.isalnum() else "_" for c in character_name]).strip("_") | |
| filename = f"{safe_name}_portrait_comfy.png" if safe_name else "rpg_portrait_comfy.png" | |
| temp_dir = tempfile.mkdtemp() | |
| img_path = os.path.join(temp_dir, filename) | |
| img.save(img_path, "PNG", pnginfo=metadata) | |
| return img, img_path, f"Image generated via ComfyUI!" | |
| time.sleep(1) | |
| return None, None, "ComfyUI generation timed out." | |
| except Exception as e: | |
| traceback.print_exc() | |
| return None, None, f"ComfyUI Error: {e}" | |
| def generate_image_with_hf(prompt, aspect_ratio, model_id=None, provider=None, token=None, character_name="Unnamed Hero"): | |
| """Generates an image using Hugging Face Inference API.""" | |
| active_token = token if token else HF_TOKEN | |
| if not active_token: | |
| return None, None, "Error: Hugging Face token not found. Please log in or provide a token." | |
| active_model = model_id if model_id else HF_IMAGE_MODEL | |
| active_provider = provider if provider and provider.strip() else "auto" | |
| # Resolution mapping | |
| res_map = { | |
| "1:1": (1024, 1024), | |
| "16:9": (1344, 768), | |
| "9:16": (768, 1344), | |
| "4:3": (1152, 864), | |
| "3:4": (864, 1152) | |
| } | |
| width, height = res_map.get(aspect_ratio, (1024, 1024)) | |
| try: | |
| client = InferenceClient(api_key=active_token, provider=active_provider) | |
| # Use InferenceClient's built-in text_to_image which is more robust | |
| # width/height are supported by some providers (like fal-ai, replicate) | |
| img = client.text_to_image( | |
| prompt, | |
| model=active_model, | |
| width=width, | |
| height=height | |
| ) | |
| metadata = PngInfo() | |
| metadata.add_text("Comment", prompt) | |
| metadata.add_text("CharacterName", character_name) | |
| safe_name = "".join([c if c.isalnum() else "_" for c in character_name]).strip("_") | |
| filename = f"{safe_name}_portrait_hf.png" if safe_name else "rpg_portrait_hf.png" | |
| temp_dir = tempfile.mkdtemp() | |
| img_path = os.path.join(temp_dir, filename) | |
| img.save(img_path, "PNG", pnginfo=metadata) | |
| return img, img_path, f"Image generated via Hugging Face ({active_model})!" | |
| except Exception as e: | |
| traceback.print_exc() | |
| return None, None, f"Hugging Face Image Error: {e}" | |
| def generate_image_master(refined_prompt, technical_prompt, aspect_ratio, backend, hf_image_model, hf_image_provider, oauth_token=None, character_name="Unnamed Hero"): | |
| """Routes image generation to the selected backend.""" | |
| final_prompt = refined_prompt.strip() if refined_prompt.strip() else technical_prompt | |
| # Prioritizes manual token | |
| hf_token = oauth_token.strip() if oauth_token and oauth_token.strip() else None | |
| if backend == "ComfyUI (Local)": | |
| return generate_image_with_comfy(final_prompt, aspect_ratio, character_name) | |
| elif backend == "Hugging Face (Cloud)": | |
| return generate_image_with_hf(final_prompt, aspect_ratio, hf_image_model, hf_image_provider, hf_token, character_name) | |
| else: | |
| return generate_image_with_gemini(refined_prompt, technical_prompt, aspect_ratio, character_name) | |