| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import json |
| | import base64 |
| | import io |
| | from PIL import Image |
| | import svgwrite |
| | from typing import Dict, Any, List, Optional, Union |
| | import diffusers |
| | from diffusers import StableDiffusionPipeline, DDIMScheduler |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| | import torchvision.transforms as transforms |
| | from torchvision.transforms.functional import to_pil_image |
| | import random |
| | import math |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model_id = "runwayml/stable-diffusion-v1-5" |
| | |
| | try: |
| | |
| | self.pipe = StableDiffusionPipeline.from_pretrained( |
| | self.model_id, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | safety_checker=None, |
| | requires_safety_checker=False |
| | ).to(self.device) |
| | |
| | |
| | self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
| | |
| | |
| | self.clip_model = self.pipe.text_encoder |
| | self.clip_tokenizer = self.pipe.tokenizer |
| | |
| | print("DiffSketcher handler initialized successfully!") |
| | except Exception as e: |
| | print(f"Warning: Could not load diffusion model: {e}") |
| | self.pipe = None |
| | self.clip_model = None |
| | self.clip_tokenizer = None |
| |
|
| | def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image: |
| | """ |
| | Generate SVG sketch from text prompt using DiffSketcher approach |
| | """ |
| | try: |
| | |
| | if isinstance(inputs, str): |
| | prompt = inputs |
| | parameters = {} |
| | else: |
| | prompt = inputs.get("inputs", inputs.get("prompt", "a simple sketch")) |
| | parameters = inputs.get("parameters", {}) |
| | |
| | |
| | num_paths = parameters.get("num_paths", 64) |
| | num_iter = parameters.get("num_iter", 500) |
| | width = parameters.get("width", 224) |
| | height = parameters.get("height", 224) |
| | guidance_scale = parameters.get("guidance_scale", 7.5) |
| | seed = parameters.get("seed", None) |
| | |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| | |
| | print(f"Generating sketch for: '{prompt}' with {num_paths} paths") |
| | |
| | |
| | svg_content, metadata = self.generate_diffsketcher_svg( |
| | prompt, width, height, num_paths, num_iter, guidance_scale |
| | ) |
| | |
| | |
| | pil_image = self.svg_to_pil_image(svg_content, width, height) |
| | |
| | |
| | pil_image.info['svg_content'] = svg_content |
| | pil_image.info['prompt'] = prompt |
| | pil_image.info['parameters'] = json.dumps(parameters) |
| | pil_image.info['num_paths'] = str(num_paths) |
| | pil_image.info['method'] = 'diffsketcher' |
| | |
| | return pil_image |
| | |
| | except Exception as e: |
| | print(f"Error in DiffSketcher handler: {e}") |
| | |
| | fallback_svg = self.create_fallback_svg(prompt if 'prompt' in locals() else "error", 224, 224) |
| | fallback_image = self.svg_to_pil_image(fallback_svg, 224, 224) |
| | fallback_image.info['error'] = str(e) |
| | return fallback_image |
| |
|
| | def generate_diffsketcher_svg(self, prompt: str, width: int, height: int, |
| | num_paths: int, num_iter: int, guidance_scale: float): |
| | """ |
| | Generate SVG using DiffSketcher-inspired approach with diffusion guidance |
| | """ |
| | |
| | text_embeddings = self.get_text_embeddings(prompt) |
| | |
| | |
| | paths = self.initialize_paths(num_paths, width, height) |
| | |
| | |
| | optimized_paths = self.optimize_paths_with_diffusion( |
| | paths, text_embeddings, prompt, width, height, num_iter, guidance_scale |
| | ) |
| | |
| | |
| | svg_content = self.paths_to_svg(optimized_paths, width, height) |
| | |
| | metadata = { |
| | "method": "diffsketcher", |
| | "prompt": prompt, |
| | "num_paths": num_paths, |
| | "num_iter": num_iter, |
| | "guidance_scale": guidance_scale, |
| | "width": width, |
| | "height": height |
| | } |
| | |
| | return svg_content, metadata |
| |
|
| | def get_text_embeddings(self, prompt: str): |
| | """Get CLIP text embeddings for the prompt""" |
| | if self.clip_model is None or self.clip_tokenizer is None: |
| | |
| | return torch.zeros((2, 77, 768)) |
| | |
| | try: |
| | with torch.no_grad(): |
| | text_inputs = self.clip_tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=self.clip_tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | text_embeddings = self.clip_model(text_inputs.input_ids)[0] |
| | |
| | |
| | uncond_inputs = self.clip_tokenizer( |
| | "", |
| | padding="max_length", |
| | max_length=self.clip_tokenizer.model_max_length, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | uncond_embeddings = self.clip_model(uncond_inputs.input_ids)[0] |
| | |
| | |
| | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| | |
| | return text_embeddings |
| | except Exception as e: |
| | print(f"Error getting text embeddings: {e}") |
| | return torch.zeros((2, 77, 768)) |
| |
|
| | def initialize_paths(self, num_paths: int, width: int, height: int): |
| | """Initialize random Bezier paths""" |
| | paths = [] |
| | |
| | for i in range(num_paths): |
| | |
| | start_x = random.uniform(0.1 * width, 0.9 * width) |
| | start_y = random.uniform(0.1 * height, 0.9 * height) |
| | |
| | |
| | cp1_x = start_x + random.uniform(-width*0.2, width*0.2) |
| | cp1_y = start_y + random.uniform(-height*0.2, height*0.2) |
| | cp2_x = start_x + random.uniform(-width*0.2, width*0.2) |
| | cp2_y = start_y + random.uniform(-height*0.2, height*0.2) |
| | |
| | |
| | end_x = start_x + random.uniform(-width*0.3, width*0.3) |
| | end_y = start_y + random.uniform(-height*0.3, height*0.3) |
| | |
| | |
| | cp1_x = max(0, min(width, cp1_x)) |
| | cp1_y = max(0, min(height, cp1_y)) |
| | cp2_x = max(0, min(width, cp2_x)) |
| | cp2_y = max(0, min(height, cp2_y)) |
| | end_x = max(0, min(width, end_x)) |
| | end_y = max(0, min(height, end_y)) |
| | |
| | |
| | color_intensity = random.uniform(0.1, 0.7) |
| | color = ( |
| | int(color_intensity * 255), |
| | int(color_intensity * 255), |
| | int(color_intensity * 255) |
| | ) |
| | |
| | |
| | stroke_width = random.uniform(0.5, 3.0) |
| | |
| | path = { |
| | 'start': (start_x, start_y), |
| | 'cp1': (cp1_x, cp1_y), |
| | 'cp2': (cp2_x, cp2_y), |
| | 'end': (end_x, end_y), |
| | 'color': color, |
| | 'stroke_width': stroke_width, |
| | 'opacity': random.uniform(0.3, 0.8) |
| | } |
| | paths.append(path) |
| | |
| | return paths |
| |
|
| | def optimize_paths_with_diffusion(self, paths: List[Dict], text_embeddings: torch.Tensor, |
| | prompt: str, width: int, height: int, |
| | num_iter: int, guidance_scale: float): |
| | """ |
| | Optimize paths using diffusion model guidance (simplified approach) |
| | """ |
| | |
| | semantic_features = self.extract_semantic_features(prompt) |
| | |
| | |
| | for iteration in range(min(num_iter // 10, 50)): |
| | |
| | paths = self.apply_semantic_guidance(paths, semantic_features, width, height) |
| | |
| | |
| | if iteration % 5 == 0: |
| | paths = self.apply_aesthetic_refinement(paths, width, height) |
| | |
| | return paths |
| |
|
| | def extract_semantic_features(self, prompt: str): |
| | """Extract semantic features from prompt to guide path generation""" |
| | |
| | features = { |
| | 'complexity': 'medium', |
| | 'style': 'sketch', |
| | 'density': 'medium', |
| | 'organic': False, |
| | 'geometric': False, |
| | 'detailed': False |
| | } |
| | |
| | prompt_lower = prompt.lower() |
| | |
| | |
| | complex_words = ['detailed', 'intricate', 'complex', 'elaborate'] |
| | simple_words = ['simple', 'minimal', 'basic', 'clean'] |
| | |
| | if any(word in prompt_lower for word in complex_words): |
| | features['complexity'] = 'high' |
| | features['detailed'] = True |
| | elif any(word in prompt_lower for word in simple_words): |
| | features['complexity'] = 'low' |
| | |
| | |
| | if any(word in prompt_lower for word in ['sketch', 'drawing', 'pencil', 'charcoal']): |
| | features['style'] = 'sketch' |
| | elif any(word in prompt_lower for word in ['painting', 'artistic', 'painted']): |
| | features['style'] = 'artistic' |
| | |
| | |
| | organic_words = ['tree', 'flower', 'animal', 'person', 'face', 'natural', 'organic'] |
| | geometric_words = ['building', 'house', 'geometric', 'square', 'circle', 'triangle'] |
| | |
| | if any(word in prompt_lower for word in organic_words): |
| | features['organic'] = True |
| | if any(word in prompt_lower for word in geometric_words): |
| | features['geometric'] = True |
| | |
| | return features |
| |
|
| | def apply_semantic_guidance(self, paths: List[Dict], features: Dict, width: int, height: int): |
| | """Apply semantic guidance to modify paths""" |
| | modified_paths = [] |
| | |
| | for path in paths: |
| | new_path = path.copy() |
| | |
| | |
| | if features['complexity'] == 'high': |
| | |
| | variation = 0.15 |
| | new_path['cp1'] = ( |
| | new_path['cp1'][0] + random.uniform(-width*variation, width*variation), |
| | new_path['cp1'][1] + random.uniform(-height*variation, height*variation) |
| | ) |
| | new_path['cp2'] = ( |
| | new_path['cp2'][0] + random.uniform(-width*variation, width*variation), |
| | new_path['cp2'][1] + random.uniform(-height*variation, height*variation) |
| | ) |
| | elif features['complexity'] == 'low': |
| | |
| | start_x, start_y = new_path['start'] |
| | end_x, end_y = new_path['end'] |
| | new_path['cp1'] = ( |
| | start_x + (end_x - start_x) * 0.33, |
| | start_y + (end_y - start_y) * 0.33 |
| | ) |
| | new_path['cp2'] = ( |
| | start_x + (end_x - start_x) * 0.66, |
| | start_y + (end_y - start_y) * 0.66 |
| | ) |
| | |
| | |
| | if features['organic']: |
| | |
| | new_path['stroke_width'] *= random.uniform(0.8, 1.2) |
| | new_path['opacity'] *= random.uniform(0.9, 1.1) |
| | elif features['geometric']: |
| | |
| | |
| | grid_size = 20 |
| | for key in ['start', 'cp1', 'cp2', 'end']: |
| | x, y = new_path[key] |
| | new_path[key] = ( |
| | round(x / grid_size) * grid_size, |
| | round(y / grid_size) * grid_size |
| | ) |
| | |
| | |
| | for key in ['start', 'cp1', 'cp2', 'end']: |
| | x, y = new_path[key] |
| | new_path[key] = ( |
| | max(0, min(width, x)), |
| | max(0, min(height, y)) |
| | ) |
| | |
| | modified_paths.append(new_path) |
| | |
| | return modified_paths |
| |
|
| | def apply_aesthetic_refinement(self, paths: List[Dict], width: int, height: int): |
| | """Apply aesthetic refinements to improve visual quality""" |
| | |
| | center_x, center_y = width / 2, height / 2 |
| | |
| | def distance_from_center(path): |
| | start_x, start_y = path['start'] |
| | return math.sqrt((start_x - center_x)**2 + (start_y - center_y)**2) |
| | |
| | |
| | paths.sort(key=distance_from_center, reverse=True) |
| | |
| | |
| | for i, path in enumerate(paths): |
| | |
| | layer_factor = 1.0 - (i / len(paths)) * 0.3 |
| | path['opacity'] = min(0.9, path['opacity'] * layer_factor) |
| | |
| | return paths |
| |
|
| | def paths_to_svg(self, paths: List[Dict], width: int, height: int): |
| | """Convert optimized paths to SVG format""" |
| | dwg = svgwrite.Drawing(size=(width, height)) |
| | dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| | |
| | for path in paths: |
| | start_x, start_y = path['start'] |
| | cp1_x, cp1_y = path['cp1'] |
| | cp2_x, cp2_y = path['cp2'] |
| | end_x, end_y = path['end'] |
| | |
| | |
| | path_data = f"M {start_x},{start_y} C {cp1_x},{cp1_y} {cp2_x},{cp2_y} {end_x},{end_y}" |
| | |
| | color = path['color'] |
| | stroke_color = f"rgb({color[0]},{color[1]},{color[2]})" |
| | |
| | dwg.add(dwg.path( |
| | d=path_data, |
| | stroke=stroke_color, |
| | stroke_width=path['stroke_width'], |
| | stroke_opacity=path['opacity'], |
| | fill='none', |
| | stroke_linecap='round', |
| | stroke_linejoin='round' |
| | )) |
| | |
| | return dwg.tostring() |
| |
|
| | def svg_to_pil_image(self, svg_content: str, width: int, height: int): |
| | """Convert SVG content to PIL Image""" |
| | try: |
| | import cairosvg |
| | |
| | |
| | png_bytes = cairosvg.svg2png( |
| | bytestring=svg_content.encode('utf-8'), |
| | output_width=width, |
| | output_height=height |
| | ) |
| | |
| | |
| | image = Image.open(io.BytesIO(png_bytes)).convert('RGB') |
| | return image |
| | |
| | except ImportError: |
| | print("cairosvg not available, creating simple image representation") |
| | |
| | image = Image.new('RGB', (width, height), 'white') |
| | return image |
| | except Exception as e: |
| | print(f"Error converting SVG to image: {e}") |
| | |
| | image = Image.new('RGB', (width, height), 'white') |
| | return image |
| |
|
| | def create_fallback_svg(self, prompt: str, width: int, height: int): |
| | """Create simple fallback SVG""" |
| | dwg = svgwrite.Drawing(size=(width, height)) |
| | dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| | |
| | |
| | dwg.add(dwg.text( |
| | f"DiffSketcher\n{prompt[:30]}...", |
| | insert=(width/2, height/2), |
| | text_anchor="middle", |
| | font_size="12px", |
| | fill="black" |
| | )) |
| | |
| | return dwg.tostring() |