{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b3b23a40-8354-4287-bac2-32f9d084fff3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_validators.py:202: UserWarning: The `local_dir_use_symlinks` argument is deprecated and ignored in `hf_hub_download`. Downloading to a local directory does not use symlinks anymore.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "sdxs_vae log-variance: 1.840\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "vae8 log-variance: 1.840\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "vae9 log-variance: 1.840\n", "Готово\n" ] } ], "source": [ "import torch\n", "from PIL import Image\n", "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n", "from torchvision.transforms.functional import to_pil_image\n", "import matplotlib.pyplot as plt\n", "import os\n", "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n", "\n", "# путь к вашей картинке\n", "IMG_PATH = \"1234567890.png\"\n", "OUT_DIR = \"vaetest\"\n", "device = \"cuda\"\n", "dtype = torch.float32 # ← единый float32\n", "os.makedirs(OUT_DIR, exist_ok=True)\n", "\n", "# список VAE\n", "VAES = {\n", " #\"sdxl\": \"madebyollin/sdxl-vae-fp16-fix\",\n", " \"sdxs_vae\": \"AiArtLab/sdxs-1b\",\n", " \"vae8\": \"/workspace/simplevae2x/vae8\",\n", " \"vae9\": \"/workspace/simplevae2x/vae9\"\n", "}\n", "\n", "def load_image(path):\n", " img = Image.open(path).convert('RGB')\n", " # обрезаем до кратности 8\n", " w, h = img.size\n", " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n", " tensor = ToTensor()(img).unsqueeze(0) # [0,1]\n", " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor) # [-1,1]\n", " return img, tensor.to(device, dtype=dtype)\n", "\n", "# обратно в PIL\n", "def tensor_to_img(t):\n", " t = (t * 0.5 + 0.5).clamp(0, 1)\n", " return to_pil_image(t[0])\n", "\n", "def logvariance(latents):\n", " \"\"\"Возвращает лог-дисперсию по всем элементам.\"\"\"\n", " return torch.log(latents.var() + 1e-8).item()\n", "\n", "def plot_latent_distribution(latents, title, save_path):\n", " \"\"\"Гистограмма + QQ-plot.\"\"\"\n", " lat = latents.detach().cpu().numpy().flatten()\n", " plt.figure(figsize=(10, 4))\n", "\n", " # гистограмма\n", " plt.subplot(1, 2, 1)\n", " plt.hist(lat, bins=100, density=True, alpha=0.7, color='steelblue')\n", " plt.title(f\"{title} histogram\")\n", " plt.xlabel(\"latent value\")\n", " plt.ylabel(\"density\")\n", "\n", " # QQ-plot\n", " from scipy.stats import probplot\n", " plt.subplot(1, 2, 2)\n", " probplot(lat, dist=\"norm\", plot=plt)\n", " plt.title(f\"{title} QQ-plot\")\n", "\n", " plt.tight_layout()\n", " plt.savefig(save_path)\n", " plt.close()\n", "\n", "for name, repo in VAES.items():\n", " if name==\"sdxs_vae\":\n", " vae = AsymmetricAutoencoderKL.from_pretrained(repo, subfolder=\"vae\", torch_dtype=dtype).to(device)\n", " else:\n", " vae = AsymmetricAutoencoderKL.from_pretrained(repo, torch_dtype=dtype).to(device)#, subfolder=\"vae\", variant=\"fp16\"\n", "\n", " cfg = vae.config\n", " scale = getattr(cfg, \"scaling_factor\", 1.)\n", " shift = getattr(cfg, \"shift_factor\", 0.0)\n", " mean = getattr(cfg, \"latents_mean\", None)\n", " std = getattr(cfg, \"latents_std\", None)\n", "\n", " C = 4 # 4 для SDXL\n", " if mean is not None:\n", " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, C, 1, 1)\n", " if std is not None:\n", " std = torch.tensor(std, device=device, dtype=dtype).view(1, C, 1, 1)\n", " if shift is not None:\n", " shift = torch.tensor(shift, device=device, dtype=dtype)\n", " else:\n", " shift = 0.0 \n", "\n", " scale = torch.tensor(scale, device=device, dtype=dtype)\n", "\n", " img, x = load_image(IMG_PATH)\n", " img.save(os.path.join(OUT_DIR, f\"original.png\"))\n", "\n", " with torch.no_grad():\n", " # encode\n", " latents = vae.encode(x).latent_dist.sample().to(dtype)\n", " if mean is not None and std is not None:\n", " latents = (latents - mean) / std\n", " latents = latents * scale + shift\n", "\n", " lv = logvariance(latents)\n", " print(f\"{name} log-variance: {lv:.3f}\")\n", "\n", " # график\n", " plot_latent_distribution(latents, f\"{name}_latents\",\n", " os.path.join(OUT_DIR, f\"dist_{name}.png\"))\n", "\n", " # decode\n", " latents = (latents - shift) / scale\n", " if mean is not None and std is not None:\n", " latents = latents * std + mean\n", " rec = vae.decode(latents).sample\n", "\n", " tensor_to_img(rec).save(os.path.join(OUT_DIR, f\"decoded_{name}.png\"))\n", "\n", "print(\"Готово\")" ] }, { "cell_type": "code", "execution_count": null, "id": "200b72ab-1978-4d71-9aba-b1ef97cf0b27", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }