ImageGen / core /pipelines /pipeline_input_processor.py
RioShiina's picture
Upload folder using huggingface_hub
fbc9e7b verified
Raw
History Blame Contribute Delete
17.7 kB
import os
import random
import numpy as np
import gradio as gr
from PIL import Image, ImageChops
from typing import Dict, Any, List
from core.settings import INPUT_DIR
from utils.app_utils import (
sanitize_filename,
get_lora_path,
get_embedding_path,
ensure_controlnet_model_downloaded,
ensure_ipadapter_models_downloaded,
_ensure_model_downloaded,
ensure_sd3_ipadapter_models_downloaded,
get_vae_path,
)
def process_pipeline_inputs(ui_inputs: Dict[str, Any], progress: gr.Progress, workflow_model_type: str) -> Dict[str, Any]:
task_type = ui_inputs['task_type']
temp_files_to_clean = []
lora_data = ui_inputs.get('lora_data', [])
active_loras_for_gpu, active_loras_for_meta = [], []
if lora_data:
sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
for i, (source, lora_id, scale, _) in enumerate(zip(sources, ids, scales, files)):
if scale > 0 and lora_id and lora_id.strip():
lora_filename = None
if source == "File":
lora_filename = sanitize_filename(lora_id)
elif source == "Civitai":
local_path, status = get_lora_path(source, lora_id, os.environ.get("CIVITAI_API_KEY", ""), progress)
if local_path: lora_filename = os.path.basename(local_path)
else: raise gr.Error(f"Failed to prepare LoRA {lora_id}: {status}")
if lora_filename:
active_loras_for_gpu.append({"lora_name": lora_filename, "strength_model": scale, "strength_clip": scale})
active_loras_for_meta.append(f"{source} {lora_id}:{scale}")
ui_inputs['denoise'] = 1.0
if task_type == 'img2img': ui_inputs['denoise'] = ui_inputs.get('img2img_denoise', 0.7)
elif task_type == 'hires_fix': ui_inputs['denoise'] = ui_inputs.get('hires_denoise', 0.55)
elif task_type == 'inpaint': ui_inputs['denoise'] = ui_inputs.get('inpaint_denoise', 1.0)
if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR)
if task_type == 'img2img':
input_image_pil = ui_inputs.get('img2img_image')
if not input_image_pil:
raise gr.Error("Please upload an image for Image-to-Image.")
temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png")
input_image_pil.save(temp_file_path, "PNG")
ui_inputs['input_image'] = os.path.basename(temp_file_path)
temp_files_to_clean.append(temp_file_path)
ui_inputs['width'] = input_image_pil.width
ui_inputs['height'] = input_image_pil.height
elif task_type == 'inpaint':
inpaint_dict = ui_inputs.get('inpaint_image_dict')
if not inpaint_dict or not inpaint_dict.get('background') or not inpaint_dict.get('layers'):
raise gr.Error("Inpainting requires an input image and a drawn mask.")
background_img = inpaint_dict['background'].convert("RGBA")
composite_mask_pil = Image.new('L', background_img.size, 0)
for layer in inpaint_dict['layers']:
if layer:
layer_alpha = layer.split()[-1]
composite_mask_pil = ImageChops.lighter(composite_mask_pil, layer_alpha)
inverted_mask_alpha = Image.fromarray(255 - np.array(composite_mask_pil), mode='L')
r, g, b, _ = background_img.split()
composite_image_with_mask = Image.merge('RGBA', [r, g, b, inverted_mask_alpha])
temp_file_path = os.path.join(INPUT_DIR, f"temp_inpaint_composite_{random.randint(1000, 9999)}.png")
composite_image_with_mask.save(temp_file_path, "PNG")
ui_inputs['input_image'] = os.path.basename(temp_file_path)
temp_files_to_clean.append(temp_file_path)
ui_inputs.pop('inpaint_mask', None)
elif task_type == 'outpaint':
input_image_pil = ui_inputs.get('outpaint_image')
if not input_image_pil:
raise gr.Error("Please upload an image for Outpainting.")
temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png")
input_image_pil.save(temp_file_path, "PNG")
ui_inputs['input_image'] = os.path.basename(temp_file_path)
temp_files_to_clean.append(temp_file_path)
ui_inputs['megapixels'] = 0.25
ui_inputs['grow_mask_by'] = ui_inputs.get('feathering', 10)
elif task_type == 'hires_fix':
input_image_pil = ui_inputs.get('hires_image')
if not input_image_pil:
raise gr.Error("Please upload an image for Hires Fix.")
temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png")
input_image_pil.save(temp_file_path, "PNG")
ui_inputs['input_image'] = os.path.basename(temp_file_path)
temp_files_to_clean.append(temp_file_path)
embedding_data = ui_inputs.get('embedding_data', [])
embedding_filenames = []
if embedding_data:
emb_sources, emb_ids, emb_files = embedding_data[0::3], embedding_data[1::3], embedding_data[2::3]
for i, (source, emb_id, _) in enumerate(zip(emb_sources, emb_ids, emb_files)):
if emb_id and emb_id.strip():
emb_filename = None
if source == "File":
emb_filename = sanitize_filename(emb_id)
elif source == "Civitai":
local_path, status = get_embedding_path(source, emb_id, os.environ.get("CIVITAI_API_KEY", ""), progress)
if local_path: emb_filename = os.path.basename(local_path)
else: raise gr.Error(f"Failed to prepare Embedding {emb_id}: {status}")
if emb_filename:
embedding_filenames.append(emb_filename)
if embedding_filenames:
embedding_prompt_text = " ".join([f"embedding:{f}" for f in embedding_filenames])
if ui_inputs['positive_prompt']:
ui_inputs['positive_prompt'] = f"{ui_inputs['positive_prompt']}, {embedding_prompt_text}"
else:
ui_inputs['positive_prompt'] = embedding_prompt_text
controlnet_data = ui_inputs.get('controlnet_data', [])
active_controlnets = []
if controlnet_data:
(cn_images, _, _, cn_strengths, cn_filepaths) = [controlnet_data[i::5] for i in range(5)]
for i in range(len(cn_images)):
if cn_images[i] and cn_strengths[i] > 0 and cn_filepaths[i] and cn_filepaths[i] != "None":
ensure_controlnet_model_downloaded(cn_filepaths[i], progress)
if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR)
cn_temp_path = os.path.join(INPUT_DIR, f"temp_cn_{i}_{random.randint(1000, 9999)}.png")
cn_images[i].save(cn_temp_path, "PNG")
temp_files_to_clean.append(cn_temp_path)
active_controlnets.append({
"image": os.path.basename(cn_temp_path), "strength": cn_strengths[i],
"start_percent": 0.0, "end_percent": 1.0, "control_net_name": cn_filepaths[i]
})
anima_controlnet_lllite_data = ui_inputs.get('anima_controlnet_lllite_data', [])
active_anima_controlnets = []
if anima_controlnet_lllite_data:
(cn_images, _, _, cn_strengths, cn_filepaths, cn_starts, cn_ends) = [anima_controlnet_lllite_data[i::7] for i in range(7)]
for i in range(len(cn_images)):
if cn_images[i] and cn_strengths[i] > 0 and cn_filepaths[i] and cn_filepaths[i] != "None":
_ensure_model_downloaded(cn_filepaths[i], progress)
if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR)
cn_temp_path = os.path.join(INPUT_DIR, f"temp_anima_cn_{i}_{random.randint(1000, 9999)}.png")
cn_images[i].save(cn_temp_path, "PNG")
temp_files_to_clean.append(cn_temp_path)
active_anima_controlnets.append({
"image": os.path.basename(cn_temp_path), "strength": cn_strengths[i],
"start_percent": cn_starts[i], "end_percent": cn_ends[i], "control_net_name": cn_filepaths[i]
})
diffsynth_controlnet_data = ui_inputs.get('diffsynth_controlnet_data', [])
active_diffsynth_controlnets = []
if diffsynth_controlnet_data:
(cn_images, _, _, cn_strengths, cn_filepaths) = [diffsynth_controlnet_data[i::5] for i in range(5)]
for i in range(len(cn_images)):
if cn_images[i] and cn_strengths[i] > 0 and cn_filepaths[i] and cn_filepaths[i] != "None":
ensure_controlnet_model_downloaded(cn_filepaths[i], progress)
if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR)
cn_temp_path = os.path.join(INPUT_DIR, f"temp_diffsynth_cn_{i}_{random.randint(1000, 9999)}.png")
cn_images[i].save(cn_temp_path, "PNG")
temp_files_to_clean.append(cn_temp_path)
active_diffsynth_controlnets.append({
"image": os.path.basename(cn_temp_path), "strength": cn_strengths[i],
"control_net_name": cn_filepaths[i]
})
ipadapter_data = ui_inputs.get('ipadapter_data', [])
active_ipadapters = []
if ipadapter_data:
num_ipa_units = (len(ipadapter_data) - 5) // 3
final_preset, final_weight, final_lora_strength, final_embeds_scaling, final_combine_method = ipadapter_data[-5:]
ipa_images, ipa_weights, ipa_lora_strengths = [ipadapter_data[i*num_ipa_units:(i+1)*num_ipa_units] for i in range(3)]
all_presets_to_download = set()
for i in range(num_ipa_units):
if ipa_images[i] and ipa_weights[i] > 0 and final_preset:
all_presets_to_download.add(final_preset)
if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR)
ipa_temp_path = os.path.join(INPUT_DIR, f"temp_ipa_{i}_{random.randint(1000, 9999)}.png")
ipa_images[i].save(ipa_temp_path, "PNG")
temp_files_to_clean.append(ipa_temp_path)
active_ipadapters.append({
"image": os.path.basename(ipa_temp_path), "preset": final_preset,
"weight": ipa_weights[i], "lora_strength": ipa_lora_strengths[i]
})
if active_ipadapters and final_preset:
all_presets_to_download.add(final_preset)
for preset in all_presets_to_download:
ensure_ipadapter_models_downloaded(preset, progress)
model_type_key = 'sd15' if workflow_model_type == 'sd15' else 'sdxl'
if active_ipadapters:
active_ipadapters.append({
'is_final_settings': True, 'model_type': model_type_key, 'final_preset': final_preset,
'final_weight': final_weight, 'final_lora_strength': final_lora_strength,
'final_embeds_scaling': final_embeds_scaling, 'final_combine_method': final_combine_method
})
flux1_ipadapter_data = ui_inputs.get('flux1_ipadapter_data', [])
active_flux1_ipadapters = []
if flux1_ipadapter_data:
num_units = len(flux1_ipadapter_data) // 4
f_images = flux1_ipadapter_data[0*num_units : 1*num_units]
f_weights = flux1_ipadapter_data[1*num_units : 2*num_units]
f_starts = flux1_ipadapter_data[2*num_units : 3*num_units]
f_ends = flux1_ipadapter_data[3*num_units : 4*num_units]
for i in range(len(f_images)):
if f_images[i] and f_weights[i] > 0:
for filename in ["ip-adapter.bin"]:
_ensure_model_downloaded(filename, progress)
from huggingface_hub import snapshot_download
progress(0.5, desc="Caching HF SigLIP model...")
snapshot_download(
repo_id="google/siglip-so400m-patch14-384",
allow_patterns=["*.json", "*.safetensors", "*.txt"],
ignore_patterns=["*.msgpack", "*.h5", "*.bin"]
)
temp_path = os.path.join(INPUT_DIR, f"temp_fipa_{i}_{random.randint(1000, 9999)}.png")
f_images[i].save(temp_path, "PNG")
temp_files_to_clean.append(temp_path)
active_flux1_ipadapters.append({
"image": os.path.basename(temp_path),
"weight": f_weights[i], "start_percent": f_starts[i], "end_percent": f_ends[i]
})
sd3_ipadapter_data = ui_inputs.get('sd3_ipadapter_chain', [])
active_sd3_ipadapters = []
if sd3_ipadapter_data:
num_units = len(sd3_ipadapter_data) // 4
s_images = sd3_ipadapter_data[0*num_units : 1*num_units]
s_weights = sd3_ipadapter_data[1*num_units : 2*num_units]
s_starts = sd3_ipadapter_data[2*num_units : 3*num_units]
s_ends = sd3_ipadapter_data[3*num_units : 4*num_units]
sd3_ipa_downloaded = False
for i in range(len(s_images)):
if s_images[i] and s_weights[i] > 0:
if not sd3_ipa_downloaded:
ensure_sd3_ipadapter_models_downloaded(progress)
sd3_ipa_downloaded = True
temp_path = os.path.join(INPUT_DIR, f"temp_s3ipa_{i}_{random.randint(1000, 9999)}.png")
s_images[i].save(temp_path, "PNG")
temp_files_to_clean.append(temp_path)
active_sd3_ipadapters.append({
"image": os.path.basename(temp_path),
"weight": s_weights[i], "start_percent": s_starts[i], "end_percent": s_ends[i]
})
style_data = ui_inputs.get('style_data', [])
active_styles = []
if style_data:
num_units = len(style_data) // 2
st_images = style_data[0*num_units : 1*num_units]
st_strengths = style_data[1*num_units : 2*num_units]
for i in range(len(st_images)):
if st_images[i] and st_strengths[i] > 0:
_ensure_model_downloaded("sigclip_vision_patch14_384.safetensors", progress)
temp_path = os.path.join(INPUT_DIR, f"temp_style_{i}_{random.randint(1000, 9999)}.png")
st_images[i].save(temp_path, "PNG")
temp_files_to_clean.append(temp_path)
active_styles.append({
"image": os.path.basename(temp_path), "strength": st_strengths[i]
})
reference_latent_data = ui_inputs.get('reference_latent_data', [])
active_reference_latents = []
if reference_latent_data:
for img in reference_latent_data:
if img:
if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR)
temp_path = os.path.join(INPUT_DIR, f"temp_ref_{random.randint(1000, 9999)}.png")
img.save(temp_path, "PNG")
temp_files_to_clean.append(temp_path)
active_reference_latents.append(os.path.basename(temp_path))
hidream_o1_reference_data = ui_inputs.get('hidream_o1_reference_data', [])
active_hidream_o1_reference = []
if hidream_o1_reference_data:
for img in hidream_o1_reference_data:
if img:
if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR)
temp_path = os.path.join(INPUT_DIR, f"temp_ho1_ref_{random.randint(1000, 9999)}.png")
img.save(temp_path, "PNG")
temp_files_to_clean.append(temp_path)
active_hidream_o1_reference.append(os.path.basename(temp_path))
vae_source = ui_inputs.get('vae_source')
vae_id = ui_inputs.get('vae_id')
vae_name_override = None
if vae_source and vae_source != "None":
if vae_source == "File":
vae_name_override = sanitize_filename(vae_id)
elif vae_source == "Civitai" and vae_id and vae_id.strip():
local_path, status = get_vae_path(vae_source, vae_id, os.environ.get("CIVITAI_API_KEY", ""), progress)
if local_path: vae_name_override = os.path.basename(local_path)
else: raise gr.Error(f"Failed to prepare VAE {vae_id}: {status}")
if vae_name_override:
ui_inputs['vae_name'] = vae_name_override
conditioning_data = ui_inputs.get('conditioning_data', [])
active_conditioning = []
if conditioning_data:
num_units = len(conditioning_data) // 6
prompts, widths, heights, xs, ys, strengths = [conditioning_data[i*num_units : (i+1)*num_units] for i in range(6)]
for i in range(num_units):
if prompts[i] and prompts[i].strip():
active_conditioning.append({
"prompt": prompts[i], "width": int(widths[i]), "height": int(heights[i]),
"x": int(xs[i]), "y": int(ys[i]), "strength": float(strengths[i])
})
return {
"active_loras_for_gpu": active_loras_for_gpu,
"active_loras_for_meta": active_loras_for_meta,
"active_controlnets": active_controlnets,
"active_anima_controlnets": active_anima_controlnets,
"active_diffsynth_controlnets": active_diffsynth_controlnets,
"active_ipadapters": active_ipadapters,
"active_flux1_ipadapters": active_flux1_ipadapters,
"active_sd3_ipadapters": active_sd3_ipadapters,
"active_styles": active_styles,
"active_reference_latents": active_reference_latents,
"active_hidream_o1_reference": active_hidream_o1_reference,
"active_conditioning": active_conditioning,
"temp_files_to_clean": temp_files_to_clean
}