Buckets:

rtrm's picture
download
raw
37.1 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;torchao&quot;,&quot;local&quot;:&quot;torchao&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;torch.compile&quot;,&quot;local&quot;:&quot;torchcompile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;autoquant&quot;,&quot;local&quot;:&quot;autoquant&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Supported quantization types&quot;,&quot;local&quot;:&quot;supported-quantization-types&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Serializing and Deserializing quantized models&quot;,&quot;local&quot;:&quot;serializing-and-deserializing-quantized-models&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Resources&quot;,&quot;local&quot;:&quot;resources&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_12807/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/entry/start.81ef957e.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/scheduler.53228c21.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/singletons.650a2250.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/index.e93d0901.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/paths.8020c977.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/entry/app.8ed8d352.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/preload-helper.3843ebca.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/index.100fac89.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/nodes/0.5f0c602d.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/nodes/297.3fedc5d6.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/CopyLLMTxtMenu.f558cba0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/globals.7f7f1b26.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/IconCopy.38cf8f56.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.9bafb610.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12807/en/_app/immutable/chunks/CodeBlock.d30a6509.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;torchao&quot;,&quot;local&quot;:&quot;torchao&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;torch.compile&quot;,&quot;local&quot;:&quot;torchcompile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;autoquant&quot;,&quot;local&quot;:&quot;autoquant&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Supported quantization types&quot;,&quot;local&quot;:&quot;supported-quantization-types&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Serializing and Deserializing quantized models&quot;,&quot;local&quot;:&quot;serializing-and-deserializing-quantized-models&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Resources&quot;,&quot;local&quot;:&quot;resources&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="torchao" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchao"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torchao</span></h1> <p data-svelte-h="svelte-1ucx4sz"><a href="https://github.com/pytorch/ao" rel="nofollow">torchao</a> provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with <a href="https://hf.co/docs/accelerate/index" rel="nofollow">Accelerate</a> and contains <code>torch.nn.Linear</code> layers.</p> <p data-svelte-h="svelte-fmrzpk">Make sure Pytorch 2.5+ and torchao are installed with the command below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->uv pip install -U torch torchao<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1nzolae">Each quantization dtype is available as a separate instance of a <a href="https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize" rel="nofollow">AOBaseConfig</a> class. This provides more flexible configuration options by exposing more available arguments.</p> <p data-svelte-h="svelte-ienguh">Pass the <code>AOBaseConfig</code> of a quantization dtype, like <a href="https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig" rel="nofollow">Int4WeightOnlyConfig</a> to <a href="/docs/diffusers/pr_12807/en/api/quantization#diffusers.TorchAoConfig">TorchAoConfig</a> in <a href="/docs/diffusers/pr_12807/en/api/models/overview#diffusers.ModelMixin.from_pretrained">from_pretrained()</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
<span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int8WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={<span class="hljs-string">&quot;transformer&quot;</span>: TorchAoConfig(Int8WeightOnlyConfig(group_size=<span class="hljs-number">128</span>)))}
)
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map=<span class="hljs-string">&quot;cuda&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1i1n0wg">For simple use cases, you could also provide a string identifier in <code>TorchAo</code> as shown below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={<span class="hljs-string">&quot;transformer&quot;</span>: TorchAoConfig(<span class="hljs-string">&quot;int8wo&quot;</span>)}
)
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map=<span class="hljs-string">&quot;cuda&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="torchcompile" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchcompile"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torch.compile</span></h2> <p data-svelte-h="svelte-1re1mrg">torchao supports <a href="../optimization/fp16#torchcompile">torch.compile</a> which can speed up inference with one line of code.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
<span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> Int4WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={<span class="hljs-string">&quot;transformer&quot;</span>: TorchAoConfig(Int4WeightOnlyConfig(group_size=<span class="hljs-number">128</span>)))}
)
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map=<span class="hljs-string">&quot;cuda&quot;</span>
)
pipeline.transformer.<span class="hljs-built_in">compile</span>(transformer, mode=<span class="hljs-string">&quot;max-autotune&quot;</span>, fullgraph=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-139k18a">Refer to this <a href="https://github.com/huggingface/diffusers/pull/10009#issue-2688781450" rel="nofollow">table</a> for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao <a href="https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks" rel="nofollow">repository</a>.</p> <blockquote class="tip" data-svelte-h="svelte-1xstrmi"><p>The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.</p></blockquote> <h2 class="relative group"><a id="autoquant" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#autoquant"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>autoquant</span></h2> <p data-svelte-h="svelte-1sxhdyw">torchao provides <a href="https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant" rel="nofollow">autoquant</a> an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
<span class="hljs-keyword">from</span> torchao.quantization <span class="hljs-keyword">import</span> autoquant
<span class="hljs-comment"># Load the pipeline</span>
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-schnell&quot;</span>,
torch_dtype=torch.bfloat16,
device_map=<span class="hljs-string">&quot;cuda&quot;</span>
)
transformer = autoquant(pipeline.transformer)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="supported-quantization-types" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#supported-quantization-types"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Supported quantization types</span></h2> <p data-svelte-h="svelte-1dy3rwb">torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.</p> <p data-svelte-h="svelte-17x1tdo">Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like <code>bfloat16</code>. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.</p> <p data-svelte-h="svelte-1nbmql6">Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.</p> <p data-svelte-h="svelte-4xexxq">The quantization methods supported are as follows:</p> <table data-svelte-h="svelte-1y0n94l"><thead><tr><th><strong>Category</strong></th> <th><strong>Full Function Names</strong></th> <th><strong>Shorthands</strong></th></tr></thead> <tbody><tr><td><strong>Integer quantization</strong></td> <td><code>int4_weight_only</code>, <code>int8_dynamic_activation_int4_weight</code>, <code>int8_weight_only</code>, <code>int8_dynamic_activation_int8_weight</code></td> <td><code>int4wo</code>, <code>int4dq</code>, <code>int8wo</code>, <code>int8dq</code></td></tr> <tr><td><strong>Floating point 8-bit quantization</strong></td> <td><code>float8_weight_only</code>, <code>float8_dynamic_activation_float8_weight</code>, <code>float8_static_activation_float8_weight</code></td> <td><code>float8wo</code>, <code>float8wo_e5m2</code>, <code>float8wo_e4m3</code>, <code>float8dq</code>, <code>float8dq_e4m3</code>, <code>float8dq_e4m3_tensor</code>, <code>float8dq_e4m3_row</code></td></tr> <tr><td><strong>Floating point X-bit quantization</strong></td> <td><code>fpx_weight_only</code></td> <td><code>fpX_eAwB</code> where <code>X</code> is the number of bits (1-7), <code>A</code> is exponent bits, and <code>B</code> is mantissa bits. Constraint: <code>X == A + B + 1</code></td></tr> <tr><td><strong>Unsigned Integer quantization</strong></td> <td><code>uintx_weight_only</code></td> <td><code>uint1wo</code>, <code>uint2wo</code>, <code>uint3wo</code>, <code>uint4wo</code>, <code>uint5wo</code>, <code>uint6wo</code>, <code>uint7wo</code></td></tr></tbody></table> <p data-svelte-h="svelte-1r64eqr">Some quantization methods are aliases (for example, <code>int8wo</code> is the commonly used shorthand for <code>int8_weight_only</code>). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.</p> <p data-svelte-h="svelte-6zfuzk">Refer to the <a href="https://docs.pytorch.org/ao/stable/index.html" rel="nofollow">official torchao documentation</a> for a better understanding of the available quantization methods and the exhaustive list of configuration options available.</p> <h2 class="relative group"><a id="serializing-and-deserializing-quantized-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#serializing-and-deserializing-quantized-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Serializing and Deserializing quantized models</span></h2> <p data-svelte-h="svelte-16wg2wb">To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the <a href="/docs/diffusers/pr_12807/en/api/models/overview#diffusers.ModelMixin.save_pretrained">save_pretrained()</a> method.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, TorchAoConfig
quantization_config = TorchAoConfig(<span class="hljs-string">&quot;int8wo&quot;</span>)
transformer = AutoModel.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/Flux.1-Dev&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained(<span class="hljs-string">&quot;/path/to/flux_int8wo&quot;</span>, safe_serialization=<span class="hljs-literal">False</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1loqepz">To load a serialized quantized model, use the <a href="/docs/diffusers/pr_12807/en/api/models/overview#diffusers.ModelMixin.from_pretrained">from_pretrained()</a> method.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> FluxPipeline, AutoModel
transformer = AutoModel.from_pretrained(<span class="hljs-string">&quot;/path/to/flux_int8wo&quot;</span>, torch_dtype=torch.bfloat16, use_safetensors=<span class="hljs-literal">False</span>)
pipe = FluxPipeline.from_pretrained(<span class="hljs-string">&quot;black-forest-labs/Flux.1-Dev&quot;</span>, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to(<span class="hljs-string">&quot;cuda&quot;</span>)
prompt = <span class="hljs-string">&quot;A cat holding a sign that says hello world&quot;</span>
image = pipe(prompt, num_inference_steps=<span class="hljs-number">30</span>, guidance_scale=<span class="hljs-number">7.0</span>).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">&quot;output.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1m61wn7">If you are using <code>torch&lt;=2.6.0</code>, some quantization methods, such as <code>uint4wo</code>, cannot be loaded directly and may result in an <code>UnpicklingError</code> when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using <code>weights_only=False</code> in <code>torch.load</code>, so it should be run only if the weights were obtained from a trustable source.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> init_empty_weights
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> FluxPipeline, AutoModel, TorchAoConfig
<span class="hljs-comment"># Serialize the model</span>
transformer = AutoModel.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/Flux.1-Dev&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
quantization_config=TorchAoConfig(<span class="hljs-string">&quot;uint4wo&quot;</span>),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained(<span class="hljs-string">&quot;/path/to/flux_uint4wo&quot;</span>, safe_serialization=<span class="hljs-literal">False</span>, max_shard_size=<span class="hljs-string">&quot;50GB&quot;</span>)
<span class="hljs-comment"># ...</span>
<span class="hljs-comment"># Load the model</span>
state_dict = torch.load(<span class="hljs-string">&quot;/path/to/flux_uint4wo/diffusion_pytorch_model.bin&quot;</span>, weights_only=<span class="hljs-literal">False</span>, map_location=<span class="hljs-string">&quot;cpu&quot;</span>)
<span class="hljs-keyword">with</span> init_empty_weights():
transformer = AutoModel.from_config(<span class="hljs-string">&quot;/path/to/flux_uint4wo/config.json&quot;</span>)
transformer.load_state_dict(state_dict, strict=<span class="hljs-literal">True</span>, assign=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-t8mm75"><p>The <a href="/docs/diffusers/pr_12807/en/api/models/auto_model#diffusers.AutoModel">AutoModel</a> API is supported for PyTorch &gt;= 2.6 as shown in the examples below.</p></blockquote> <h2 class="relative group"><a id="resources" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#resources"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Resources</span></h2> <ul data-svelte-h="svelte-ok3vq4"><li><a href="https://docs.pytorch.org/ao/stable/index.html" rel="nofollow">TorchAO Quantization API</a></li> <li><a href="https://github.com/sayakpaul/diffusers-torchao" rel="nofollow">Diffusers-TorchAO examples</a></li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1x3mwc2 = {
assets: "/docs/diffusers/pr_12807/en",
base: "/docs/diffusers/pr_12807/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_12807/en/_app/immutable/entry/start.81ef957e.js"),
import("/docs/diffusers/pr_12807/en/_app/immutable/entry/app.8ed8d352.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 297],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
37.1 kB
·
Xet hash:
8bd2adc7252b3d5e77ae6cfdf09572d35f11bd5c354fd53baf1dbe50f9425e91

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.