| import gradio as gr |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler |
| import torch |
| from PIL import Image |
| import random |
| from peft import PeftModel, LoraConfig |
|
|
| model_id = "CompVis/stable-diffusion-v1-4" |
| lora_model_id = "codermert/mert_flux" |
|
|
| def load_model(): |
| pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
| pipe = pipe.to("cpu") |
| pipe.safety_checker = None |
|
|
| |
| config = LoraConfig.from_pretrained(lora_model_id) |
| pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_model_id) |
| |
| return pipe |
|
|
| pipe = load_model() |
|
|
| def generate_image(prompt, negative_prompt, steps, cfg_scale, seed, strength): |
| if seed == -1: |
| seed = random.randint(1, 1000000000) |
| |
| generator = torch.Generator().manual_seed(seed) |
| |
| with torch.no_grad(): |
| image = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_inference_steps=steps, |
| guidance_scale=cfg_scale, |
| generator=generator, |
| ).images[0] |
| |
| return image, seed |
|
|
| css = """ |
| #app-container { |
| max-width: 800px; |
| margin-left: auto; |
| margin-right: auto; |
| } |
| """ |
|
|
| examples = [ |
| ["A beautiful landscape with mountains and a lake", "ugly, deformed"], |
| ["A futuristic cityscape at night", "daytime, rural"], |
| ["A portrait of a smiling person in a colorful outfit", "monochrome, frowning"], |
| ] |
|
|
| with gr.Blocks(theme='default', css=css) as app: |
| gr.HTML("<center><h1>Mert Flux LoRA Explorer (CPU Version)</h1></center>") |
| with gr.Column(elem_id="app-container"): |
| with gr.Row(): |
| text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=2) |
| negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What to avoid in the image", lines=2) |
| with gr.Row(): |
| with gr.Column(): |
| steps = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=50, step=1) |
| cfg_scale = gr.Slider(label="CFG Scale", value=7.5, minimum=1, maximum=15, step=0.5) |
| with gr.Column(): |
| seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1) |
|
|
| with gr.Row(): |
| generate_button = gr.Button("Generate", variant='primary') |
| with gr.Row(): |
| image_output = gr.Image(type="pil", label="Generated Image", show_download_button=True) |
| with gr.Row(): |
| seed_output = gr.Number(label="Seed Used") |
|
|
| gr.Examples(examples=examples, inputs=[text_prompt, negative_prompt]) |
|
|
| generate_button.click( |
| generate_image, |
| inputs=[text_prompt, negative_prompt, steps, cfg_scale, seed, strength], |
| outputs=[image_output, seed_output] |
| ) |
|
|
| app.launch() |