Buckets:

rtrm's picture
download
raw
41.2 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;文本反转(Textual Inversion)&quot;,&quot;local&quot;:&quot;文本反转textual-inversion&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;脚本参数&quot;,&quot;local&quot;:&quot;脚本参数&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;训练脚本&quot;,&quot;local&quot;:&quot;训练脚本&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;启动脚本&quot;,&quot;local&quot;:&quot;启动脚本&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;后续步骤&quot;,&quot;local&quot;:&quot;后续步骤&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_11739/zh/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/entry/start.95a8faef.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/scheduler.e4ff9b64.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/singletons.0a6f1d19.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/index.f9be34a7.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/paths.37f6e25a.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/entry/app.a988cdaf.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/preload-helper.3e2c3f46.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/index.09f1bca0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/nodes/0.0ec3fec6.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/nodes/50.110be27b.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.f5199cd9.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/CodeBlock.1680a1fd.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/zh/_app/immutable/chunks/HfOption.44827c7f.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;文本反转(Textual Inversion)&quot;,&quot;local&quot;:&quot;文本反转textual-inversion&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;脚本参数&quot;,&quot;local&quot;:&quot;脚本参数&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;训练脚本&quot;,&quot;local&quot;:&quot;训练脚本&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;启动脚本&quot;,&quot;local&quot;:&quot;启动脚本&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;后续步骤&quot;,&quot;local&quot;:&quot;后续步骤&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="文本反转textual-inversion" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#文本反转textual-inversion"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>文本反转(Textual Inversion)</span></h1> <p data-svelte-h="svelte-583d0v"><a href="https://hf.co/papers/2208.01618" rel="nofollow">文本反转</a>是一种训练技术,仅需少量示例图像即可个性化图像生成模型。该技术通过学习和更新文本嵌入(新嵌入会绑定到提示中必须使用的特殊词汇)来匹配您提供的示例图像。</p> <p data-svelte-h="svelte-11ie2pv">如果在显存有限的GPU上训练,建议在训练命令中启用<code>gradient_checkpointing</code><code>mixed_precision</code>参数。您还可以通过<a href="../optimization/xformers">xFormers</a>使用内存高效注意力机制来减少内存占用。JAX/Flax训练也支持在TPU和GPU上进行高效训练,但不支持梯度检查点或xFormers。在配置与PyTorch相同的情况下,Flax训练脚本的速度至少应快70%!</p> <p data-svelte-h="svelte-79domf">本指南将探索<a href="https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py" rel="nofollow">textual_inversion.py</a>脚本,帮助您更熟悉其工作原理,并了解如何根据自身需求进行调整。</p> <p data-svelte-h="svelte-1l0wcm7">运行脚本前,请确保从源码安装库:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->git <span class="hljs-built_in">clone</span> https://github.com/huggingface/diffusers
<span class="hljs-built_in">cd</span> diffusers
pip install .<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1por4m9">进入包含训练脚本的示例目录,并安装所需依赖:</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">PyTorch </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">Flax </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">cd</span> examples/textual_inversion
pip install -r requirements.txt<!-- HTML_TAG_END --></pre></div> </div> <blockquote class="tip" data-svelte-h="svelte-snmpgn"><p>🤗 Accelerate 是一个帮助您在多GPU/TPU或混合精度环境下训练的工具库。它会根据硬件和环境自动配置训练设置。查看🤗 Accelerate <a href="https://huggingface.co/docs/accelerate/quicktour" rel="nofollow">快速入门</a>了解更多。</p></blockquote> <p data-svelte-h="svelte-cfitki">初始化🤗 Accelerate环境:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-b5atr0">要设置默认的🤗 Accelerate环境(不选择任何配置):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config default<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-h13mpn">如果您的环境不支持交互式shell(如notebook),可以使用:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate.utils <span class="hljs-keyword">import</span> write_basic_config
write_basic_config()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-13r470a">最后,如果想在自定义数据集上训练模型,请参阅<a href="create_dataset">创建训练数据集</a>指南,了解如何创建适用于训练脚本的数据集。</p> <blockquote class="tip" data-svelte-h="svelte-aoxkq0"><p>以下部分重点介绍训练脚本中需要理解的关键修改点,但未涵盖脚本所有细节。如需深入了解,可随时查阅<a href="https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py" rel="nofollow">脚本源码</a>,如有疑问欢迎反馈。</p></blockquote> <h2 class="relative group"><a id="脚本参数" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#脚本参数"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>脚本参数</span></h2> <p data-svelte-h="svelte-11ybobz">训练脚本包含众多参数,便于您定制训练过程。所有参数及其说明都列在<a href="https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L176" rel="nofollow"><code>parse_args()</code></a>函数中。Diffusers为每个参数提供了默认值(如训练批次大小和学习率),但您可以通过训练命令自由调整这些值。</p> <p data-svelte-h="svelte-yyke9b">例如,将梯度累积步数增加到默认值1以上:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch textual_inversion.py \
--gradient_accumulation_steps=4<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fmbksz">其他需要指定的基础重要参数包括:</p> <ul data-svelte-h="svelte-1v241ru"><li><code>--pretrained_model_name_or_path</code>:Hub上的模型名称或本地预训练模型路径</li> <li><code>--train_data_dir</code>:包含训练数据集(示例图像)的文件夹路径</li> <li><code>--output_dir</code>:训练模型保存位置</li> <li><code>--push_to_hub</code>:是否将训练好的模型推送至Hub</li> <li><code>--checkpointing_steps</code>:训练过程中保存检查点的频率;若训练意外中断,可通过在命令中添加<code>--resume_from_checkpoint</code>从该检查点恢复训练</li> <li><code>--num_vectors</code>:学习嵌入的向量数量;增加此参数可提升模型效果,但会提高训练成本</li> <li><code>--placeholder_token</code>:绑定学习嵌入的特殊词汇(推理时需在提示中使用该词)</li> <li><code>--initializer_token</code>:大致描述训练目标的单字词汇(如物体或风格)</li> <li><code>--learnable_property</code>:训练目标是学习新”风格”(如梵高画风)还是”物体”(如您的宠物狗)</li></ul> <h2 class="relative group"><a id="训练脚本" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#训练脚本"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>训练脚本</span></h2> <p data-svelte-h="svelte-188nunv">与其他训练脚本不同,textual_inversion.py包含自定义数据集类<a href="https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L487" rel="nofollow"><code>TextualInversionDataset</code></a>,用于创建数据集。您可以自定义图像尺寸、占位符词汇、插值方法、是否裁剪图像等。如需修改数据集创建方式,可调整<code>TextualInversionDataset</code>类。</p> <p data-svelte-h="svelte-1rv31ld">接下来,在<a href="https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L573" rel="nofollow"><code>main()</code></a>函数中可找到数据集预处理代码和训练循环。</p> <p data-svelte-h="svelte-t2slq7">脚本首先加载<a href="https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L616" rel="nofollow">tokenizer</a><a href="https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L622" rel="nofollow">scheduler和模型</a></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 加载tokenizer</span>
<span class="hljs-keyword">if</span> args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
<span class="hljs-keyword">elif</span> args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=<span class="hljs-string">&quot;tokenizer&quot;</span>)
<span class="hljs-comment"># 加载scheduler和模型</span>
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=<span class="hljs-string">&quot;scheduler&quot;</span>)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder=<span class="hljs-string">&quot;text_encoder&quot;</span>, revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=<span class="hljs-string">&quot;vae&quot;</span>, revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder=<span class="hljs-string">&quot;unet&quot;</span>, revision=args.revision
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-162yif">随后将特殊<a href="https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L632" rel="nofollow">占位符词汇</a>加入tokenizer,并调整嵌入层以适配新词汇。</p> <p data-svelte-h="svelte-1kypmdn">接着,脚本通过<code>TextualInversionDataset</code><a href="https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L716" rel="nofollow">创建数据集</a></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->train_dataset = TextualInversionDataset(
data_root=args.train_data_dir,
tokenizer=tokenizer,
size=args.resolution,
placeholder_token=(<span class="hljs-string">&quot; &quot;</span>.join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
repeats=args.repeats,
learnable_property=args.learnable_property,
center_crop=args.center_crop,
<span class="hljs-built_in">set</span>=<span class="hljs-string">&quot;train&quot;</span>,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=<span class="hljs-literal">True</span>, num_workers=args.dataloader_num_workers
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-16ltvar">最后,<a href="https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L784" rel="nofollow">训练循环</a>处理从预测噪声残差到更新特殊占位符词汇嵌入权重的所有流程。</p> <p data-svelte-h="svelte-1m9gut4">如需深入了解训练循环工作原理,请参阅<a href="../using-diffusers/write_own_pipeline">理解管道、模型与调度器</a>教程,该教程解析了去噪过程的基本模式。</p> <h2 class="relative group"><a id="启动脚本" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#启动脚本"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>启动脚本</span></h2> <p data-svelte-h="svelte-1fd8wg2">完成所有修改或确认默认配置后,即可启动训练脚本!🚀</p> <p data-svelte-h="svelte-he73wu">本指南将下载<a href="https://huggingface.co/datasets/diffusers/cat_toy_example" rel="nofollow">猫玩具</a>的示例图像并存储在目录中。当然,您也可以创建和使用自己的数据集(参见<a href="create_dataset">创建训练数据集</a>指南)。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> snapshot_download
local_dir = <span class="hljs-string">&quot;./cat&quot;</span>
snapshot_download(
<span class="hljs-string">&quot;diffusers/cat_toy_example&quot;</span>, local_dir=local_dir, repo_type=<span class="hljs-string">&quot;dataset&quot;</span>, ignore_patterns=<span class="hljs-string">&quot;.gitattributes&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1b8hdsd">设置环境变量<code>MODEL_NAME</code>为Hub上的模型ID或本地模型路径,<code>DATA_DIR</code>为刚下载的猫图像路径。脚本会将以下文件保存至您的仓库:</p> <ul data-svelte-h="svelte-1ch14mg"><li><code>learned_embeds.bin</code>:与示例图像对应的学习嵌入向量</li> <li><code>token_identifier.txt</code>:特殊占位符词汇</li> <li><code>type_of_concept.txt</code>:训练概念类型(“object”或”style”)</li></ul> <blockquote class="warning" data-svelte-h="svelte-1shho7l"><p>在单块V100 GPU上完整训练约需1小时。</p></blockquote> <p data-svelte-h="svelte-xksh79">启动脚本前还有最后一步。如果想实时观察训练过程,可以定期保存生成图像。在训练命令中添加以下参数:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->--validation_prompt=<span class="hljs-string">&quot;A &lt;cat-toy&gt; train&quot;</span>
--num_validation_images=4
--validation_steps=100<!-- HTML_TAG_END --></pre></div> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">PyTorch </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">Flax </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> MODEL_NAME=<span class="hljs-string">&quot;stable-diffusion-v1-5/stable-diffusion-v1-5&quot;</span>
<span class="hljs-built_in">export</span> DATA_DIR=<span class="hljs-string">&quot;./cat&quot;</span>
accelerate launch textual_inversion.py \
--pretrained_model_name_or_path=<span class="hljs-variable">$MODEL_NAME</span> \
--train_data_dir=<span class="hljs-variable">$DATA_DIR</span> \
--learnable_property=<span class="hljs-string">&quot;object&quot;</span> \
--placeholder_token=<span class="hljs-string">&quot;&lt;cat-toy&gt;&quot;</span> \
--initializer_token=<span class="hljs-string">&quot;toy&quot;</span> \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=3000 \
--learning_rate=5.0e-04 \
--scale_lr \
--lr_scheduler=<span class="hljs-string">&quot;constant&quot;</span> \
--lr_warmup_steps=0 \
--output_dir=<span class="hljs-string">&quot;textual_inversion_cat&quot;</span> \
--push_to_hub<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-eky57p">训练完成后,可以像这样使用新模型进行推理:</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">PyTorch </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">Flax </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline
<span class="hljs-keyword">import</span> torch
pipeline = StableDiffusionPipeline.from_pretrained(<span class="hljs-string">&quot;stable-diffusion-v1-5/stable-diffusion-v1-5&quot;</span>, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipeline.load_textual_inversion(<span class="hljs-string">&quot;sd-concepts-library/cat-toy&quot;</span>)
image = pipeline(<span class="hljs-string">&quot;A &lt;cat-toy&gt; train&quot;</span>, num_inference_steps=<span class="hljs-number">50</span>).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">&quot;cat-train.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="后续步骤" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#后续步骤"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>后续步骤</span></h2> <p data-svelte-h="svelte-1jignpy">恭喜您成功训练了自己的文本反转模型!🎉 如需了解更多使用技巧,以下指南可能会有所帮助:</p> <ul data-svelte-h="svelte-g6dwxa"><li>学习如何<a href="../using-diffusers/loading_adapters">加载文本反转嵌入</a>,并将其用作负面嵌入</li> <li>学习如何将<a href="textual_inversion_inference">文本反转</a>应用于Stable Diffusion 1/2和Stable Diffusion XL的推理</li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/zh/training/text_inversion.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_pa8cr6 = {
assets: "/docs/diffusers/pr_11739/zh",
base: "/docs/diffusers/pr_11739/zh",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_11739/zh/_app/immutable/entry/start.95a8faef.js"),
import("/docs/diffusers/pr_11739/zh/_app/immutable/entry/app.a988cdaf.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 50],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
41.2 kB
·
Xet hash:
f827991f29e1ee8dd4f909f69aefe6dd99ce4fd804ef2b723ba0315b2742fd99

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.