| | import asyncio |
| | import json |
| | import logging |
| | import random |
| | import urllib.parse |
| | import urllib.request |
| | from typing import Optional |
| |
|
| | import websocket |
| | from open_webui.env import SRC_LOG_LEVELS |
| | from pydantic import BaseModel |
| |
|
| | log = logging.getLogger(__name__) |
| | log.setLevel(SRC_LOG_LEVELS["COMFYUI"]) |
| |
|
| | default_headers = {"User-Agent": "Mozilla/5.0"} |
| |
|
| |
|
| | def queue_prompt(prompt, client_id, base_url): |
| | log.info("queue_prompt") |
| | p = {"prompt": prompt, "client_id": client_id} |
| | data = json.dumps(p).encode("utf-8") |
| | log.debug(f"queue_prompt data: {data}") |
| | try: |
| | req = urllib.request.Request( |
| | f"{base_url}/prompt", data=data, headers=default_headers |
| | ) |
| | response = urllib.request.urlopen(req).read() |
| | return json.loads(response) |
| | except Exception as e: |
| | log.exception(f"Error while queuing prompt: {e}") |
| | raise e |
| |
|
| |
|
| | def get_image(filename, subfolder, folder_type, base_url): |
| | log.info("get_image") |
| | data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
| | url_values = urllib.parse.urlencode(data) |
| | req = urllib.request.Request( |
| | f"{base_url}/view?{url_values}", headers=default_headers |
| | ) |
| | with urllib.request.urlopen(req) as response: |
| | return response.read() |
| |
|
| |
|
| | def get_image_url(filename, subfolder, folder_type, base_url): |
| | log.info("get_image") |
| | data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
| | url_values = urllib.parse.urlencode(data) |
| | return f"{base_url}/view?{url_values}" |
| |
|
| |
|
| | def get_history(prompt_id, base_url): |
| | log.info("get_history") |
| |
|
| | req = urllib.request.Request( |
| | f"{base_url}/history/{prompt_id}", headers=default_headers |
| | ) |
| | with urllib.request.urlopen(req) as response: |
| | return json.loads(response.read()) |
| |
|
| |
|
| | def get_images(ws, prompt, client_id, base_url): |
| | prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] |
| | output_images = [] |
| | while True: |
| | out = ws.recv() |
| | if isinstance(out, str): |
| | message = json.loads(out) |
| | if message["type"] == "executing": |
| | data = message["data"] |
| | if data["node"] is None and data["prompt_id"] == prompt_id: |
| | break |
| | else: |
| | continue |
| |
|
| | history = get_history(prompt_id, base_url)[prompt_id] |
| | for o in history["outputs"]: |
| | for node_id in history["outputs"]: |
| | node_output = history["outputs"][node_id] |
| | if "images" in node_output: |
| | for image in node_output["images"]: |
| | url = get_image_url( |
| | image["filename"], image["subfolder"], image["type"], base_url |
| | ) |
| | output_images.append({"url": url}) |
| | return {"data": output_images} |
| |
|
| |
|
| | class ComfyUINodeInput(BaseModel): |
| | type: Optional[str] = None |
| | node_ids: list[str] = [] |
| | key: Optional[str] = "text" |
| | value: Optional[str] = None |
| |
|
| |
|
| | class ComfyUIWorkflow(BaseModel): |
| | workflow: str |
| | nodes: list[ComfyUINodeInput] |
| |
|
| |
|
| | class ComfyUIGenerateImageForm(BaseModel): |
| | workflow: ComfyUIWorkflow |
| |
|
| | prompt: str |
| | negative_prompt: Optional[str] = None |
| | width: int |
| | height: int |
| | n: int = 1 |
| |
|
| | steps: Optional[int] = None |
| | seed: Optional[int] = None |
| |
|
| |
|
| | async def comfyui_generate_image( |
| | model: str, payload: ComfyUIGenerateImageForm, client_id, base_url |
| | ): |
| | ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") |
| | workflow = json.loads(payload.workflow.workflow) |
| |
|
| | for node in payload.workflow.nodes: |
| | if node.type: |
| | if node.type == "model": |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"][node.key] = model |
| | elif node.type == "prompt": |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"]["text"] = payload.prompt |
| | elif node.type == "negative_prompt": |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"]["text"] = payload.negative_prompt |
| | elif node.type == "width": |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"]["width"] = payload.width |
| | elif node.type == "height": |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"]["height"] = payload.height |
| | elif node.type == "n": |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"]["batch_size"] = payload.n |
| | elif node.type == "steps": |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"]["steps"] = payload.steps |
| | elif node.type == "seed": |
| | seed = ( |
| | payload.seed |
| | if payload.seed |
| | else random.randint(0, 18446744073709551614) |
| | ) |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"][node.key] = seed |
| | else: |
| | for node_id in node.node_ids: |
| | workflow[node_id]["inputs"][node.key] = node.value |
| |
|
| | try: |
| | ws = websocket.WebSocket() |
| | ws.connect(f"{ws_url}/ws?clientId={client_id}") |
| | log.info("WebSocket connection established.") |
| | except Exception as e: |
| | log.exception(f"Failed to connect to WebSocket server: {e}") |
| | return None |
| |
|
| | try: |
| | log.info("Sending workflow to WebSocket server.") |
| | log.info(f"Workflow: {workflow}") |
| | images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url) |
| | except Exception as e: |
| | log.exception(f"Error while receiving images: {e}") |
| | images = None |
| |
|
| | ws.close() |
| |
|
| | return images |
| |
|