| | import base64 |
| | import io |
| | from PIL import Image |
| | import torch |
| | from diffusers import StableDiffusionImg2ImgPipeline |
| |
|
| | |
| | torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
| | pipe = None |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir: str): |
| | |
| | pass |
| |
|
| | def init(self): |
| | global pipe |
| | if pipe is None: |
| | |
| | pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
| | "karthikAI/InstantID-i2i", |
| | revision="main", |
| | torch_dtype=torch.float16 |
| | ).to(torch_device) |
| |
|
| | def inference(self, model_inputs: dict) -> dict: |
| | |
| | b64 = model_inputs.get("inputs") |
| | if b64 is None: |
| | return {"error": "No 'inputs' key with base64 image provided."} |
| | img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") |
| |
|
| | |
| | prompt = model_inputs.get("parameters", {}).get("prompt", "") |
| |
|
| | |
| | out = pipe(prompt=prompt, image=img) |
| | result_img = out.images[0] |
| |
|
| | |
| | buf = io.BytesIO() |
| | result_img.save(buf, format="PNG") |
| | b64_out = base64.b64encode(buf.getvalue()).decode() |
| | return {"generated_image_base64": b64_out} |