Buckets:

rtrm's picture
download
raw
40.5 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;文生图&quot;,&quot;local&quot;:&quot;文生图&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;脚本参数&quot;,&quot;local&quot;:&quot;脚本参数&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Min-SNR加权策略&quot;,&quot;local&quot;:&quot;min-snr加权策略&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;训练脚本解析&quot;,&quot;local&quot;:&quot;训练脚本解析&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;启动脚本&quot;,&quot;local&quot;:&quot;启动脚本&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;后续步骤&quot;,&quot;local&quot;:&quot;后续步骤&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_12652/zh/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/entry/start.ca7a833f.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/scheduler.e4ff9b64.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/singletons.71526a34.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/index.f9be34a7.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/paths.0df57e7f.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/entry/app.746b83f3.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/preload-helper.bb94e341.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/index.09f1bca0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/nodes/0.8237e20e.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/nodes/49.26a8f157.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.3bffcf96.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/CodeBlock.3dd9a65d.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/HfOption.44827c7f.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;文生图&quot;,&quot;local&quot;:&quot;文生图&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;脚本参数&quot;,&quot;local&quot;:&quot;脚本参数&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Min-SNR加权策略&quot;,&quot;local&quot;:&quot;min-snr加权策略&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;训练脚本解析&quot;,&quot;local&quot;:&quot;训练脚本解析&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;启动脚本&quot;,&quot;local&quot;:&quot;启动脚本&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;后续步骤&quot;,&quot;local&quot;:&quot;后续步骤&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="文生图" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#文生图"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>文生图</span></h1> <blockquote class="warning" data-svelte-h="svelte-14t9fto"><p>文生图训练脚本目前处于实验阶段,容易出现过拟合和灾难性遗忘等问题。建议尝试不同超参数以获得最佳数据集适配效果。</p></blockquote> <p data-svelte-h="svelte-1c3ggng">Stable Diffusion 等文生图模型能够根据文本提示生成对应图像。</p> <p data-svelte-h="svelte-tf4mti">模型训练对硬件要求较高,但启用 <code>gradient_checkpointing</code><code>mixed_precision</code> 后,可在单块24GB显存GPU上完成训练。如需更大批次或更快训练速度,建议使用30GB以上显存的GPU设备。通过启用 <a href="../optimization/xformers">xFormers</a> 内存高效注意力机制可降低显存占用。JAX/Flax 训练方案也支持TPU/GPU高效训练,但不支持梯度检查点、梯度累积和xFormers。使用Flax训练时建议配备30GB以上显存GPU或TPU v3。</p> <p data-svelte-h="svelte-mmr0em">本指南将详解 <a href="https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py" rel="nofollow">train_text_to_image.py</a> 训练脚本,助您掌握其原理并适配自定义需求。</p> <p data-svelte-h="svelte-mwjd95">运行脚本前请确保已从源码安装库:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->git <span class="hljs-built_in">clone</span> https://github.com/huggingface/diffusers
<span class="hljs-built_in">cd</span> diffusers
pip install .<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ie0m52">然后进入包含训练脚本的示例目录,安装对应依赖:</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">PyTorch </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">Flax </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">cd</span> examples/text_to_image
pip install -r requirements.txt<!-- HTML_TAG_END --></pre></div> </div> <blockquote class="tip" data-svelte-h="svelte-tt4yuu"><p>🤗 Accelerate 是支持多GPU/TPU训练和混合精度的工具库,能根据硬件环境自动配置训练参数。参阅 🤗 Accelerate <a href="https://huggingface.co/docs/accelerate/quicktour" rel="nofollow">快速入门</a> 了解更多。</p></blockquote> <p data-svelte-h="svelte-2mxnyi">初始化 🤗 Accelerate 环境:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8qo0b1">要创建默认配置环境(不进行交互式选择):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config default<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hg4sov">若环境不支持交互式shell(如notebook),可使用:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate.utils <span class="hljs-keyword">import</span> write_basic_config
write_basic_config()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1g3ykjk">最后,如需在自定义数据集上训练,请参阅 <a href="create_dataset">创建训练数据集</a> 指南了解如何准备适配脚本的数据集。</p> <h2 class="relative group"><a id="脚本参数" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#脚本参数"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>脚本参数</span></h2> <blockquote class="tip" data-svelte-h="svelte-14psxqh"><p>以下重点介绍脚本中影响训练效果的关键参数,如需完整参数说明可查阅 <a href="https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py" rel="nofollow">脚本源码</a>。如有疑问欢迎反馈。</p></blockquote> <p data-svelte-h="svelte-1fyn80r">训练脚本提供丰富参数供自定义训练流程,所有参数及说明详见 <a href="https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L193" rel="nofollow"><code>parse_args()</code></a> 函数。该函数为每个参数提供默认值(如批次大小、学习率等),也可通过命令行参数覆盖。</p> <p data-svelte-h="svelte-je3eqx">例如使用fp16混合精度加速训练:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch train_text_to_image.py \
--mixed_precision=<span class="hljs-string">&quot;fp16&quot;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-urfoa3">基础重要参数包括:</p> <ul data-svelte-h="svelte-vs9gl3"><li><code>--pretrained_model_name_or_path</code>: Hub模型名称或本地预训练模型路径</li> <li><code>--dataset_name</code>: Hub数据集名称或本地训练数据集路径</li> <li><code>--image_column</code>: 数据集中图像列名</li> <li><code>--caption_column</code>: 数据集中文本列名</li> <li><code>--output_dir</code>: 模型保存路径</li> <li><code>--push_to_hub</code>: 是否将训练模型推送至Hub</li> <li><code>--checkpointing_steps</code>: 模型检查点保存步数;训练中断时可添加 <code>--resume_from_checkpoint</code> 从该检查点恢复训练</li></ul> <h3 class="relative group"><a id="min-snr加权策略" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#min-snr加权策略"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Min-SNR加权策略</span></h3> <p data-svelte-h="svelte-1kealoa"><a href="https://huggingface.co/papers/2303.09556" rel="nofollow">Min-SNR</a> 加权策略通过重新平衡损失函数加速模型收敛。训练脚本支持预测 <code>epsilon</code>(噪声)或 <code>v_prediction</code>,而Min-SNR兼容两种预测类型。该策略仅限PyTorch版本,Flax训练脚本不支持。</p> <p data-svelte-h="svelte-13rrc02">添加 <code>--snr_gamma</code> 参数并设为推荐值5.0:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch train_text_to_image.py \
--snr_gamma=5.0<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-p650kw">可通过此 <a href="https://wandb.ai/sayakpaul/text2image-finetune-minsnr" rel="nofollow">Weights and Biases</a> 报告比较不同 <code>snr_gamma</code> 值的损失曲面。小数据集上Min-SNR效果可能不如大数据集显著。</p> <h2 class="relative group"><a id="训练脚本解析" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#训练脚本解析"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>训练脚本解析</span></h2> <p data-svelte-h="svelte-wh0i5">数据集预处理代码和训练循环位于 <a href="https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L490" rel="nofollow"><code>main()</code></a> 函数,自定义修改需在此处进行。</p> <p data-svelte-h="svelte-1nlf1uu"><code>train_text_to_image</code> 脚本首先 <a href="https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L543" rel="nofollow">加载调度器</a> 和分词器,此处可替换其他调度器:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=<span class="hljs-string">&quot;scheduler&quot;</span>)
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder=<span class="hljs-string">&quot;tokenizer&quot;</span>, revision=args.revision
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18tiko3">接着 <a href="https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L619" rel="nofollow">加载UNet模型</a></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=<span class="hljs-string">&quot;unet&quot;</span>)
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cpfd0u">随后对数据集的文本和图像列进行预处理。<a href="https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L724" rel="nofollow"><code>tokenize_captions</code></a> 函数处理文本分词,<a href="https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L742" rel="nofollow"><code>train_transforms</code></a> 定义图像增强策略,二者集成于 <code>preprocess_train</code></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_train</span>(<span class="hljs-params">examples</span>):
images = [image.convert(<span class="hljs-string">&quot;RGB&quot;</span>) <span class="hljs-keyword">for</span> image <span class="hljs-keyword">in</span> examples[image_column]]
examples[<span class="hljs-string">&quot;pixel_values&quot;</span>] = [train_transforms(image) <span class="hljs-keyword">for</span> image <span class="hljs-keyword">in</span> images]
examples[<span class="hljs-string">&quot;input_ids&quot;</span>] = tokenize_captions(examples)
<span class="hljs-keyword">return</span> examples<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-2hmlty">最后,<a href="https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L878" rel="nofollow">训练循环</a> 处理剩余流程:图像编码为潜空间、添加噪声、计算文本嵌入条件、更新模型参数、保存并推送模型至Hub。想深入了解训练循环原理,可参阅 <a href="../using-diffusers/write_own_pipeline">理解管道、模型与调度器</a> 教程,该教程解析了去噪过程的核心逻辑。</p> <h2 class="relative group"><a id="启动脚本" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#启动脚本"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>启动脚本</span></h2> <p data-svelte-h="svelte-mxmxgt">完成所有配置后,即可启动训练脚本!🚀</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">PyTorch </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">Flax </div></div> <div class="language-select"><p data-svelte-h="svelte-tmdnu8"><a href="https://huggingface.co/datasets/lambdalabs/naruto-blip-captions" rel="nofollow">火影忍者BLIP标注数据集</a> 为例训练生成火影角色。设置环境变量 <code>MODEL_NAME</code><code>dataset_name</code> 指定模型和数据集(Hub或本地路径)。多GPU训练需在 <code>accelerate launch</code> 命令中添加 <code>--multi_gpu</code> 参数。</p> <blockquote class="tip" data-svelte-h="svelte-5nkv91"><p>使用本地数据集时,设置 <code>TRAIN_DIR</code><code>OUTPUT_DIR</code> 环境变量为数据集路径和模型保存路径。</p></blockquote> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> MODEL_NAME=<span class="hljs-string">&quot;stable-diffusion-v1-5/stable-diffusion-v1-5&quot;</span>
<span class="hljs-built_in">export</span> dataset_name=<span class="hljs-string">&quot;lambdalabs/naruto-blip-captions&quot;</span>
accelerate launch --mixed_precision=<span class="hljs-string">&quot;fp16&quot;</span> train_text_to_image.py \
--pretrained_model_name_or_path=<span class="hljs-variable">$MODEL_NAME</span> \
--dataset_name=<span class="hljs-variable">$dataset_name</span> \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--enable_xformers_memory_efficient_attention \
--lr_scheduler=<span class="hljs-string">&quot;constant&quot;</span> --lr_warmup_steps=0 \
--output_dir=<span class="hljs-string">&quot;sd-naruto-model&quot;</span> \
--push_to_hub<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-hxct76">训练完成后,即可使用新模型进行推理:</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">PyTorch </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">Flax </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline
<span class="hljs-keyword">import</span> torch
pipeline = StableDiffusionPipeline.from_pretrained(<span class="hljs-string">&quot;path/to/saved_model&quot;</span>, torch_dtype=torch.float16, use_safetensors=<span class="hljs-literal">True</span>).to(<span class="hljs-string">&quot;cuda&quot;</span>)
image = pipeline(prompt=<span class="hljs-string">&quot;yoda&quot;</span>).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">&quot;yoda-naruto.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="后续步骤" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#后续步骤"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>后续步骤</span></h2> <p data-svelte-h="svelte-19kknl">恭喜完成文生图模型训练!如需进一步使用模型,以下指南可能有所帮助:</p> <ul data-svelte-h="svelte-1tmdo3u"><li>了解如何加载 <a href="../using-diffusers/loading_adapters#LoRA">LoRA权重</a> 进行推理(如果训练时使用了LoRA)</li> <li><a href="../using-diffusers/conditional_image_generation">文生图</a> 任务指南中,了解引导尺度等参数或提示词加权等技术如何控制生成效果</li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/zh/training/text2image.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_oaf0l2 = {
assets: "/docs/diffusers/pr_12652/zh",
base: "/docs/diffusers/pr_12652/zh",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_12652/zh/_app/immutable/entry/start.ca7a833f.js"),
import("/docs/diffusers/pr_12652/zh/_app/immutable/entry/app.746b83f3.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 49],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
40.5 kB
·
Xet hash:
cf321454091f66794804967ee2b7e56b1edf8007092f36e6e7d539901cc5ca29

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.