Buckets:

rtrm's picture
download
raw
81.1 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Diffusion模型评估指南&quot;,&quot;local&quot;:&quot;diffusion模型评估指南&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;评估场景&quot;,&quot;local&quot;:&quot;评估场景&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;定性评估&quot;,&quot;local&quot;:&quot;定性评估&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;定量评估&quot;,&quot;local&quot;:&quot;定量评估&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;文本引导图像生成&quot;,&quot;local&quot;:&quot;文本引导图像生成&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;图像条件式文本生成图像&quot;,&quot;local&quot;:&quot;图像条件式文本生成图像&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;基于类别的图像生成&quot;,&quot;local&quot;:&quot;基于类别的图像生成&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_12652/zh/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/entry/start.ca7a833f.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/scheduler.e4ff9b64.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/singletons.71526a34.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/index.f9be34a7.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/paths.0df57e7f.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/entry/app.746b83f3.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/preload-helper.bb94e341.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/index.09f1bca0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/nodes/0.8237e20e.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/nodes/5.72407acc.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.3bffcf96.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/CodeBlock.3dd9a65d.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Diffusion模型评估指南&quot;,&quot;local&quot;:&quot;diffusion模型评估指南&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;评估场景&quot;,&quot;local&quot;:&quot;评估场景&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;定性评估&quot;,&quot;local&quot;:&quot;定性评估&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;定量评估&quot;,&quot;local&quot;:&quot;定量评估&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;文本引导图像生成&quot;,&quot;local&quot;:&quot;文本引导图像生成&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;图像条件式文本生成图像&quot;,&quot;local&quot;:&quot;图像条件式文本生成图像&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;基于类别的图像生成&quot;,&quot;local&quot;:&quot;基于类别的图像生成&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="diffusion模型评估指南" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#diffusion模型评估指南"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Diffusion模型评估指南</span></h1> <a target="_blank" href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/evaluation.ipynb" data-svelte-h="svelte-h14285"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="在 Colab 中打开"></a> <blockquote class="tip" data-svelte-h="svelte-1jugxt2"><p>鉴于当前已出现针对图像生成Diffusion模型的成熟评估框架(如<a href="https://crfm.stanford.edu/helm/heim/latest/" rel="nofollow">HEIM</a><a href="https://huggingface.co/papers/2307.06350" rel="nofollow">T2I-Compbench</a><a href="https://huggingface.co/papers/2310.11513" rel="nofollow">GenEval</a>),本文档部分内容已过时。</p></blockquote> <p data-svelte-h="svelte-1pes0mw"><a href="https://huggingface.co/docs/diffusers/stable_diffusion" rel="nofollow">Stable Diffusion</a> 这类生成模型的评估本质上是主观的。但作为开发者和研究者,我们经常需要在众多可能性中做出审慎选择。那么当面对不同生成模型(如 GANs、Diffusion 等)时,该如何决策?</p> <p data-svelte-h="svelte-k5obfv">定性评估容易产生偏差,可能导致错误结论;而定量指标又未必能准确反映图像质量。因此,通常需要结合定性与定量评估来获得更可靠的模型选择依据。</p> <p data-svelte-h="svelte-1ajaxvu">本文档将系统介绍扩散模型的定性与定量评估方法(非穷尽列举)。对于定量方法,我们将重点演示如何结合 <code>diffusers</code> 库实现这些评估。</p> <p data-svelte-h="svelte-e6klhp">文档所示方法同样适用于评估不同<a href="https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview" rel="nofollow">噪声调度器</a>在固定生成模型下的表现差异。</p> <h2 class="relative group"><a id="评估场景" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#评估场景"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>评估场景</span></h2> <p data-svelte-h="svelte-1iu3uxb">我们涵盖以下Diffusion模型管线的评估:</p> <ul data-svelte-h="svelte-1nhmbj0"><li>文本引导图像生成(如 <a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img" rel="nofollow"><code>StableDiffusionPipeline</code></a></li> <li>基于文本和输入图像的引导生成(如 <a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/img2img" rel="nofollow"><code>StableDiffusionImg2ImgPipeline</code></a><a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix" rel="nofollow"><code>StableDiffusionInstructPix2PixPipeline</code></a></li> <li>类别条件图像生成模型(如 <a href="https://huggingface.co/docs/diffusers/main/en/api/pipe" rel="nofollow"><code>DiTPipeline</code></a>)</li></ul> <h2 class="relative group"><a id="定性评估" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#定性评估"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>定性评估</span></h2> <p data-svelte-h="svelte-1txffyk">定性评估通常涉及对生成图像的人工评判。评估维度包括构图质量、图文对齐度和空间关系等方面。标准化的提示词能为这些主观指标提供统一基准。DrawBench和PartiPrompts是常用的定性评估提示词数据集,分别由<a href="https://imagen.research.google/" rel="nofollow">Imagen</a><a href="https://parti.research.google/" rel="nofollow">Parti</a>团队提出。</p> <p data-svelte-h="svelte-1vuvlvf">根据<a href="https://parti.research.google/" rel="nofollow">Parti官方网站</a>说明:</p> <blockquote data-svelte-h="svelte-m0q7q0"><p>PartiPrompts (P2)是我们发布的包含1600多个英文提示词的丰富集合,可用于测量模型在不同类别和挑战维度上的能力。</p></blockquote> <p data-svelte-h="svelte-19xz367"><img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts.png" alt="parti-prompts"></p> <p data-svelte-h="svelte-f20ejf">PartiPrompts包含以下字段:</p> <ul data-svelte-h="svelte-11gxful"><li>Prompt(提示词)</li> <li>Category(类别,如”抽象”、“世界知识”等)</li> <li>Challenge(难度等级,如”基础”、“复杂”、“文字与符号”等)</li></ul> <p data-svelte-h="svelte-1op8x1l">这些基准测试支持对不同图像生成模型进行并排人工对比评估。为此,🧨 Diffusers团队构建了<strong>Open Parti Prompts</strong>——一个基于Parti Prompts的社区驱动型定性评估基准,用于比较顶尖开源diffusion模型:</p> <ul data-svelte-h="svelte-jo9asf"><li><a href="https://huggingface.co/spaces/OpenGenAI/open-parti-prompts" rel="nofollow">Open Parti Prompts游戏</a>:展示10个parti提示词对应的4张生成图像,用户选择最符合提示的图片</li> <li><a href="https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard" rel="nofollow">Open Parti Prompts排行榜</a>:对比当前最优开源diffusion模型的性能榜单</li></ul> <p data-svelte-h="svelte-1hb73k9">为进行手动图像对比,我们演示如何使用<code>diffusers</code>处理部分PartiPrompts提示词。</p> <p data-svelte-h="svelte-16alyrm">以下是从不同挑战维度(基础、复杂、语言结构、想象力、文字与符号)采样的提示词示例(使用<a href="https://huggingface.co/datasets/nateraw/parti-prompts" rel="nofollow">PartiPrompts作为数据集</a>):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-comment"># prompts = load_dataset(&quot;nateraw/parti-prompts&quot;, split=&quot;train&quot;)</span>
<span class="hljs-comment"># prompts = prompts.shuffle()</span>
<span class="hljs-comment"># sample_prompts = [prompts[i][&quot;Prompt&quot;] for i in range(5)]</span>
<span class="hljs-comment"># Fixing these sample prompts in the interest of reproducibility.</span>
sample_prompts = [
<span class="hljs-string">&quot;a corgi&quot;</span>,
<span class="hljs-string">&quot;a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky&quot;</span>,
<span class="hljs-string">&quot;a car with no windows&quot;</span>,
<span class="hljs-string">&quot;a cube made of porcupine&quot;</span>,
<span class="hljs-string">&#x27;The saying &quot;BE EXCELLENT TO EACH OTHER&quot; written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.&#x27;</span>,
]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-rdopsg">现在我们可以使用Stable Diffusion(<a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" rel="nofollow">v1-4 checkpoint</a>)生成这些提示词对应的图像:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
seed = <span class="hljs-number">0</span>
generator = torch.manual_seed(seed)
images = sd_pipeline(sample_prompts, num_images_per_prompt=<span class="hljs-number">1</span>, generator=generator).images<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-4i7yd5"><img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-14.png" alt="parti-prompts-14"></p> <p data-svelte-h="svelte-8snsne">我们也可以通过设置<code>num_images_per_prompt</code>参数来比较同一提示词生成的不同图像。使用不同检查点(<a href="https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5" rel="nofollow">v1-5</a>)运行相同流程后,结果如下:</p> <p data-svelte-h="svelte-gipltn"><img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-15.png" alt="parti-prompts-15"></p> <p data-svelte-h="svelte-ueq19r">当使用多个待评估模型为所有提示词生成若干图像后,这些结果将提交给人类评估员进行打分。有关DrawBench和PartiPrompts基准测试的更多细节,请参阅各自的论文。</p> <blockquote class="tip" data-svelte-h="svelte-yewjf8"><p>在模型训练过程中查看推理样本有助于评估训练进度。我们的<a href="https://github.com/huggingface/diffusers/tree/main/examples/" rel="nofollow">训练脚本</a>支持此功能,并额外提供TensorBoard和Weights &amp; Biases日志记录功能。</p></blockquote> <h2 class="relative group"><a id="定量评估" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#定量评估"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>定量评估</span></h2> <p data-svelte-h="svelte-r7hhw0">本节将指导您如何评估三种不同的扩散流程,使用以下指标:</p> <ul data-svelte-h="svelte-b93mpw"><li>CLIP分数</li> <li>CLIP方向相似度</li> <li>FID(弗雷歇起始距离)</li></ul> <h3 class="relative group"><a id="文本引导图像生成" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#文本引导图像生成"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>文本引导图像生成</span></h3> <p data-svelte-h="svelte-n2bg45"><a href="https://huggingface.co/papers/2104.08718" rel="nofollow">CLIP分数</a>用于衡量图像-标题对的匹配程度。CLIP分数越高表明匹配度越高🔼。该分数是对”匹配度”这一定性概念的量化测量,也可以理解为图像与标题之间的语义相似度。研究发现CLIP分数与人类判断具有高度相关性。</p> <p data-svelte-h="svelte-1hzc3dk">首先加载<code>StableDiffusionPipeline</code></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline
<span class="hljs-keyword">import</span> torch
model_ckpt = <span class="hljs-string">&quot;CompVis/stable-diffusion-v1-4&quot;</span>
sd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1k1g1o6">使用多个提示词生成图像:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->prompts = [
<span class="hljs-string">&quot;a photo of an astronaut riding a horse on mars&quot;</span>,
<span class="hljs-string">&quot;A high tech solarpunk utopia in the Amazon rainforest&quot;</span>,
<span class="hljs-string">&quot;A pikachu fine dining with a view to the Eiffel Tower&quot;</span>,
<span class="hljs-string">&quot;A mecha robot in a favela in expressionist style&quot;</span>,
<span class="hljs-string">&quot;an insect robot preparing a delicious meal&quot;</span>,
<span class="hljs-string">&quot;A small cabin on top of a snowy mountain in the style of Disney, artstation&quot;</span>,
]
images = sd_pipeline(prompts, num_images_per_prompt=<span class="hljs-number">1</span>, output_type=<span class="hljs-string">&quot;np&quot;</span>).images
<span class="hljs-built_in">print</span>(images.shape)
<span class="hljs-comment"># (6, 512, 512, 3)</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1k28vwi">然后计算CLIP分数:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torchmetrics.functional.multimodal <span class="hljs-keyword">import</span> clip_score
<span class="hljs-keyword">from</span> functools <span class="hljs-keyword">import</span> partial
clip_score_fn = partial(clip_score, model_name_or_path=<span class="hljs-string">&quot;openai/clip-vit-base-patch16&quot;</span>)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">calculate_clip_score</span>(<span class="hljs-params">images, prompts</span>):
images_int = (images * <span class="hljs-number">255</span>).astype(<span class="hljs-string">&quot;uint8&quot;</span>)
clip_score = clip_score_fn(torch.from_numpy(images_int).permute(<span class="hljs-number">0</span>, <span class="hljs-number">3</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>), prompts).detach()
<span class="hljs-keyword">return</span> <span class="hljs-built_in">round</span>(<span class="hljs-built_in">float</span>(clip_score), <span class="hljs-number">4</span>)
sd_clip_score = calculate_clip_score(images, prompts)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;CLIP分数: <span class="hljs-subst">{sd_clip_score}</span>&quot;</span>)
<span class="hljs-comment"># CLIP分数: 35.7038</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1iltbps">上述示例中,我们为每个提示生成一张图像。如果为每个提示生成多张图像,则需要计算每个提示生成图像的平均分数。</p> <p data-svelte-h="svelte-mydfaa">当需要比较两个兼容<code>StableDiffusionPipeline</code>的检查点时,应在调用管道时传入生成器。首先使用<a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" rel="nofollow">v1-4 Stable Diffusion检查点</a>以固定种子生成图像:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->seed = <span class="hljs-number">0</span>
generator = torch.manual_seed(seed)
images = sd_pipeline(prompts, num_images_per_prompt=<span class="hljs-number">1</span>, generator=generator, output_type=<span class="hljs-string">&quot;np&quot;</span>).images<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1o4hn2o">然后加载<a href="https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5" rel="nofollow">v1-5检查点</a>生成图像:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model_ckpt_1_5 = <span class="hljs-string">&quot;stable-diffusion-v1-5/stable-diffusion-v1-5&quot;</span>
sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda&quot;</span>)
images_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=<span class="hljs-number">1</span>, generator=generator, output_type=<span class="hljs-string">&quot;np&quot;</span>).images<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-2io45m">最后比较两者的CLIP分数:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->sd_clip_score_1_4 = calculate_clip_score(images, prompts)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;v-1-4版本的CLIP分数: <span class="hljs-subst">{sd_clip_score_1_4}</span>&quot;</span>)
<span class="hljs-comment"># v-1-4版本的CLIP分数: 34.9102</span>
sd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;v-1-5版本的CLIP分数: <span class="hljs-subst">{sd_clip_score_1_5}</span>&quot;</span>)
<span class="hljs-comment"># v-1-5版本的CLIP分数: 36.2137</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ujkfdk">结果表明<a href="https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5" rel="nofollow">v1-5</a>检查点性能优于前代。但需注意,我们用于计算CLIP分数的提示词数量较少。实际评估时应使用更多样化且数量更大的提示词集。</p> <blockquote class="warning" data-svelte-h="svelte-iu3cp7"><p>该分数存在固有局限性:训练数据中的标题是从网络爬取,并提取自图片关联的<code>alt</code>等标签。这些描述未必符合人类描述图像的方式,因此我们需要人工”设计”部分提示词。</p></blockquote> <h3 class="relative group"><a id="图像条件式文本生成图像" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#图像条件式文本生成图像"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>图像条件式文本生成图像</span></h3> <p data-svelte-h="svelte-1yzloih">这种情况下,生成管道同时接受输入图像和文本提示作为条件。以<code>StableDiffusionInstructPix2PixPipeline</code>为例,该管道接收编辑指令作为输入提示,并接受待编辑的输入图像。</p> <p data-svelte-h="svelte-7c2hev">示例图示:</p> <p data-svelte-h="svelte-1u1nt56"><img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png" alt="编辑指令"></p> <p data-svelte-h="svelte-1nycz9x">评估此类模型的策略之一是测量两幅图像间变化的连贯性(通过<a href="https://huggingface.co/docs/transformers/model_doc/clip" rel="nofollow">CLIP</a>定义)中两个图像之间的变化与两个图像描述之间的变化的一致性(如论文<a href="https://huggingface.co/papers/2108.00946" rel="nofollow">《CLIP-Guided Domain Adaptation of Image Generators》</a>所示)。这被称为“<strong>CLIP方向相似度</strong>”。</p> <ul data-svelte-h="svelte-169ilj8"><li><strong>描述1</strong>对应输入图像(图像1),即待编辑的图像。</li> <li><strong>描述2</strong>对应编辑后的图像(图像2),应反映编辑指令。</li></ul> <p data-svelte-h="svelte-f9n3pa">以下是示意图:</p> <p data-svelte-h="svelte-8e89hb"><img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-consistency.png" alt="edit-consistency"></p> <p data-svelte-h="svelte-r23mbf">我们准备了一个小型数据集来实现该指标。首先加载数据集:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
dataset = load_dataset(<span class="hljs-string">&quot;sayakpaul/instructpix2pix-demo&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
dataset.features<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;input&#x27;</span>: Value(dtype=<span class="hljs-string">&#x27;string&#x27;</span>, <span class="hljs-built_in">id</span>=None),
<span class="hljs-string">&#x27;edit&#x27;</span>: Value(dtype=<span class="hljs-string">&#x27;string&#x27;</span>, <span class="hljs-built_in">id</span>=None),
<span class="hljs-string">&#x27;output&#x27;</span>: Value(dtype=<span class="hljs-string">&#x27;string&#x27;</span>, <span class="hljs-built_in">id</span>=None),
<span class="hljs-string">&#x27;image&#x27;</span>: Image(decode=True, <span class="hljs-built_in">id</span>=None)}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-4az32q">数据字段说明:</p> <ul data-svelte-h="svelte-1rw2hq2"><li><code>input</code>:与<code>image</code>对应的原始描述。</li> <li><code>edit</code>:编辑指令。</li> <li><code>output</code>:反映<code>edit</code>指令的修改后描述。</li></ul> <p data-svelte-h="svelte-11cxr2x">查看一个样本:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->idx = <span class="hljs-number">0</span>
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Original caption: <span class="hljs-subst">{dataset[idx][<span class="hljs-string">&#x27;input&#x27;</span>]}</span>&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Edit instruction: <span class="hljs-subst">{dataset[idx][<span class="hljs-string">&#x27;edit&#x27;</span>]}</span>&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Modified caption: <span class="hljs-subst">{dataset[idx][<span class="hljs-string">&#x27;output&#x27;</span>]}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Original caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles <span class="hljs-keyword">in</span> the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has <span class="hljs-string">&#x27;everything you could hope for&#x27;</span>, according to Big 7 Travel. It boasts <span class="hljs-string">&#x27;crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills&#x27;</span>
Edit instruction: make the isles all white marble
Modified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles <span class="hljs-keyword">in</span> the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has <span class="hljs-string">&#x27;everything you could hope for&#x27;</span>, according to Big 7 Travel. It boasts <span class="hljs-string">&#x27;crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xp6zg6">对应的图像:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->dataset[idx][<span class="hljs-string">&quot;image&quot;</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ogvsh3"><img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-dataset.png" alt="edit-dataset"></p> <p data-svelte-h="svelte-16d8j0z">我们将根据编辑指令修改数据集中的图像,并计算方向相似度。</p> <p data-svelte-h="svelte-1yqwgt0">首先加载<code>StableDiffusionInstructPix2PixPipeline</code></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionInstructPix2PixPipeline
instruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
<span class="hljs-string">&quot;timbrooks/instruct-pix2pix&quot;</span>, torch_dtype=torch.float16
).to(<span class="hljs-string">&quot;cuda&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1llczdr">执行编辑操作:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">def</span> <span class="hljs-title function_">edit_image</span>(<span class="hljs-params">input_image, instruction</span>):
image = instruct_pix2pix_pipeline(
instruction,
image=input_image,
output_type=<span class="hljs-string">&quot;np&quot;</span>,
generator=generator,
).images[<span class="hljs-number">0</span>]
<span class="hljs-keyword">return</span> image
input_images = []
original_captions = []
modified_captions = []
edited_images = []
<span class="hljs-keyword">for</span> idx <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(dataset)):
input_image = dataset[idx][<span class="hljs-string">&quot;image&quot;</span>]
edit_instruction = dataset[idx][<span class="hljs-string">&quot;edit&quot;</span>]
edited_image = edit_image(input_image, edit_instruction)
input_images.append(np.array(input_image))
original_captions.append(dataset[idx][<span class="hljs-string">&quot;input&quot;</span>])
modified_captions.append(dataset[idx][<span class="hljs-string">&quot;output&quot;</span>])
edited_images.append(edited_image)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-y69d0n">为测量方向相似度,我们首先加载CLIP的图像和文本编码器:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> (
CLIPTokenizer,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
CLIPImageProcessor,
)
clip_id = <span class="hljs-string">&quot;openai/clip-vit-large-patch14&quot;</span>
tokenizer = CLIPTokenizer.from_pretrained(clip_id)
text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(<span class="hljs-string">&quot;cuda&quot;</span>)
image_processor = CLIPImageProcessor.from_pretrained(clip_id)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(<span class="hljs-string">&quot;cuda&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1olfepe">注意我们使用的是特定CLIP检查点——<code>openai/clip-vit-large-patch14</code>,因为Stable Diffusion预训练正是基于此CLIP变体。详见<a href="https://huggingface.co/docs/transformers/model_doc/clip" rel="nofollow">文档</a></p> <p data-svelte-h="svelte-1cfzew0">接着准备计算方向相似度的PyTorch <code>nn.Module</code></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch.nn <span class="hljs-keyword">as</span> nn
<span class="hljs-keyword">import</span> torch.nn.functional <span class="hljs-keyword">as</span> F
<span class="hljs-keyword">class</span> <span class="hljs-title class_">DirectionalSimilarity</span>(nn.Module):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, tokenizer, text_encoder, image_processor, image_encoder</span>):
<span class="hljs-built_in">super</span>().__init__()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.image_processor = image_processor
self.image_encoder = image_encoder
<span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_image</span>(<span class="hljs-params">self, image</span>):
image = self.image_processor(image, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>)[<span class="hljs-string">&quot;pixel_values&quot;</span>]
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;pixel_values&quot;</span>: image.to(<span class="hljs-string">&quot;cuda&quot;</span>)}
<span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_text</span>(<span class="hljs-params">self, text</span>):
inputs = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding=<span class="hljs-string">&quot;max_length&quot;</span>,
truncation=<span class="hljs-literal">True</span>,
return_tensors=<span class="hljs-string">&quot;pt&quot;</span>,
)
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;input_ids&quot;</span>: inputs.input_ids.to(<span class="hljs-string">&quot;cuda&quot;</span>)}
<span class="hljs-keyword">def</span> <span class="hljs-title function_">encode_image</span>(<span class="hljs-params">self, image</span>):
preprocessed_image = self.preprocess_image(image)
image_features = self.image_encoder(**preprocessed_image).image_embeds
image_features = image_features / image_features.norm(dim=<span class="hljs-number">1</span>, keepdim=<span class="hljs-literal">True</span>)
<span class="hljs-keyword">return</span> image_features
<span class="hljs-keyword">def</span> <span class="hljs-title function_">encode_text</span>(<span class="hljs-params">self, text</span>):
tokenized_text = self.tokenize_text(text)
text_features = self.text_encoder(**tokenized_text).text_embeds
text_features = text_features / text_features.norm(dim=<span class="hljs-number">1</span>, keepdim=<span class="hljs-literal">True</span>)
<span class="hljs-keyword">return</span> text_features
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_directional_similarity</span>(<span class="hljs-params">self, img_feat_one, img_feat_two, text_feat_one, text_feat_two</span>):
sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one)
<span class="hljs-keyword">return</span> sim_direction
<span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, image_one, image_two, caption_one, caption_two</span>):
img_feat_one = self.encode_image(image_one)
img_feat_two = self.encode_image(image_two)
text_feat_one = self.encode_text(caption_one)
text_feat_two = self.encode_text(caption_two)
directional_similarity = self.compute_directional_similarity(
img_feat_one, img_feat_two, text_feat_one, text_feat_two
)
<span class="hljs-keyword">return</span> directional_similarity<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ah96o9">现在让我们使用<code>DirectionalSimilarity</code>模块:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->dir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder)
scores = []
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(input_images)):
original_image = input_images[i]
original_caption = original_captions[i]
edited_image = edited_images[i]
modified_caption = modified_captions[i]
similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption)
scores.append(<span class="hljs-built_in">float</span>(similarity_score.detach().cpu()))
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;CLIP方向相似度: <span class="hljs-subst">{np.mean(scores)}</span>&quot;</span>)
<span class="hljs-comment"># CLIP方向相似度: 0.0797976553440094</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1evlka">与CLIP分数类似,CLIP方向相似度数值越高越好。</p> <p data-svelte-h="svelte-du0g2y">需要注意的是,<code>StableDiffusionInstructPix2PixPipeline</code>提供了两个控制参数<code>image_guidance_scale</code><code>guidance_scale</code>来调节最终编辑图像的质量。建议您尝试调整这两个参数,观察它们对方向相似度的影响。</p> <p data-svelte-h="svelte-1ck0ggc">我们可以扩展这个度量标准来评估原始图像与编辑版本的相似度,只需计算<code>F.cosine_similarity(img_feat_two, img_feat_one)</code>。对于这类编辑任务,我们仍希望尽可能保留图像的主要语义特征(即保持较高的相似度分数)。</p> <p data-svelte-h="svelte-15kssed">该度量方法同样适用于类似流程,例如<a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline" rel="nofollow"><code>StableDiffusionPix2PixZeroPipeline</code></a></p> <blockquote class="tip" data-svelte-h="svelte-yfuov9"><p>CLIP分数和CLIP方向相似度都依赖CLIP模型,可能导致评估结果存在偏差。</p></blockquote> <p data-svelte-h="svelte-10m0vep"><strong><em>扩展IS、FID(后文讨论)或KID等指标存在困难</em></strong>,当被评估模型是在大型图文数据集(如<a href="https://laion.ai/blog/laion-5b/" rel="nofollow">LAION-5B数据集</a>)上预训练时。因为这些指标的底层都使用了在ImageNet-1k数据集上预训练的InceptionNet来提取图像特征。Stable Diffusion的预训练数据集与InceptionNet的预训练数据集可能重叠有限,因此不适合作为特征提取器。</p> <p data-svelte-h="svelte-1sls2vo"><strong><em>上述指标更适合评估类别条件模型</em></strong>,例如<a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit" rel="nofollow">DiT</a>。该模型是在ImageNet-1k类别条件下预训练的。
这是9篇文档中的第8部分。</p> <h3 class="relative group"><a id="基于类别的图像生成" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#基于类别的图像生成"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>基于类别的图像生成</span></h3> <p data-svelte-h="svelte-1be9oeu">基于类别的生成模型通常是在带有类别标签的数据集(如<a href="https://huggingface.co/datasets/imagenet-1k" rel="nofollow">ImageNet-1k</a>)上进行预训练的。评估这些模型的常用指标包括Fréchet Inception Distance(FID)、Kernel Inception Distance(KID)和Inception Score(IS)。本文档重点介绍FID(<a href="https://huggingface.co/papers/1706.08500" rel="nofollow">Heusel等人</a>),并展示如何使用<a href="https://huggingface.co/docs/diffusers/api/pipelines/dit" rel="nofollow"><code>DiTPipeline</code></a>计算该指标,该管道底层使用了<a href="https://huggingface.co/papers/2212.09748" rel="nofollow">DiT模型</a></p> <p data-svelte-h="svelte-18nm4k9">FID旨在衡量两组图像数据集的相似程度。根据<a href="https://mmgeneration.readthedocs.io/en/latest/quick_run.html#fid" rel="nofollow">此资源</a></p> <blockquote data-svelte-h="svelte-15x4a9v"><p>Fréchet Inception Distance是衡量两组图像数据集相似度的指标。研究表明其与人类对视觉质量的主观判断高度相关,因此最常用于评估生成对抗网络(GAN)生成样本的质量。FID通过计算Inception网络特征表示所拟合的两个高斯分布之间的Fréchet距离来实现。</p></blockquote> <p data-svelte-h="svelte-fg97ol">这两个数据集本质上是真实图像数据集和生成图像数据集(本例中为人工生成的图像)。FID通常基于两个大型数据集计算,但本文档将使用两个小型数据集进行演示。</p> <p data-svelte-h="svelte-pwwdv9">首先下载ImageNet-1k训练集中的部分图像:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> zipfile <span class="hljs-keyword">import</span> ZipFile
<span class="hljs-keyword">import</span> requests
<span class="hljs-keyword">def</span> <span class="hljs-title function_">download</span>(<span class="hljs-params">url, local_filepath</span>):
r = requests.get(url)
<span class="hljs-keyword">with</span> <span class="hljs-built_in">open</span>(local_filepath, <span class="hljs-string">&quot;wb&quot;</span>) <span class="hljs-keyword">as</span> f:
f.write(r.content)
<span class="hljs-keyword">return</span> local_filepath
dummy_dataset_url = <span class="hljs-string">&quot;https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip&quot;</span>
local_filepath = download(dummy_dataset_url, dummy_dataset_url.split(<span class="hljs-string">&quot;/&quot;</span>)[-<span class="hljs-number">1</span>])
<span class="hljs-keyword">with</span> ZipFile(local_filepath, <span class="hljs-string">&quot;r&quot;</span>) <span class="hljs-keyword">as</span> zipper:
zipper.extractall(<span class="hljs-string">&quot;.&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image
<span class="hljs-keyword">import</span> os
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
dataset_path = <span class="hljs-string">&quot;sample-imagenet-images&quot;</span>
image_paths = <span class="hljs-built_in">sorted</span>([os.path.join(dataset_path, x) <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> os.listdir(dataset_path)])
real_images = [np.array(Image.<span class="hljs-built_in">open</span>(path).convert(<span class="hljs-string">&quot;RGB&quot;</span>)) <span class="hljs-keyword">for</span> path <span class="hljs-keyword">in</span> image_paths]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6bz4xb">这些是来自以下ImageNet-1k类别的10张图像:“cassette_player”、“chain_saw”(2张)、“church”、“gas_pump”(3张)、“parachute”(2张)和”tench”。</p> <p align="center" data-svelte-h="svelte-ke5ahc"><img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/real-images.png" alt="真实图像"><br> <em>真实图像</em></p> <p data-svelte-h="svelte-ilkg7f">加载图像后,我们对其进行轻量级预处理以便用于FID计算:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torchvision.transforms <span class="hljs-keyword">import</span> functional <span class="hljs-keyword">as</span> F
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_image</span>(<span class="hljs-params">image</span>):
image = torch.tensor(image).unsqueeze(<span class="hljs-number">0</span>)
image = image.permute(<span class="hljs-number">0</span>, <span class="hljs-number">3</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>) / <span class="hljs-number">255.0</span>
<span class="hljs-keyword">return</span> F.center_crop(image, (<span class="hljs-number">256</span>, <span class="hljs-number">256</span>))
real_images = torch.stack([dit_pipeline.preprocess_image(image) <span class="hljs-keyword">for</span> image <span class="hljs-keyword">in</span> real_images])
<span class="hljs-built_in">print</span>(real_images.shape)
<span class="hljs-comment"># torch.Size([10, 3, 256, 256])</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1g1dyy3">我们现在加载<a href="https://huggingface.co/docs/diffusers/api/pipelines/dit" rel="nofollow"><code>DiTPipeline</code></a>来生成基于上述类别的条件图像。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiTPipeline, DPMSolverMultistepScheduler
dit_pipeline = DiTPipeline.from_pretrained(<span class="hljs-string">&quot;facebook/DiT-XL-2-256&quot;</span>, torch_dtype=torch.float16)
dit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)
dit_pipeline = dit_pipeline.to(<span class="hljs-string">&quot;cuda&quot;</span>)
seed = <span class="hljs-number">0</span>
generator = torch.manual_seed(seed)
words = [
<span class="hljs-string">&quot;cassette player&quot;</span>,
<span class="hljs-string">&quot;chainsaw&quot;</span>,
<span class="hljs-string">&quot;chainsaw&quot;</span>,
<span class="hljs-string">&quot;church&quot;</span>,
<span class="hljs-string">&quot;gas pump&quot;</span>,
<span class="hljs-string">&quot;gas pump&quot;</span>,
<span class="hljs-string">&quot;gas pump&quot;</span>,
<span class="hljs-string">&quot;parachute&quot;</span>,
<span class="hljs-string">&quot;parachute&quot;</span>,
<span class="hljs-string">&quot;tench&quot;</span>,
]
class_ids = dit_pipeline.get_label_ids(words)
output = dit_pipeline(class_labels=class_ids, generator=generator, output_type=<span class="hljs-string">&quot;np&quot;</span>)
fake_images = output.images
fake_images = torch.tensor(fake_images)
fake_images = fake_images.permute(<span class="hljs-number">0</span>, <span class="hljs-number">3</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>)
<span class="hljs-built_in">print</span>(fake_images.shape)
<span class="hljs-comment"># torch.Size([10, 3, 256, 256])</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1wphc0p">现在,我们可以使用<a href="https://torchmetrics.readthedocs.io/" rel="nofollow"><code>torchmetrics</code></a>计算FID分数。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torchmetrics.image.fid <span class="hljs-keyword">import</span> FrechetInceptionDistance
fid = FrechetInceptionDistance(normalize=<span class="hljs-literal">True</span>)
fid.update(real_images, real=<span class="hljs-literal">True</span>)
fid.update(fake_images, real=<span class="hljs-literal">False</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;FID分数: <span class="hljs-subst">{<span class="hljs-built_in">float</span>(fid.compute())}</span>&quot;</span>)
<span class="hljs-comment"># FID分数: 177.7147216796875</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1v2r3ps">FID分数越低越好。以下因素会影响FID结果:</p> <ul data-svelte-h="svelte-1iufyfv"><li>图像数量(包括真实图像和生成图像)</li> <li>扩散过程中引入的随机性</li> <li>扩散过程的推理步数</li> <li>扩散过程中使用的调度器</li></ul> <p data-svelte-h="svelte-1skjrzw">对于最后两点,最佳实践是使用不同的随机种子和推理步数进行多次评估,然后报告平均结果。</p> <blockquote class="warning" data-svelte-h="svelte-11xfkr6"><p>FID结果往往具有脆弱性,因为它依赖于许多因素:</p> <ul><li>计算过程中使用的特定Inception模型</li> <li>计算实现的准确性</li> <li>图像格式(PNG和JPG的起点不同)</li></ul> <p>需要注意的是,FID通常在比较相似实验时最有用,但除非作者仔细公开FID测量代码,否则很难复现论文结果。</p> <p>这些注意事项同样适用于其他相关指标,如KID和IS。</p></blockquote> <p data-svelte-h="svelte-147tlvf">最后,让我们可视化检查这些<code>fake_images</code></p> <p align="center" data-svelte-h="svelte-1if1i1w"><img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/fake-images.png" alt="生成图像"><br> <em>生成图像示例</em></p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/zh/conceptual/evaluation.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_oaf0l2 = {
assets: "/docs/diffusers/pr_12652/zh",
base: "/docs/diffusers/pr_12652/zh",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_12652/zh/_app/immutable/entry/start.ca7a833f.js"),
import("/docs/diffusers/pr_12652/zh/_app/immutable/entry/app.746b83f3.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 5],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
81.1 kB
·
Xet hash:
b85e33c8ccc9b68beb9bd72b1ae4592c99339b22b7ee9edfd995760e84c7bdee

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.