Spaces:
Paused
Paused
| import argparse | |
| import json | |
| import itertools | |
| import math | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| # 添加父目录到系统路径 | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import numpy as np | |
| import torch | |
| from einops import rearrange, repeat | |
| from PIL import Image | |
| from safetensors.torch import load_file | |
| from torchvision.transforms import functional as F | |
| from tqdm import tqdm | |
| import torch.nn.functional as Func | |
| import infer.sampling as sampling | |
| from modules.autoencoder import AutoEncoder | |
| from modules.model_edit import Step1XParams, Step1XEdit | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| DEFAULT_QWEN_DIR = REPO_ROOT / "Qwen" | |
| EMPTY_PROMPT_LATENT_PATH = REPO_ROOT / "latent" / "no_info.npz" | |
| def cudagc(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True): | |
| if Path(ckpt_path).suffix == ".safetensors": | |
| state_dict = load_file(ckpt_path, device) | |
| else: | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| missing, unexpected = model.load_state_dict(state_dict, strict=strict, assign=assign) | |
| if len(missing) > 0 and len(unexpected) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| print("\n" + "-" * 79 + "\n") | |
| print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
| elif len(missing) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| elif len(unexpected) > 0: | |
| print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
| return model | |
| def load_models(dit_path=None, ae_path=None, qwen2vl_model_path=None, device="cuda", max_length=256, dtype=torch.bfloat16, args=None): | |
| empty_llm = args is not None and hasattr(args, 'prompt_type') and args.prompt_type == 'empty' | |
| if empty_llm: | |
| print("[INFO] prompt_type=empty, 跳过Qwen模型加载") | |
| qwen2vl_encoder = None | |
| else: | |
| # Lazy import to avoid pulling transformers/vision stack during evaluation with prompt_type=empty. | |
| from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder | |
| qwen2vl_encoder = Qwen2VLEmbedder( | |
| qwen2vl_model_path, | |
| device=device, | |
| max_length=max_length, | |
| dtype=dtype, | |
| args=args, | |
| ) | |
| with torch.device("meta"): | |
| ae = AutoEncoder( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ) | |
| step1x_params = Step1XParams( | |
| in_channels=64, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| ) | |
| dit = Step1XEdit(step1x_params) | |
| ae = load_state_dict(ae, ae_path, 'cpu') | |
| dit = load_state_dict(dit, dit_path, 'cpu') | |
| ae = ae.to(dtype=torch.float32) | |
| return ae, dit, qwen2vl_encoder | |
| def equip_dit_with_lora_sd_scripts(ae, text_encoders, dit, lora, device='cuda'): | |
| from safetensors.torch import load_file | |
| weights_sd = load_file(lora) | |
| is_lora = True | |
| from library import lora_module | |
| module = lora_module | |
| lora_model, _ = module.create_network_from_weights(1.0, None, ae, text_encoders, dit, weights_sd, True) | |
| lora_model.merge_to(text_encoders, dit, weights_sd) | |
| lora_model.set_multiplier(1.0) | |
| return lora_model | |
| class ImageGenerator: | |
| def __init__( | |
| self, | |
| dit_path=None, | |
| ae_path=None, | |
| qwen2vl_model_path=None, | |
| device="cuda", | |
| max_length=640, | |
| dtype=torch.bfloat16, | |
| quantized=False, | |
| offload=False, | |
| lora=None, | |
| args=None, | |
| ) -> None: | |
| self.device = torch.device(device) | |
| self.args = args | |
| self.ae, self.dit, self.llm_encoder = load_models( | |
| dit_path=dit_path, | |
| ae_path=ae_path, | |
| qwen2vl_model_path=qwen2vl_model_path, | |
| max_length=max_length, | |
| dtype=dtype, | |
| device=device, | |
| args=args, | |
| ) | |
| if not quantized: | |
| self.dit = self.dit.to(dtype=torch.bfloat16) | |
| else: | |
| self.dit = self.dit.to(dtype=torch.float8_e4m3fn) | |
| if not offload: | |
| self.dit = self.dit.to(device=self.device) | |
| self.ae = self.ae.to(device=self.device) | |
| self.quantized = quantized | |
| self.offload = offload | |
| if lora is not None: | |
| self.lora_module = equip_dit_with_lora_sd_scripts( | |
| self.ae, | |
| [self.llm_encoder], | |
| self.dit, | |
| lora, | |
| device=self.dit.device, | |
| ) | |
| else: | |
| self.lora_module = None | |
| def prepare(self, prompt, img, ref_image, ref_image_raw, empty_llm=False): | |
| bs, _, h, w = img.shape | |
| bs, _, ref_h, ref_w = ref_image.shape | |
| assert h == ref_h and w == ref_w | |
| if bs == 1 and not isinstance(prompt, str): | |
| bs = len(prompt) | |
| elif bs >= 1 and isinstance(prompt, str): | |
| prompt = [prompt] * bs | |
| img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) #2,16,82,110->2,2255,64 | |
| ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2) # 将二维图像"压平"成一维序列 这是为 Transformer 模型准备的,因为它处理的是序列数据 | |
| if img.shape[0] == 1 and bs > 1: | |
| img = repeat(img, "1 ... -> bs ...", bs=bs) | |
| ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) | |
| #img 和 ref_img 已经不再是二维的图片了,而是变成了一个 "patches" (图像块) 的序列。一个块是64维度的。Transformer不知道这2255个图像块哪个在左上角,哪个在右下角。 | |
| img_ids = torch.zeros(h // 2, w // 2, 3) #41,55,3 # h 和 w 是潜在空间的高和宽,但 rearrange 操作把 2x2 的小块合并了# 所以实际的网格大小是 h/2 x w/2# 最后的 3 代表每个坐标有3个分量 (一个预留, Y坐标, X坐标) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] #通过广播机制,第 i 行的所有点的第二个分量都被赋值为 i | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) #将二维坐标网格"压平"成一维序列,并复制到对应的batch size | |
| ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) | |
| ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] | |
| ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] | |
| ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs) | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| if self.offload: | |
| self.llm_encoder = self.llm_encoder.to(self.device) | |
| if empty_llm: | |
| empty_prompt_cache = getattr(self.args, "empty_prompt_cache", None) if self.args is not None else None | |
| cache_path = Path(empty_prompt_cache) if empty_prompt_cache else EMPTY_PROMPT_LATENT_PATH | |
| data = np.load(cache_path) | |
| txt = torch.from_numpy(data['embeds']).to(img.device).unsqueeze(0) | |
| txt = torch.cat([txt, txt], dim=0) | |
| mask = torch.from_numpy(data['masks']).to(img.device).unsqueeze(0) | |
| mask = torch.cat([mask, mask], dim=0) | |
| else: | |
| txt, mask = self.llm_encoder(prompt, ref_image_raw) #之所以都要复制一份,是因为有正负两种prompt | |
| if self.offload: | |
| self.llm_encoder = self.llm_encoder.cpu() | |
| cudagc() | |
| txt_ids = torch.zeros(bs, txt.shape[1], 3) | |
| img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2) #2,4550,64 在patch上concat??? | |
| img_ids = torch.cat([img_ids, ref_img_ids], dim=-2) | |
| return { | |
| "img": img, | |
| "mask": mask, | |
| "img_ids": img_ids.to(img.device), #图像坐标 | |
| "llm_embedding": txt.to(img.device), #文字向量 | |
| "txt_ids": txt_ids.to(img.device), #文字坐标 | |
| } | |
| def process_diff_norm(diff_norm, k): | |
| pow_result = torch.pow(diff_norm, k) | |
| result = torch.where( | |
| diff_norm > 1.0, | |
| pow_result, | |
| torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm), | |
| ) | |
| return result | |
| def denoise( | |
| self, | |
| img: torch.Tensor, | |
| img_ids: torch.Tensor, | |
| llm_embedding: torch.Tensor, | |
| txt_ids: torch.Tensor, | |
| timesteps: list[float], | |
| cfg_guidance: float = 6.0, | |
| mask=None, | |
| show_progress=False, | |
| timesteps_truncate=1.0, | |
| ): | |
| if self.offload: | |
| self.dit = self.dit.to(self.device) | |
| if show_progress: | |
| pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...') | |
| else: | |
| pbar = itertools.pairwise(timesteps) | |
| ''' | |
| Cond 0 RGB | |
| Uncd 0 RGB | |
| ''' | |
| for t_curr, t_prev in pbar: | |
| ''' | |
| 若输入维度是2,无所谓,维度是1则: | |
| imgN D RGB | |
| imgN D RGB | |
| ''' | |
| if img.shape[0] == 1 and cfg_guidance != -1: | |
| img = torch.cat([img, img], dim=0) | |
| t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) | |
| pred, feat = self.dit( | |
| img=img, | |
| img_ids=img_ids, | |
| txt_ids=txt_ids, | |
| timesteps=t_vec, | |
| llm_embedding=llm_embedding, | |
| t_vec=t_vec, | |
| mask=mask, | |
| ) | |
| assert cfg_guidance != -1, " cfg_guidance must not be -1 NOW!!!!" | |
| cond, uncond = ( | |
| pred[0:pred.shape[0] // 2, :], | |
| pred[pred.shape[0] // 2:, :], | |
| ) | |
| ''' | |
| Cond D ??? <- pred | |
| Uncd D ??? | |
| ''' | |
| pred = uncond + cfg_guidance * (cond - uncond) #1,4608,64 | |
| pred1 = cond #todo only support single denoise!!! | |
| ''' | |
| Cond 0 RGB | |
| + pred D ??? | |
| temI D ??? | |
| ''' | |
| tem_img = img[0:img.shape[0] // 2, :] + (t_prev - t_curr) * pred #1,4608,64 | |
| img_input_length = img.shape[1] // 2 | |
| ''' | |
| tmpI [D](√) ???(x) | |
| cat Cond 0(x) [RGB](√) | |
| imgN [D] [RGB] | |
| ''' | |
| img = torch.cat( | |
| [ | |
| tem_img[:, :img_input_length], #1,2304,64 | |
| img[:img.shape[0] // 2, img_input_length:], #1,2304,64 | |
| ], | |
| dim=1) #1,4608,64 | |
| if self.offload: | |
| self.dit = self.dit.cpu() | |
| cudagc() | |
| return img[:, :img.shape[1] // 2], pred1[:, img.shape[1] // 2:] | |
| def double_denoise(self,img,img_ids,llm_embedding,txt_ids,timesteps,cfg_guidance=6.0,mask=None,height=None,width=None): | |
| if img.shape[0] == 1 and cfg_guidance != -1: | |
| img = torch.cat([img, img], dim=0) | |
| t_vec = torch.full((img.shape[0],), 1.0, dtype=img.dtype, device=img.device) | |
| pred, _ = self.dit( | |
| img=img, | |
| img_ids=img_ids, | |
| txt_ids=txt_ids, | |
| timesteps=t_vec, | |
| llm_embedding=llm_embedding, | |
| t_vec=t_vec, | |
| mask=mask, | |
| ) | |
| assert cfg_guidance != -1, " cfg_guidance must not be -1 NOW!!!!" | |
| pred, uncond = ( | |
| pred[0:pred.shape[0] // 2, :], | |
| pred[pred.shape[0] // 2:, :], | |
| ) | |
| Lpred,Rpred = self.unpack_latents(pred, height//16, width//16) | |
| return Lpred.to(torch.float32),Rpred.to(torch.float32) | |
| def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: | |
| return rearrange( | |
| x, | |
| "b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
| h=math.ceil(height / 16), | |
| w=math.ceil(width / 16), | |
| ph=2, | |
| pw=2, | |
| ) | |
| def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int): | |
| """ | |
| x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 | |
| """ | |
| import einops | |
| x = einops.rearrange(x, "b (p h w) (c ph pw) -> b p c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2, p=2) | |
| return x[:, 0], x[:, 1] | |
| def load_image(image): | |
| from PIL import Image | |
| if isinstance(image, np.ndarray): | |
| image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 | |
| image = image.unsqueeze(0) | |
| return image | |
| elif isinstance(image, Image.Image): | |
| image = F.to_tensor(image.convert("RGB")) | |
| image = image.unsqueeze(0) | |
| return image | |
| elif isinstance(image, torch.Tensor): | |
| return image | |
| elif isinstance(image, str): | |
| image = F.to_tensor(Image.open(image).convert("RGB")) | |
| image = image.unsqueeze(0) | |
| return image | |
| else: | |
| raise ValueError(f"Unsupported image type: {type(image)}") | |
| def output_process_image(self, resize_img, image_size): | |
| res_image = resize_img.resize(image_size) | |
| return res_image | |
| def input_process_image(self, img): | |
| if isinstance(img, torch.Tensor): | |
| w, h = img.shape[-1], img.shape[-2] | |
| elif isinstance(img, Image.Image): | |
| w, h = img.size | |
| if w <= 1024 and h <= 768: | |
| w_new, h_new = 1024, 768 | |
| elif w <= 1280 and h <= 960: | |
| w_new, h_new = 1216, 352 | |
| elif w <= 6048 and h <= 4032: | |
| w_new, h_new = 864, 576 | |
| else: | |
| w_new, h_new = w, h | |
| if isinstance(img, torch.Tensor): | |
| img_resized = Func.interpolate(img, (h_new, w_new), mode='bilinear', align_corners=False) | |
| img_resized = img_resized.clamp(0, 1) | |
| else: | |
| img_resized = img.resize((w_new, h_new)) | |
| return img_resized, (w_new, h_new) | |
| def generate_image( | |
| self,prompt,negative_prompt,ref_images,num_steps,cfg_guidance,seed,num_samples=1,init_image=None,image2image_strength=0.0,show_progress=False,size_level=512,args=None,judge=None,name=None | |
| ): | |
| assert num_samples == 1, "num_samples > 1 is not supported yet." | |
| ref_images_raw, img_info = self.input_process_image(ref_images) | |
| if isinstance(ref_images, Image.Image): | |
| ref_images_raw = self.load_image(ref_images_raw) | |
| height, width = ref_images_raw.shape[-2], ref_images_raw.shape[-1] | |
| ref_images_raw = ref_images_raw.to(self.device) | |
| if self.offload: | |
| self.ae = self.ae.to(self.device) | |
| ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1) #bs,3,656,880 -> 1,16,82,110 | |
| #加入cache | |
| if self.offload: | |
| self.ae = self.ae.cpu() | |
| cudagc() | |
| seed = int(seed) | |
| seed = torch.Generator(device="cpu").seed() if seed < 0 else seed | |
| t0 = time.perf_counter() | |
| if init_image is not None: | |
| init_image = self.load_image(init_image) | |
| init_image = init_image.to(self.device) | |
| init_image = torch.nn.functional.interpolate(init_image, (height, width)) | |
| if self.offload: | |
| self.ae = self.ae.to(self.device) | |
| init_image = self.ae.encode(init_image.to() * 2 - 1) | |
| if self.offload: | |
| self.ae = self.ae.cpu() | |
| cudagc() | |
| _dtype = torch.float32 if self.device.type == "cpu" else torch.bfloat16 | |
| if args is not None and hasattr(args, 'single_denoise') and not args.single_denoise: | |
| x = torch.randn(num_samples,16,height // 8,width // 8,device=self.device,dtype=_dtype,generator=torch.Generator(device=self.device).manual_seed(seed),) | |
| else: | |
| x= torch.zeros(num_samples,16,height // 8,width // 8,device=self.device,dtype=_dtype,) | |
| timesteps = sampling.get_schedule(num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True) | |
| if init_image is not None: | |
| t_idx = int((1 - image2image_strength) * num_steps) | |
| t = timesteps[t_idx] | |
| timesteps = timesteps[t_idx:] | |
| x = t * x + (1.0 - t) * init_image.to(x.dtype) | |
| x = torch.cat([x, x], dim=0) | |
| ref_images = torch.cat([ref_images, ref_images], dim=0) #这里是为了有无prompt | |
| ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0) | |
| # 检查args和prompt_type属性 | |
| empty_llm = args is not None and hasattr(args, 'prompt_type') and args.prompt_type == 'empty' | |
| inputs = self.prepare( | |
| [prompt, negative_prompt], | |
| x, #img这个gt给的是全噪声在推理 | |
| ref_image=ref_images, | |
| ref_image_raw=ref_images_raw, | |
| empty_llm=empty_llm) | |
| with torch.autocast(device_type=self.device.type, dtype=torch.float32) if self.device.type == "cpu" else torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): | |
| # Lpred,Rpred = self.double_denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,height=height,width=width)#图像中包括ref image | |
| Lpred,Rpred = self.denoise(**inputs,cfg_guidance=cfg_guidance,timesteps=timesteps,show_progress=show_progress,timesteps_truncate=1.0,)#图像中包括ref image | |
| Lpred=self.unpack(Lpred.float(),height,width) | |
| Rpred=self.unpack(Rpred.float(),height,width) | |
| if judge is not None: | |
| judge = Func.interpolate(judge, (height, width), mode='bilinear', align_corners=False) | |
| training_gt=self.ae.encode(judge) | |
| traing_loss = torch.nn.functional.mse_loss(Rpred,training_gt) | |
| print(f"training_loss with rgb2: {traing_loss}") | |
| norm = torch.linalg.norm(judge, dim=1, keepdim=True) | |
| norm[norm < 1e-9] = 1e-9 | |
| judge = judge / norm | |
| training_gt =self.ae.encode(judge) | |
| training_loss = torch.nn.functional.mse_loss(Rpred,training_gt) | |
| print(f"training_loss with normed_rgb: {training_loss}") | |
| Lpred = self.ae.decode(Lpred) | |
| Rpred = self.ae.decode(Rpred) | |
| Lpred = Lpred.clamp(-1, 1) | |
| Lpred = Lpred.mul(0.5).add(0.5) | |
| Rpred = Rpred.clamp(-1, 1) | |
| # Rpred = Rpred.mul(0.5).add(0.5) | |
| images_list = [] | |
| for img in Rpred.float(): | |
| images_list.append(self.output_process_image(F.to_pil_image(img), img_info)) | |
| return images_list, Lpred.float(), Rpred.float() | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint') | |
| parser.add_argument('--input_dir', type=str, required=True, help='Path to the input image directory') | |
| parser.add_argument('--output_dir', type=str, required=True, help='Path to the output image directory') | |
| parser.add_argument('--json_path', type=str, required=True, help='Path to the JSON file containing image names and prompts') | |
| parser.add_argument('--seed', type=int, default=42, help='Random seed for generation') | |
| parser.add_argument('--num_steps', type=int, default=28, help='Number of diffusion steps') | |
| parser.add_argument('--cfg_guidance', type=float, default=6.0, help='CFG guidance strength') | |
| parser.add_argument('--size_level', default=512, type=int) | |
| parser.add_argument('--offload', action='store_true', help='Use offload for large models') | |
| parser.add_argument('--quantized', action='store_true', help='Use fp8 model weights') | |
| parser.add_argument('--lora', type=str, default=None) | |
| parser.add_argument('--qwen2vl_model_path', type=str, default=str(DEFAULT_QWEN_DIR), help='Path to the local Qwen2.5-VL model directory') | |
| parser.add_argument('--empty_prompt_cache', type=str, default=str(EMPTY_PROMPT_LATENT_PATH), help='Path to the empty-prompt latent cache') | |
| args = parser.parse_args() | |
| assert os.path.exists(args.input_dir), f"Input directory {args.input_dir} does not exist." | |
| assert os.path.exists(args.json_path), f"JSON file {args.json_path} does not exist." | |
| args.output_dir = args.output_dir.rstrip('/') + ('-offload' if args.offload else "") + ('-quantized' if args.quantized else "") + f"-{args.size_level}" | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| image_and_prompts = json.load(open(args.json_path, 'r')) | |
| image_edit = ImageGenerator( | |
| ae_path=os.path.join(args.model_path, 'vae.safetensors'), | |
| dit_path=os.path.join(args.model_path, "step1x-edit-i1258-FP8.safetensors" if args.quantized else "step1x-edit-i1258.safetensors"), | |
| qwen2vl_model_path=args.qwen2vl_model_path, | |
| max_length=640, | |
| quantized=args.quantized, | |
| offload=args.offload, | |
| lora=args.lora, | |
| ) | |
| time_list = [] | |
| for image_name, prompt in image_and_prompts.items(): | |
| image_path = os.path.join(args.input_dir, image_name) | |
| output_path = os.path.join(args.output_dir, image_name) | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| start_time = time.time() | |
| images, _, _ = image_edit.generate_image( | |
| prompt, | |
| negative_prompt="", | |
| ref_images=Image.open(image_path).convert("RGB"), | |
| num_samples=1, | |
| num_steps=args.num_steps, | |
| cfg_guidance=args.cfg_guidance, | |
| seed=args.seed, | |
| show_progress=True, | |
| size_level=args.size_level, | |
| ) | |
| print(f"Time taken: {time.time() - start_time:.2f} seconds") | |
| time_list.append(time.time() - start_time) | |
| images[0].save(output_path, lossless=True) | |
| if len(time_list) > 1: | |
| print(f'average time for {args.output_dir}: ', sum(time_list[1:]) / len(time_list[1:])) | |
| if __name__ == "__main__": | |
| main() | |