| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import json |
| | import base64 |
| | import io |
| | from PIL import Image |
| | import svgwrite |
| | from typing import Dict, Any, List, Optional, Union |
| | import diffusers |
| | from diffusers import StableDiffusionPipeline, DDIMScheduler |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| | import torchvision.transforms as transforms |
| | import random |
| | import math |
| | import re |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model_id = "runwayml/stable-diffusion-v1-5" |
| | |
| | try: |
| | |
| | self.pipe = StableDiffusionPipeline.from_pretrained( |
| | self.model_id, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | safety_checker=None, |
| | requires_safety_checker=False |
| | ).to(self.device) |
| | |
| | |
| | self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
| | |
| | |
| | self.clip_model = self.pipe.text_encoder |
| | self.clip_tokenizer = self.pipe.tokenizer |
| | |
| | print("DiffSketchEdit handler initialized successfully!") |
| | except Exception as e: |
| | print(f"Warning: Could not load diffusion model: {e}") |
| | self.pipe = None |
| | self.clip_model = None |
| | self.clip_tokenizer = None |
| |
|
| | def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image: |
| | """ |
| | Perform sketch editing using DiffSketchEdit approach |
| | """ |
| | try: |
| | |
| | if isinstance(inputs, str): |
| | |
| | try: |
| | parsed_inputs = json.loads(inputs) |
| | if isinstance(parsed_inputs, dict): |
| | inputs = parsed_inputs |
| | else: |
| | |
| | prompts = [inputs] |
| | edit_type = "generate" |
| | parameters = {} |
| | except: |
| | |
| | prompts = [inputs] |
| | edit_type = "generate" |
| | parameters = {} |
| | |
| | if isinstance(inputs, dict): |
| | input_data = inputs.get("inputs", inputs) |
| | if isinstance(input_data, str): |
| | prompts = [input_data] |
| | edit_type = "generate" |
| | elif isinstance(input_data, dict): |
| | prompts = input_data.get("prompts", [input_data.get("prompt", "a simple sketch")]) |
| | edit_type = input_data.get("edit_type", "generate") |
| | else: |
| | prompts = ["a simple sketch"] |
| | edit_type = "generate" |
| | |
| | parameters = inputs.get("parameters", {}) |
| | |
| | |
| | width = parameters.get("width", 224) |
| | height = parameters.get("height", 224) |
| | seed = parameters.get("seed", None) |
| | input_svg = parameters.get("input_svg", None) |
| | |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| | |
| | print(f"Processing edit type: '{edit_type}' with prompts: {prompts}") |
| | |
| | |
| | if edit_type == "replace" and len(prompts) >= 2: |
| | svg_content, metadata = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg) |
| | elif edit_type == "refine": |
| | svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg) |
| | elif edit_type == "reweight": |
| | svg_content, metadata = self.attention_reweighting_edit(prompts[0], width, height, input_svg) |
| | elif edit_type == "generate": |
| | svg_content, metadata = self.simple_generation(prompts[0], width, height) |
| | else: |
| | |
| | svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg) |
| | |
| | |
| | pil_image = self.svg_to_pil_image(svg_content, width, height) |
| | |
| | |
| | pil_image.info['svg_content'] = svg_content |
| | for key, value in metadata.items(): |
| | if isinstance(value, (dict, list)): |
| | pil_image.info[key] = json.dumps(value) |
| | else: |
| | pil_image.info[key] = str(value) |
| | |
| | return pil_image |
| | |
| | except Exception as e: |
| | print(f"Error in handler: {e}") |
| | |
| | fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height) |
| | fallback_image = self.svg_to_pil_image(fallback_svg, width, height) |
| | fallback_image.info['error'] = str(e) |
| | fallback_image.info['edit_type'] = edit_type |
| | return fallback_image |
| |
|
| | def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None): |
| | """Perform word replacement editing""" |
| | try: |
| | print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'") |
| | |
| | |
| | added_words, removed_words = self.analyze_word_differences(source_prompt, target_prompt) |
| | print(f"Added words: {added_words}, Removed words: {removed_words}") |
| | |
| | |
| | if input_svg: |
| | base_svg = input_svg |
| | else: |
| | base_svg = self.generate_base_svg(source_prompt, width, height) |
| | |
| | |
| | edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height) |
| | |
| | metadata = { |
| | "edit_type": "replace", |
| | "source_prompt": source_prompt, |
| | "target_prompt": target_prompt, |
| | "added_words": list(added_words), |
| | "removed_words": list(removed_words) |
| | } |
| | |
| | return edited_svg, metadata |
| | |
| | except Exception as e: |
| | print(f"Error in word_replacement_edit: {e}") |
| | fallback_svg = self.create_fallback_svg(source_prompt, width, height) |
| | metadata = {"edit_type": "replace", "error": str(e)} |
| | return fallback_svg, metadata |
| |
|
| | def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None): |
| | """Perform prompt refinement editing""" |
| | try: |
| | print(f"Prompt refinement for: '{prompt}'") |
| | |
| | |
| | if input_svg: |
| | base_svg = input_svg |
| | else: |
| | base_svg = self.generate_base_svg(prompt, width, height) |
| | |
| | |
| | refined_svg = self.apply_refinement(base_svg, prompt, width, height) |
| | |
| | metadata = { |
| | "edit_type": "refine", |
| | "prompt": prompt |
| | } |
| | |
| | return refined_svg, metadata |
| | |
| | except Exception as e: |
| | print(f"Error in prompt_refinement_edit: {e}") |
| | fallback_svg = self.create_fallback_svg(prompt, width, height) |
| | metadata = {"edit_type": "refine", "error": str(e)} |
| | return fallback_svg, metadata |
| |
|
| | def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None): |
| | """Perform attention reweighting editing""" |
| | try: |
| | print(f"Attention reweighting for: '{prompt}'") |
| | |
| | |
| | weighted_prompt, attention_weights = self.parse_attention_weights(prompt) |
| | print(f"Weighted prompt: '{weighted_prompt}', weights: {attention_weights}") |
| | |
| | |
| | if input_svg: |
| | base_svg = input_svg |
| | else: |
| | base_svg = self.generate_base_svg(weighted_prompt, width, height) |
| | |
| | |
| | reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height) |
| | |
| | metadata = { |
| | "edit_type": "reweight", |
| | "prompt": prompt, |
| | "weighted_prompt": weighted_prompt, |
| | "attention_weights": attention_weights |
| | } |
| | |
| | return reweighted_svg, metadata |
| | |
| | except Exception as e: |
| | print(f"Error in attention_reweighting_edit: {e}") |
| | fallback_svg = self.create_fallback_svg(prompt, width, height) |
| | metadata = {"edit_type": "reweight", "error": str(e)} |
| | return fallback_svg, metadata |
| |
|
| | def simple_generation(self, prompt: str, width: int, height: int): |
| | """Perform simple SVG generation""" |
| | try: |
| | print(f"Simple generation for: '{prompt}'") |
| | |
| | svg_content = self.generate_base_svg(prompt, width, height) |
| | |
| | metadata = { |
| | "edit_type": "generate", |
| | "prompt": prompt |
| | } |
| | |
| | return svg_content, metadata |
| | |
| | except Exception as e: |
| | print(f"Error in simple_generation: {e}") |
| | fallback_svg = self.create_fallback_svg(prompt, width, height) |
| | metadata = {"edit_type": "generate", "error": str(e)} |
| | return fallback_svg, metadata |
| |
|
| | def generate_base_svg(self, prompt: str, width: int, height: int): |
| | """Generate base SVG from prompt""" |
| | dwg = svgwrite.Drawing(size=(width, height)) |
| | dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| | |
| | |
| | features = self.extract_semantic_features(prompt) |
| | |
| | |
| | if any(word in prompt.lower() for word in ['person', 'people', 'human', 'man', 'woman']): |
| | self.add_person_elements(dwg, width, height, features) |
| | elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog', 'bird', 'horse']): |
| | self.add_animal_elements(dwg, width, height, features) |
| | elif any(word in prompt.lower() for word in ['house', 'building', 'architecture']): |
| | self.add_building_elements(dwg, width, height, features) |
| | elif any(word in prompt.lower() for word in ['tree', 'nature', 'landscape']): |
| | self.add_nature_elements(dwg, width, height, features) |
| | elif any(word in prompt.lower() for word in ['car', 'vehicle', 'transport']): |
| | self.add_vehicle_elements(dwg, width, height, features) |
| | else: |
| | self.add_abstract_elements(dwg, width, height, features) |
| | |
| | return dwg.tostring() |
| |
|
| | def analyze_word_differences(self, source: str, target: str): |
| | """Analyze differences between source and target prompts""" |
| | source_words = set(source.lower().split()) |
| | target_words = set(target.lower().split()) |
| | |
| | added_words = target_words - source_words |
| | removed_words = source_words - target_words |
| | |
| | return added_words, removed_words |
| |
|
| | def parse_attention_weights(self, prompt: str): |
| | """Parse attention weights from prompt""" |
| | |
| | increase_pattern = r'\(([^:]+):([0-9.]+)\)' |
| | |
| | decrease_pattern = r'\[([^:]+):([0-9.]+)\]' |
| | |
| | attention_weights = {} |
| | weighted_prompt = prompt |
| | |
| | |
| | for match in re.finditer(increase_pattern, prompt): |
| | word = match.group(1).strip() |
| | weight = float(match.group(2)) |
| | attention_weights[word] = weight |
| | |
| | weighted_prompt = weighted_prompt.replace(match.group(0), word) |
| | |
| | |
| | for match in re.finditer(decrease_pattern, prompt): |
| | word = match.group(1).strip() |
| | weight = float(match.group(2)) |
| | attention_weights[word] = weight |
| | |
| | weighted_prompt = weighted_prompt.replace(match.group(0), word) |
| | |
| | return weighted_prompt.strip(), attention_weights |
| |
|
| | def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str, |
| | added_words: set, removed_words: set, width: int, height: int): |
| | """Apply word replacement transformations to SVG""" |
| | |
| | |
| | |
| | |
| | features = self.extract_semantic_features(target_prompt) |
| | |
| | |
| | dwg = svgwrite.Drawing(size=(width, height)) |
| | dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| | |
| | |
| | if any(word in added_words for word in ['red', 'blue', 'green', 'yellow']): |
| | |
| | self.add_colored_elements(dwg, width, height, added_words) |
| | elif any(word in added_words for word in ['big', 'large', 'huge']): |
| | |
| | self.add_large_elements(dwg, width, height, features) |
| | elif any(word in added_words for word in ['small', 'tiny', 'mini']): |
| | |
| | self.add_small_elements(dwg, width, height, features) |
| | else: |
| | |
| | self.add_content_based_on_prompt(dwg, target_prompt, width, height) |
| | |
| | return dwg.tostring() |
| |
|
| | def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int): |
| | """Apply refinement to existing SVG""" |
| | |
| | features = self.extract_semantic_features(prompt) |
| | |
| | dwg = svgwrite.Drawing(size=(width, height)) |
| | dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| | |
| | |
| | if features.get('detailed', False): |
| | self.add_detailed_elements(dwg, width, height, features) |
| | else: |
| | self.add_content_based_on_prompt(dwg, prompt, width, height) |
| | |
| | return dwg.tostring() |
| |
|
| | def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int): |
| | """Apply attention reweighting to SVG""" |
| | dwg = svgwrite.Drawing(size=(width, height)) |
| | dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| | |
| | |
| | for word, weight in attention_weights.items(): |
| | if weight > 1.0: |
| | |
| | self.add_emphasized_element(dwg, word, weight, width, height) |
| | elif weight < 1.0: |
| | |
| | self.add_deemphasized_element(dwg, word, weight, width, height) |
| | |
| | |
| | self.add_content_based_on_prompt(dwg, prompt, width, height) |
| | |
| | return dwg.tostring() |
| |
|
| | def add_person_elements(self, dwg, width, height, features): |
| | """Add person-like elements""" |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | |
| | head_radius = 20 |
| | dwg.add(dwg.circle(center=(center_x, center_y - 40), r=head_radius, fill='#FDBCB4', stroke='black', stroke_width=2)) |
| | |
| | |
| | body_height = 60 |
| | body_width = 30 |
| | dwg.add(dwg.rect( |
| | insert=(center_x - body_width//2, center_y - 10), |
| | size=(body_width, body_height), |
| | fill='#4A90E2', |
| | stroke='black', |
| | stroke_width=2 |
| | )) |
| | |
| | |
| | dwg.add(dwg.line(start=(center_x - body_width//2, center_y), end=(center_x - 40, center_y + 20), stroke='black', stroke_width=3)) |
| | dwg.add(dwg.line(start=(center_x + body_width//2, center_y), end=(center_x + 40, center_y + 20), stroke='black', stroke_width=3)) |
| | |
| | |
| | dwg.add(dwg.line(start=(center_x - 10, center_y + body_height - 10), end=(center_x - 20, center_y + body_height + 30), stroke='black', stroke_width=3)) |
| | dwg.add(dwg.line(start=(center_x + 10, center_y + body_height - 10), end=(center_x + 20, center_y + body_height + 30), stroke='black', stroke_width=3)) |
| |
|
| | def add_animal_elements(self, dwg, width, height, features): |
| | """Add animal-like elements""" |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | |
| | dwg.add(dwg.ellipse(center=(center_x, center_y), r=(40, 25), fill='#8B4513', stroke='black', stroke_width=2)) |
| | |
| | |
| | dwg.add(dwg.circle(center=(center_x - 30, center_y - 10), r=20, fill='#A0522D', stroke='black', stroke_width=2)) |
| | |
| | |
| | for i, x_offset in enumerate([-20, -10, 10, 20]): |
| | dwg.add(dwg.line( |
| | start=(center_x + x_offset, center_y + 25), |
| | end=(center_x + x_offset, center_y + 45), |
| | stroke='black', |
| | stroke_width=3 |
| | )) |
| | |
| | |
| | dwg.add(dwg.path( |
| | d=f"M {center_x + 40},{center_y} Q {center_x + 60},{center_y - 20} {center_x + 50},{center_y - 35}", |
| | stroke='black', |
| | stroke_width=3, |
| | fill='none' |
| | )) |
| |
|
| | def add_building_elements(self, dwg, width, height, features): |
| | """Add building-like elements""" |
| | |
| | building_width = width * 0.6 |
| | building_height = height * 0.7 |
| | x = (width - building_width) // 2 |
| | y = height - building_height - 10 |
| | |
| | dwg.add(dwg.rect( |
| | insert=(x, y), |
| | size=(building_width, building_height), |
| | fill='#CD853F', |
| | stroke='black', |
| | stroke_width=2 |
| | )) |
| | |
| | |
| | roof_points = [(x, y), (x + building_width//2, y - 30), (x + building_width, y)] |
| | dwg.add(dwg.polygon(points=roof_points, fill='#8B0000', stroke='black', stroke_width=2)) |
| | |
| | |
| | window_size = 15 |
| | for i in range(3): |
| | for j in range(4): |
| | wx = x + 15 + i * 30 |
| | wy = y + 15 + j * 25 |
| | if wy < y + building_height - 20: |
| | dwg.add(dwg.rect( |
| | insert=(wx, wy), |
| | size=(window_size, window_size), |
| | fill='#87CEEB', |
| | stroke='black', |
| | stroke_width=1 |
| | )) |
| | |
| | |
| | door_width = 20 |
| | door_height = 40 |
| | door_x = x + building_width//2 - door_width//2 |
| | door_y = y + building_height - door_height |
| | dwg.add(dwg.rect( |
| | insert=(door_x, door_y), |
| | size=(door_width, door_height), |
| | fill='#8B4513', |
| | stroke='black', |
| | stroke_width=2 |
| | )) |
| |
|
| | def add_nature_elements(self, dwg, width, height, features): |
| | """Add nature-like elements""" |
| | |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | |
| | trunk_width = 15 |
| | trunk_height = height // 3 |
| | trunk_x = center_x - trunk_width // 2 |
| | trunk_y = height - trunk_height - 10 |
| | |
| | dwg.add(dwg.rect( |
| | insert=(trunk_x, trunk_y), |
| | size=(trunk_width, trunk_height), |
| | fill='#8B4513', |
| | stroke='black', |
| | stroke_width=1 |
| | )) |
| | |
| | |
| | crown_radius = 30 |
| | for i, (dx, dy) in enumerate([(-15, -20), (15, -20), (0, -35), (-10, -50), (10, -50)]): |
| | dwg.add(dwg.circle( |
| | center=(center_x + dx, center_y + dy), |
| | r=crown_radius - i * 3, |
| | fill='#228B22', |
| | stroke='#006400', |
| | stroke_width=1, |
| | opacity=0.8 |
| | )) |
| |
|
| | def add_vehicle_elements(self, dwg, width, height, features): |
| | """Add vehicle-like elements""" |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | |
| | car_width = width * 0.6 |
| | car_height = height * 0.3 |
| | car_x = (width - car_width) // 2 |
| | car_y = center_y + 10 |
| | |
| | dwg.add(dwg.rect( |
| | insert=(car_x, car_y), |
| | size=(car_width, car_height), |
| | fill='#FF4500', |
| | stroke='black', |
| | stroke_width=2, |
| | rx=5 |
| | )) |
| | |
| | |
| | windshield_width = car_width * 0.6 |
| | windshield_height = car_height * 0.4 |
| | windshield_x = car_x + (car_width - windshield_width) // 2 |
| | windshield_y = car_y - windshield_height + 5 |
| | |
| | dwg.add(dwg.rect( |
| | insert=(windshield_x, windshield_y), |
| | size=(windshield_width, windshield_height), |
| | fill='#87CEEB', |
| | stroke='black', |
| | stroke_width=1 |
| | )) |
| | |
| | |
| | wheel_radius = 12 |
| | wheel_y = car_y + car_height - 5 |
| | dwg.add(dwg.circle(center=(car_x + 25, wheel_y), r=wheel_radius, fill='black')) |
| | dwg.add(dwg.circle(center=(car_x + car_width - 25, wheel_y), r=wheel_radius, fill='black')) |
| |
|
| | def add_abstract_elements(self, dwg, width, height, features): |
| | """Add abstract elements""" |
| | colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7'] |
| | |
| | for i in range(5): |
| | shape_type = random.choice(['circle', 'rect', 'path']) |
| | color = random.choice(colors) |
| | |
| | if shape_type == 'circle': |
| | radius = random.randint(10, 30) |
| | x = random.randint(radius, width - radius) |
| | y = random.randint(radius, height - radius) |
| | dwg.add(dwg.circle(center=(x, y), r=radius, fill=color, opacity=0.7)) |
| | elif shape_type == 'rect': |
| | w = random.randint(20, 60) |
| | h = random.randint(20, 60) |
| | x = random.randint(0, width - w) |
| | y = random.randint(0, height - h) |
| | dwg.add(dwg.rect(insert=(x, y), size=(w, h), fill=color, opacity=0.7)) |
| | else: |
| | |
| | start_x = random.randint(0, width) |
| | start_y = random.randint(0, height) |
| | end_x = random.randint(0, width) |
| | end_y = random.randint(0, height) |
| | dwg.add(dwg.line(start=(start_x, start_y), end=(end_x, end_y), stroke=color, stroke_width=3)) |
| |
|
| | def add_colored_elements(self, dwg, width, height, color_words): |
| | """Add elements with specific colors""" |
| | color_map = { |
| | 'red': '#FF0000', |
| | 'blue': '#0000FF', |
| | 'green': '#00FF00', |
| | 'yellow': '#FFFF00', |
| | 'purple': '#800080', |
| | 'orange': '#FFA500' |
| | } |
| | |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | for word in color_words: |
| | if word in color_map: |
| | color = color_map[word] |
| | |
| | dwg.add(dwg.circle( |
| | center=(center_x + random.randint(-50, 50), center_y + random.randint(-50, 50)), |
| | r=random.randint(15, 35), |
| | fill=color, |
| | opacity=0.8 |
| | )) |
| |
|
| | def add_large_elements(self, dwg, width, height, features): |
| | """Add large-sized elements""" |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | |
| | dwg.add(dwg.circle( |
| | center=(center_x, center_y), |
| | r=min(width, height) // 3, |
| | fill='#4A90E2', |
| | stroke='black', |
| | stroke_width=3 |
| | )) |
| |
|
| | def add_small_elements(self, dwg, width, height, features): |
| | """Add small-sized elements""" |
| | |
| | for i in range(8): |
| | x = random.randint(10, width - 10) |
| | y = random.randint(10, height - 10) |
| | dwg.add(dwg.circle( |
| | center=(x, y), |
| | r=random.randint(3, 8), |
| | fill='#E74C3C', |
| | opacity=0.7 |
| | )) |
| |
|
| | def add_detailed_elements(self, dwg, width, height, features): |
| | """Add detailed elements for refinement""" |
| | |
| | self.add_abstract_elements(dwg, width, height, features) |
| | |
| | |
| | center_x, center_y = width // 2, height // 2 |
| | for i in range(4): |
| | angle = i * math.pi / 2 |
| | x = center_x + 40 * math.cos(angle) |
| | y = center_y + 40 * math.sin(angle) |
| | dwg.add(dwg.circle(center=(x, y), r=8, fill='#9B59B6', opacity=0.6)) |
| |
|
| | def add_emphasized_element(self, dwg, word: str, weight: float, width: int, height: int): |
| | """Add emphasized element based on attention weight""" |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | |
| | base_size = 20 |
| | size = int(base_size * weight) |
| | |
| | dwg.add(dwg.circle( |
| | center=(center_x + random.randint(-30, 30), center_y + random.randint(-30, 30)), |
| | r=size, |
| | fill='#FF6B6B', |
| | opacity=min(1.0, weight / 2), |
| | stroke='black', |
| | stroke_width=2 |
| | )) |
| |
|
| | def add_deemphasized_element(self, dwg, word: str, weight: float, width: int, height: int): |
| | """Add de-emphasized element based on attention weight""" |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | |
| | base_size = 15 |
| | size = int(base_size * weight) |
| | |
| | dwg.add(dwg.circle( |
| | center=(center_x + random.randint(-40, 40), center_y + random.randint(-40, 40)), |
| | r=max(3, size), |
| | fill='#CCCCCC', |
| | opacity=weight, |
| | stroke='gray', |
| | stroke_width=1 |
| | )) |
| |
|
| | def add_content_based_on_prompt(self, dwg, prompt: str, width: int, height: int): |
| | """Add content based on prompt analysis""" |
| | features = self.extract_semantic_features(prompt) |
| | |
| | if any(word in prompt.lower() for word in ['person', 'people', 'human']): |
| | self.add_person_elements(dwg, width, height, features) |
| | elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog']): |
| | self.add_animal_elements(dwg, width, height, features) |
| | elif any(word in prompt.lower() for word in ['house', 'building']): |
| | self.add_building_elements(dwg, width, height, features) |
| | elif any(word in prompt.lower() for word in ['tree', 'nature']): |
| | self.add_nature_elements(dwg, width, height, features) |
| | elif any(word in prompt.lower() for word in ['car', 'vehicle']): |
| | self.add_vehicle_elements(dwg, width, height, features) |
| | else: |
| | self.add_abstract_elements(dwg, width, height, features) |
| |
|
| | def extract_semantic_features(self, prompt: str): |
| | """Extract semantic features from prompt""" |
| | features = { |
| | 'detailed': False, |
| | 'simple': False, |
| | 'colorful': False, |
| | 'large': False, |
| | 'small': False |
| | } |
| | |
| | prompt_lower = prompt.lower() |
| | |
| | if any(word in prompt_lower for word in ['detailed', 'complex', 'intricate']): |
| | features['detailed'] = True |
| | if any(word in prompt_lower for word in ['simple', 'minimal', 'basic']): |
| | features['simple'] = True |
| | if any(word in prompt_lower for word in ['colorful', 'bright', 'vibrant']): |
| | features['colorful'] = True |
| | if any(word in prompt_lower for word in ['large', 'big', 'huge']): |
| | features['large'] = True |
| | if any(word in prompt_lower for word in ['small', 'tiny', 'mini']): |
| | features['small'] = True |
| | |
| | return features |
| |
|
| | def svg_to_pil_image(self, svg_content: str, width: int, height: int): |
| | """Convert SVG content to PIL Image""" |
| | try: |
| | import cairosvg |
| | |
| | |
| | png_bytes = cairosvg.svg2png( |
| | bytestring=svg_content.encode('utf-8'), |
| | output_width=width, |
| | output_height=height |
| | ) |
| | |
| | |
| | image = Image.open(io.BytesIO(png_bytes)).convert('RGB') |
| | return image |
| | |
| | except ImportError: |
| | print("cairosvg not available, creating simple image representation") |
| | |
| | image = Image.new('RGB', (width, height), 'white') |
| | return image |
| | except Exception as e: |
| | print(f"Error converting SVG to image: {e}") |
| | |
| | image = Image.new('RGB', (width, height), 'white') |
| | return image |
| |
|
| | def create_fallback_svg(self, prompt: str, width: int, height: int): |
| | """Create simple fallback SVG""" |
| | dwg = svgwrite.Drawing(size=(width, height)) |
| | dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| | |
| | |
| | prompt_str = str(prompt)[:30] if prompt else "error" |
| | dwg.add(dwg.text( |
| | f"DiffSketchEdit\n{prompt_str}...", |
| | insert=(width/2, height/2), |
| | text_anchor="middle", |
| | font_size="12px", |
| | fill="black" |
| | )) |
| | |
| | return dwg.tostring() |