| import io |
| import os |
| import time |
| from pathlib import Path |
|
|
| import requests |
| from PIL import Image |
|
|
| API_ENDPOINT = "https://api.bfl.ml" |
|
|
|
|
| class ApiException(Exception): |
| def __init__(self, status_code: int, detail: str | list[dict] | None = None): |
| super().__init__() |
| self.detail = detail |
| self.status_code = status_code |
|
|
| def __str__(self) -> str: |
| return self.__repr__() |
|
|
| def __repr__(self) -> str: |
| if self.detail is None: |
| message = None |
| elif isinstance(self.detail, str): |
| message = self.detail |
| else: |
| message = "[" + ",".join(d["msg"] for d in self.detail) + "]" |
| return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" |
|
|
|
|
| class ImageRequest: |
| def __init__( |
| self, |
| prompt: str, |
| width: int = 1024, |
| height: int = 1024, |
| name: str = "flux.1-pro", |
| num_steps: int = 50, |
| prompt_upsampling: bool = False, |
| seed: int | None = None, |
| validate: bool = True, |
| launch: bool = True, |
| api_key: str | None = None, |
| ): |
| """ |
| Manages an image generation request to the API. |
| |
| Args: |
| prompt: Prompt to sample |
| width: Width of the image in pixel |
| height: Height of the image in pixel |
| name: Name of the model |
| num_steps: Number of network evaluations |
| prompt_upsampling: Use prompt upsampling |
| seed: Fix the generation seed |
| validate: Run input validation |
| launch: Directly launches request |
| api_key: Your API key if not provided by the environment |
| |
| Raises: |
| ValueError: For invalid input |
| ApiException: For errors raised from the API |
| """ |
| if validate: |
| if name not in ["flux.1-pro"]: |
| raise ValueError(f"Invalid model {name}") |
| elif width % 32 != 0: |
| raise ValueError(f"width must be divisible by 32, got {width}") |
| elif not (256 <= width <= 1440): |
| raise ValueError(f"width must be between 256 and 1440, got {width}") |
| elif height % 32 != 0: |
| raise ValueError(f"height must be divisible by 32, got {height}") |
| elif not (256 <= height <= 1440): |
| raise ValueError(f"height must be between 256 and 1440, got {height}") |
| elif not (1 <= num_steps <= 50): |
| raise ValueError(f"steps must be between 1 and 50, got {num_steps}") |
|
|
| self.request_json = { |
| "prompt": prompt, |
| "width": width, |
| "height": height, |
| "variant": name, |
| "steps": num_steps, |
| "prompt_upsampling": prompt_upsampling, |
| } |
| if seed is not None: |
| self.request_json["seed"] = seed |
|
|
| self.request_id: str | None = None |
| self.result: dict | None = None |
| self._image_bytes: bytes | None = None |
| self._url: str | None = None |
| if api_key is None: |
| self.api_key = os.environ.get("BFL_API_KEY") |
| else: |
| self.api_key = api_key |
|
|
| if launch: |
| self.request() |
|
|
| def request(self): |
| """ |
| Request to generate the image. |
| """ |
| if self.request_id is not None: |
| return |
| response = requests.post( |
| f"{API_ENDPOINT}/v1/image", |
| headers={ |
| "accept": "application/json", |
| "x-key": self.api_key, |
| "Content-Type": "application/json", |
| }, |
| json=self.request_json, |
| ) |
| result = response.json() |
| if response.status_code != 200: |
| raise ApiException(status_code=response.status_code, detail=result.get("detail")) |
| self.request_id = response.json()["id"] |
|
|
| def retrieve(self) -> dict: |
| """ |
| Wait for the generation to finish and retrieve response. |
| """ |
| if self.request_id is None: |
| self.request() |
| while self.result is None: |
| response = requests.get( |
| f"{API_ENDPOINT}/v1/get_result", |
| headers={ |
| "accept": "application/json", |
| "x-key": self.api_key, |
| }, |
| params={ |
| "id": self.request_id, |
| }, |
| ) |
| result = response.json() |
| if "status" not in result: |
| raise ApiException(status_code=response.status_code, detail=result.get("detail")) |
| elif result["status"] == "Ready": |
| self.result = result["result"] |
| elif result["status"] == "Pending": |
| time.sleep(0.5) |
| else: |
| raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") |
| return self.result |
|
|
| @property |
| def bytes(self) -> bytes: |
| """ |
| Generated image as bytes. |
| """ |
| if self._image_bytes is None: |
| response = requests.get(self.url) |
| if response.status_code == 200: |
| self._image_bytes = response.content |
| else: |
| raise ApiException(status_code=response.status_code) |
| return self._image_bytes |
|
|
| @property |
| def url(self) -> str: |
| """ |
| Public url to retrieve the image from |
| """ |
| if self._url is None: |
| result = self.retrieve() |
| self._url = result["sample"] |
| return self._url |
|
|
| @property |
| def image(self) -> Image.Image: |
| """ |
| Load the image as a PIL Image |
| """ |
| return Image.open(io.BytesIO(self.bytes)) |
|
|
| def save(self, path: str): |
| """ |
| Save the generated image to a local path |
| """ |
| suffix = Path(self.url).suffix |
| if not path.endswith(suffix): |
| path = path + suffix |
| Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "wb") as file: |
| file.write(self.bytes) |
|
|
|
|
| if __name__ == "__main__": |
| from fire import Fire |
|
|
| Fire(ImageRequest) |
|
|