Buckets:
| import{s as ra,n as Qa,o as ga}from"../chunks/scheduler.23542ac5.js";import{S as ua,i as Va,e as n,s as M,c as J,h as ba,a as U,d as a,b as t,f as Ia,g as j,j as p,k as _s,l as da,m as e,n as y,t as T,o as c,p as i}from"../chunks/index.9b1f405b.js";import{C as Aa,H as Vl,E as fa}from"../chunks/MermaidChart.svelte_svelte_type_style_lang.ad6b418d.js";import{C as m}from"../chunks/CodeBlock.259f3544.js";import{D as Ra}from"../chunks/DocNotebookDropdown.68a629d2.js";function Ba(Os){let C,Al,bl,fl,o,Rl,h,Bl,I,Zl,r,Ws='Unconditional 이미지 생성은 학습에 사용된 데이터셋과 유사한 이미지를 생성하는 diffusion 모델에서 인기 있는 어플리케이션입니다. 일반적으로, 가장 좋은 결과는 특정 데이터셋에 사전 훈련된 모델을 파인튜닝하는 것으로 얻을 수 있습니다. 이 <a href="https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model" rel="nofollow">허브</a>에서 이러한 많은 체크포인트를 찾을 수 있지만, 만약 마음에 드는 체크포인트를 찾지 못했다면, 언제든지 스스로 학습할 수 있습니다!',kl,Q,Ss='이 튜토리얼은 나만의 🦋 나비 🦋를 생성하기 위해 <a href="https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset" rel="nofollow">Smithsonian Butterflies</a> 데이터셋의 하위 집합에서 <code>UNet2DModel</code> 모델을 학습하는 방법을 가르쳐줄 것입니다.',Gl,w,Ds='<p>💡 이 학습 튜토리얼은 <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb" rel="nofollow">Training with 🧨 Diffusers</a> 노트북 기반으로 합니다. Diffusion 모델의 작동 방식 및 자세한 내용은 노트북을 확인하세요!</p>',El,g,Ys='시작 전에, 🤗 Datasets을 불러오고 전처리하기 위해 데이터셋이 설치되어 있는지 다수 GPU에서 학습을 간소화하기 위해 🤗 Accelerate 가 설치되어 있는지 확인하세요. 그 후 학습 메트릭을 시각화하기 위해 <a href="https://www.tensorflow.org/tensorboard" rel="nofollow">TensorBoard</a>를 또한 설치하세요. (또한 학습 추적을 위해 <a href="https://docs.wandb.ai/" rel="nofollow">Weights & Biases</a>를 사용할 수 있습니다.)',Fl,u,Xl,V,zs='커뮤니티에 모델을 공유할 것을 권장하며, 이를 위해서 Hugging Face 계정에 로그인을 해야 합니다. (계정이 없다면 <a href="https://hf.co/join" rel="nofollow">여기</a>에서 만들 수 있습니다.) 노트북에서 로그인할 수 있으며 메시지가 표시되면 토큰을 입력할 수 있습니다.',Nl,b,_l,d,vs="또는 터미널로 로그인할 수 있습니다:",Ol,A,Wl,f,$s='모델 체크포인트가 상당히 크기 때문에 <a href="https://git-lfs.com/" rel="nofollow">Git-LFS</a>에서 대용량 파일의 버전 관리를 할 수 있습니다.',Sl,R,Dl,B,Yl,Z,xs="편의를 위해 학습 파라미터들을 포함한 <code>TrainingConfig</code> 클래스를 생성합니다 (자유롭게 조정 가능):",zl,k,vl,G,$l,E,Hs='🤗 Datasets 라이브러리와 <a href="https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset" rel="nofollow">Smithsonian Butterflies</a> 데이터셋을 쉽게 불러올 수 있습니다.',xl,F,Hl,X,Ls='💡<a href="https://huggingface.co/huggan" rel="nofollow">HugGan Community Event</a> 에서 추가의 데이터셋을 찾거나 로컬의 <a href="https://huggingface.co/docs/datasets/image_dataset#imagefolder" rel="nofollow"><code>ImageFolder</code></a>를 만듦으로써 나만의 데이터셋을 사용할 수 있습니다. HugGan Community Event 에 가져온 데이터셋의 경우 리포지토리의 id로 <code>config.dataset_name</code> 을 설정하고, 나만의 이미지를 사용하는 경우 <code>imagefolder</code> 를 설정합니다.',Ll,N,qs='🤗 Datasets은 <code>Image</code> 기능을 사용해 자동으로 이미지 데이터를 디코딩하고 <a href="https://pillow.readthedocs.io/en/stable/reference/Image.html" rel="nofollow"><code>PIL.Image</code></a>로 불러옵니다. 이를 시각화 해보면:',ql,_,Kl,O,Ks='<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_ds.png"/>',Pl,W,Ps="이미지는 모두 다른 사이즈이기 때문에, 우선 전처리가 필요합니다:",ls,S,la="<li><code>Resize</code> 는 <code>config.image_size</code> 에 정의된 이미지 사이즈로 변경합니다.</li> <li><code>RandomHorizontalFlip</code> 은 랜덤적으로 이미지를 미러링하여 데이터셋을 보강합니다.</li> <li><code>Normalize</code> 는 모델이 예상하는 [-1, 1] 범위로 픽셀 값을 재조정 하는데 중요합니다.</li>",ss,D,as,Y,sa="학습 도중에 <code>preprocess</code> 함수를 적용하려면 🤗 Datasets의 <code>set_transform</code> 방법이 사용됩니다.",es,z,Ms,v,aa='이미지의 크기가 조정되었는지 확인하기 위해 이미지를 다시 시각화해보세요. 이제 <a href="https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader" rel="nofollow">DataLoader</a>에 데이터셋을 포함해 학습할 준비가 되었습니다!',ts,$,ns,x,Us,H,ea="🧨 Diffusers에 사전학습된 모델들은 모델 클래스에서 원하는 파라미터로 쉽게 생성할 수 있습니다. 예를 들어, <code>UNet2DModel</code>를 생성하려면:",ps,L,Js,q,Ma="샘플의 이미지 크기와 모델 출력 크기가 맞는지 빠르게 확인하기 위한 좋은 아이디어가 있습니다:",js,K,ys,P,ta="훌륭해요! 다음, 이미지에 약간의 노이즈를 더하기 위해 스케줄러가 필요합니다.",Ts,ll,cs,sl,na="스케줄러는 모델을 학습 또는 추론에 사용하는지에 따라 다르게 작동합니다. 추론시에, 스케줄러는 노이즈로부터 이미지를 생성합니다. 학습시 스케줄러는 diffusion 과정에서의 특정 포인트로부터 모델의 출력 또는 샘플을 가져와 <em>노이즈 스케줄</em> 과 <em>업데이트 규칙</em>에 따라 이미지에 노이즈를 적용합니다.",is,al,Ua="<code>DDPMScheduler</code>를 보면 이전으로부터 <code>sample_image</code>에 랜덤한 노이즈를 더하는 <code>add_noise</code> 메서드를 사용합니다:",ms,el,Cs,Ml,pa='<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/noisy_butterfly.png"/>',ws,tl,Ja="모델의 학습 목적은 이미지에 더해진 노이즈를 예측하는 것입니다. 이 단계에서 손실은 다음과 같이 계산될 수 있습니다:",os,nl,hs,Ul,Is,pl,ja="지금까지, 모델 학습을 시작하기 위해 많은 부분을 갖추었으며 이제 남은 것은 모든 것을 조합하는 것입니다.",rs,Jl,ya="우선 옵티마이저(optimizer)와 학습률 스케줄러(learning rate scheduler)가 필요할 것입니다:",Qs,jl,gs,yl,Ta="그 후, 모델을 평가하는 방법이 필요합니다. 평가를 위해, <code>DDPMPipeline</code>을 사용해 배치의 이미지 샘플들을 생성하고 그리드 형태로 저장할 수 있습니다:",us,Tl,Vs,cl,ca="TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽게 수행하기 위해 🤗 Accelerate를 학습 루프에 함께 앞서 말한 모든 구성 정보들을 묶어 진행할 수 있습니다. 허브에 모델을 업로드 하기 위해 리포지토리 이름 및 정보를 가져오기 위한 함수를 작성하고 허브에 업로드할 수 있습니다.",bs,il,ia="💡아래의 학습 루프는 어렵고 길어 보일 수 있지만, 나중에 한 줄의 코드로 학습을 한다면 그만한 가치가 있을 것입니다! 만약 기다리지 못하고 이미지를 생성하고 싶다면, 아래 코드를 자유롭게 붙여넣고 작동시키면 됩니다. 🤗",ds,ml,As,Cl,ma="휴, 코드가 꽤 많았네요! 하지만 🤗 Accelerate의 <code>notebook_launcher</code> 함수와 학습을 시작할 준비가 되었습니다. 함수에 학습 루프, 모든 학습 인수, 학습에 사용할 프로세스 수(사용 가능한 GPU의 수를 변경할 수 있음)를 전달합니다:",fs,wl,Rs,ol,Ca="한번 학습이 완료되면, diffusion 모델로 생성된 최종 🦋이미지🦋를 확인해보길 바랍니다!",Bs,hl,Zs,Il,wa='<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_final.png"/>',ks,rl,Gs,Ql,oa='Unconditional 이미지 생성은 학습될 수 있는 작업 중 하나의 예시입니다. 다른 작업과 학습 방법은 <a href="../training/overview">🧨 Diffusers 학습 예시</a> 페이지에서 확인할 수 있습니다. 다음은 학습할 수 있는 몇 가지 예시입니다:',Es,gl,ha='<li><a href="../training/text_inversion">Textual Inversion</a>, 특정 시각적 개념을 학습시켜 생성된 이미지에 통합시키는 알고리즘입니다.</li> <li><a href="../training/dreambooth">DreamBooth</a>, 주제에 대한 몇 가지 입력 이미지들이 주어지면 주제에 대한 개인화된 이미지를 생성하기 위한 기술입니다.</li> <li><a href="../training/text2image">Guide</a> 데이터셋에 Stable Diffusion 모델을 파인튜닝하는 방법입니다.</li> <li><a href="../training/lora">Guide</a> LoRA를 사용해 매우 큰 모델을 빠르게 파인튜닝하기 위한 메모리 효율적인 기술입니다.</li>',Fs,ul,Xs,dl,Ns;return o=new Aa({props:{containerStyle:"float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"}}),h=new Ra({props:{containerStyle:"float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;",options:[{label:"Mixed",value:"https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers_doc/ko/basic_training.ipynb"},{label:"PyTorch",value:"https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers_doc/ko/pytorch/basic_training.ipynb"},{label:"TensorFlow",value:"https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers_doc/ko/tensorflow/basic_training.ipynb"},{label:"Mixed",value:"https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/diffusers_doc/ko/basic_training.ipynb"},{label:"PyTorch",value:"https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/diffusers_doc/ko/pytorch/basic_training.ipynb"},{label:"TensorFlow",value:"https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/diffusers_doc/ko/tensorflow/basic_training.ipynb"}]}}),I=new Vl({props:{title:"Diffusion 모델을 학습하기",local:"diffusion-모델을-학습하기",headingTag:"h1"}}),u=new m({props:{code:"IXBpcCUyMGluc3RhbGwlMjBkaWZmdXNlcnMlNUJ0cmFpbmluZyU1RA==",highlighted:"!pip install diffusers[training]",wrap:!1}}),b=new m({props:{code:"ZnJvbSUyMGh1Z2dpbmdmYWNlX2h1YiUyMGltcG9ydCUyMG5vdGVib29rX2xvZ2luJTBBJTBBbm90ZWJvb2tfbG9naW4oKQ==",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login | |
| <span class="hljs-meta">>>> </span>notebook_login()`,wrap:!1}}),A=new m({props:{code:"aGYlMjBhdXRoJTIwbG9naW4=",highlighted:"hf auth login",wrap:!1}}),R=new m({props:{code:"IXN1ZG8lMjBhcHQlMjAtcXElMjBpbnN0YWxsJTIwZ2l0LWxmcyUwQSFnaXQlMjBjb25maWclMjAtLWdsb2JhbCUyMGNyZWRlbnRpYWwuaGVscGVyJTIwc3RvcmU=",highlighted:`!sudo apt -qq install git-lfs | |
| !git config --global credential.helper store`,wrap:!1}}),B=new Vl({props:{title:"학습 구성",local:"학습-구성",headingTag:"h2"}}),k=new m({props:{code:"ZnJvbSUyMGRhdGFjbGFzc2VzJTIwaW1wb3J0JTIwZGF0YWNsYXNzJTBBJTBBJTBBJTQwZGF0YWNsYXNzJTBBY2xhc3MlMjBUcmFpbmluZ0NvbmZpZyUzQSUwQSUyMCUyMCUyMCUyMGltYWdlX3NpemUlMjAlM0QlMjAxMjglMjAlMjAlMjMlMjAlRUMlODMlOUQlRUMlODQlQjElRUIlOTAlOTglRUIlOEElOTQlMjAlRUMlOUQlQjQlRUIlQUYlQjglRUMlQTclODAlMjAlRUQlOTUlQjQlRUMlODMlODElRUIlOEYlODQlMEElMjAlMjAlMjAlMjB0cmFpbl9iYXRjaF9zaXplJTIwJTNEJTIwMTYlMEElMjAlMjAlMjAlMjBldmFsX2JhdGNoX3NpemUlMjAlM0QlMjAxNiUyMCUyMCUyMyUyMCVFRCU4RiU4OSVFQSVCMCU4MCUyMCVFQiU4RiU5OSVFQyU5NSU4OCVFQyU5NyU5MCUyMCVFQyU4MyU5OCVFRCU5NCU4QyVFQiVBNyU4MSVFRCU5NSVBMCUyMCVFQyU5RCVCNCVFQiVBRiVCOCVFQyVBNyU4MCUyMCVFQyU4OCU5OCUwQSUyMCUyMCUyMCUyMG51bV9lcG9jaHMlMjAlM0QlMjA1MCUwQSUyMCUyMCUyMCUyMGdyYWRpZW50X2FjY3VtdWxhdGlvbl9zdGVwcyUyMCUzRCUyMDElMEElMjAlMjAlMjAlMjBsZWFybmluZ19yYXRlJTIwJTNEJTIwMWUtNCUwQSUyMCUyMCUyMCUyMGxyX3dhcm11cF9zdGVwcyUyMCUzRCUyMDUwMCUwQSUyMCUyMCUyMCUyMHNhdmVfaW1hZ2VfZXBvY2hzJTIwJTNEJTIwMTAlMEElMjAlMjAlMjAlMjBzYXZlX21vZGVsX2Vwb2NocyUyMCUzRCUyMDMwJTBBJTIwJTIwJTIwJTIwbWl4ZWRfcHJlY2lzaW9uJTIwJTNEJTIwJTIyZnAxNiUyMiUyMCUyMCUyMyUyMCU2MG5vJTYwJUVCJThBJTk0JTIwZmxvYXQzMiUyQyUyMCVFQyU5RSU5MCVFQiU4RiU5OSUyMCVFRCU5OCVCQyVFRCU5NSVBOSUyMCVFQyVBMCU5NSVFQiVCMCU4MCVFQiU4RiU4NCVFQiVBNSVCQyUyMCVFQyU5QyU4NCVFRCU5NSU5QyUyMCU2MGZwMTYlNjAlMEElMjAlMjAlMjAlMjBvdXRwdXRfZGlyJTIwJTNEJTIwJTIyZGRwbS1idXR0ZXJmbGllcy0xMjglMjIlMjAlMjAlMjMlMjAlRUIlQTElOUMlRUMlQkIlQUMlMjAlRUIlQjAlOEYlMjBIRiUyMEh1YiVFQyU5NyU5MCUyMCVFQyVBMCU4MCVFQyU5RSVBNSVFQiU5MCU5OCVFQiU4QSU5NCUyMCVFQiVBQSVBOCVFQiU4RCVCOCVFQiVBQSU4NSUwQSUwQSUyMCUyMCUyMCUyMHB1c2hfdG9faHViJTIwJTNEJTIwVHJ1ZSUyMCUyMCUyMyUyMCVFQyVBMCU4MCVFQyU5RSVBNSVFQiU5MCU5QyUyMCVFQiVBQSVBOCVFQiU4RCVCOCVFQyU5RCU4NCUyMEhGJTIwSHViJUVDJTk3JTkwJTIwJUVDJTk3JTg1JUVCJUExJTlDJUVCJTkzJTlDJUVEJTk1JUEwJUVDJUE3JTgwJTIwJUVDJTk3JUFDJUVCJUI2JTgwJTBBJTIwJTIwJTIwJTIwaHViX3ByaXZhdGVfcmVwbyUyMCUzRCUyME5vbmUlMEElMjAlMjAlMjAlMjBvdmVyd3JpdGVfb3V0cHV0X2RpciUyMCUzRCUyMFRydWUlMjAlMjAlMjMlMjAlRUIlODUlQjglRUQlOEElQjglRUIlQjYlODElRUMlOUQlODQlMjAlRUIlOEIlQTQlRUMlOEIlOUMlMjAlRUMlOEIlQTQlRUQlOTYlODklRUQlOTUlQTAlMjAlRUIlOTUlOEMlMjAlRUMlOUQlQjQlRUMlQTAlODQlMjAlRUIlQUElQTglRUIlOEQlQjglRUMlOTclOTAlMjAlRUIlOEQlQUUlRUMlOTYlQjQlRUMlOTQlOEMlRUMlOUElQjglRUMlQTclODAlMEElMjAlMjAlMjAlMjBzZWVkJTIwJTNEJTIwMCUwQSUwQSUwQWNvbmZpZyUyMCUzRCUyMFRyYWluaW5nQ29uZmlnKCk=",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> dataclasses <span class="hljs-keyword">import</span> dataclass | |
| <span class="hljs-meta">>>> </span>@dataclass | |
| <span class="hljs-meta">... </span><span class="hljs-keyword">class</span> <span class="hljs-title class_">TrainingConfig</span>: | |
| <span class="hljs-meta">... </span> image_size = <span class="hljs-number">128</span> <span class="hljs-comment"># 생성되는 이미지 해상도</span> | |
| <span class="hljs-meta">... </span> train_batch_size = <span class="hljs-number">16</span> | |
| <span class="hljs-meta">... </span> eval_batch_size = <span class="hljs-number">16</span> <span class="hljs-comment"># 평가 동안에 샘플링할 이미지 수</span> | |
| <span class="hljs-meta">... </span> num_epochs = <span class="hljs-number">50</span> | |
| <span class="hljs-meta">... </span> gradient_accumulation_steps = <span class="hljs-number">1</span> | |
| <span class="hljs-meta">... </span> learning_rate = <span class="hljs-number">1e-4</span> | |
| <span class="hljs-meta">... </span> lr_warmup_steps = <span class="hljs-number">500</span> | |
| <span class="hljs-meta">... </span> save_image_epochs = <span class="hljs-number">10</span> | |
| <span class="hljs-meta">... </span> save_model_epochs = <span class="hljs-number">30</span> | |
| <span class="hljs-meta">... </span> mixed_precision = <span class="hljs-string">"fp16"</span> <span class="hljs-comment"># \`no\`는 float32, 자동 혼합 정밀도를 위한 \`fp16\`</span> | |
| <span class="hljs-meta">... </span> output_dir = <span class="hljs-string">"ddpm-butterflies-128"</span> <span class="hljs-comment"># 로컬 및 HF Hub에 저장되는 모델명</span> | |
| <span class="hljs-meta">... </span> push_to_hub = <span class="hljs-literal">True</span> <span class="hljs-comment"># 저장된 모델을 HF Hub에 업로드할지 여부</span> | |
| <span class="hljs-meta">... </span> hub_private_repo = <span class="hljs-literal">None</span> | |
| <span class="hljs-meta">... </span> overwrite_output_dir = <span class="hljs-literal">True</span> <span class="hljs-comment"># 노트북을 다시 실행할 때 이전 모델에 덮어씌울지</span> | |
| <span class="hljs-meta">... </span> seed = <span class="hljs-number">0</span> | |
| <span class="hljs-meta">>>> </span>config = TrainingConfig()`,wrap:!1}}),G=new Vl({props:{title:"데이터셋 불러오기",local:"데이터셋-불러오기",headingTag:"h2"}}),F=new m({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwbG9hZF9kYXRhc2V0JTBBJTBBY29uZmlnLmRhdGFzZXRfbmFtZSUyMCUzRCUyMCUyMmh1Z2dhbiUyRnNtaXRoc29uaWFuX2J1dHRlcmZsaWVzX3N1YnNldCUyMiUwQWRhdGFzZXQlMjAlM0QlMjBsb2FkX2RhdGFzZXQoY29uZmlnLmRhdGFzZXRfbmFtZSUyQyUyMHNwbGl0JTNEJTIydHJhaW4lMjIp",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-meta">>>> </span>config.dataset_name = <span class="hljs-string">"huggan/smithsonian_butterflies_subset"</span> | |
| <span class="hljs-meta">>>> </span>dataset = load_dataset(config.dataset_name, split=<span class="hljs-string">"train"</span>)`,wrap:!1}}),_=new m({props:{code:"aW1wb3J0JTIwbWF0cGxvdGxpYi5weXBsb3QlMjBhcyUyMHBsdCUwQSUwQWZpZyUyQyUyMGF4cyUyMCUzRCUyMHBsdC5zdWJwbG90cygxJTJDJTIwNCUyQyUyMGZpZ3NpemUlM0QoMTYlMkMlMjA0KSklMEFmb3IlMjBpJTJDJTIwaW1hZ2UlMjBpbiUyMGVudW1lcmF0ZShkYXRhc2V0JTVCJTNBNCU1RCU1QiUyMmltYWdlJTIyJTVEKSUzQSUwQSUyMCUyMCUyMCUyMGF4cyU1QmklNUQuaW1zaG93KGltYWdlKSUwQSUyMCUyMCUyMCUyMGF4cyU1QmklNUQuc2V0X2F4aXNfb2ZmKCklMEFmaWcuc2hvdygp",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt | |
| <span class="hljs-meta">>>> </span>fig, axs = plt.subplots(<span class="hljs-number">1</span>, <span class="hljs-number">4</span>, figsize=(<span class="hljs-number">16</span>, <span class="hljs-number">4</span>)) | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">for</span> i, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(dataset[:<span class="hljs-number">4</span>][<span class="hljs-string">"image"</span>]): | |
| <span class="hljs-meta">... </span> axs[i].imshow(image) | |
| <span class="hljs-meta">... </span> axs[i].set_axis_off() | |
| <span class="hljs-meta">>>> </span>fig.show()`,wrap:!1}}),D=new m({props:{code:"ZnJvbSUyMHRvcmNodmlzaW9uJTIwaW1wb3J0JTIwdHJhbnNmb3JtcyUwQSUwQXByZXByb2Nlc3MlMjAlM0QlMjB0cmFuc2Zvcm1zLkNvbXBvc2UoJTBBJTIwJTIwJTIwJTIwJTVCJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwdHJhbnNmb3Jtcy5SZXNpemUoKGNvbmZpZy5pbWFnZV9zaXplJTJDJTIwY29uZmlnLmltYWdlX3NpemUpKSUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMHRyYW5zZm9ybXMuUmFuZG9tSG9yaXpvbnRhbEZsaXAoKSUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMHRyYW5zZm9ybXMuVG9UZW5zb3IoKSUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMHRyYW5zZm9ybXMuTm9ybWFsaXplKCU1QjAuNSU1RCUyQyUyMCU1QjAuNSU1RCklMkMlMEElMjAlMjAlMjAlMjAlNUQlMEEp",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> torchvision <span class="hljs-keyword">import</span> transforms | |
| <span class="hljs-meta">>>> </span>preprocess = transforms.Compose( | |
| <span class="hljs-meta">... </span> [ | |
| <span class="hljs-meta">... </span> transforms.Resize((config.image_size, config.image_size)), | |
| <span class="hljs-meta">... </span> transforms.RandomHorizontalFlip(), | |
| <span class="hljs-meta">... </span> transforms.ToTensor(), | |
| <span class="hljs-meta">... </span> transforms.Normalize([<span class="hljs-number">0.5</span>], [<span class="hljs-number">0.5</span>]), | |
| <span class="hljs-meta">... </span> ] | |
| <span class="hljs-meta">... </span>)`,wrap:!1}}),z=new m({props:{code:"ZGVmJTIwdHJhbnNmb3JtKGV4YW1wbGVzKSUzQSUwQSUyMCUyMCUyMCUyMGltYWdlcyUyMCUzRCUyMCU1QnByZXByb2Nlc3MoaW1hZ2UuY29udmVydCglMjJSR0IlMjIpKSUyMGZvciUyMGltYWdlJTIwaW4lMjBleGFtcGxlcyU1QiUyMmltYWdlJTIyJTVEJTVEJTBBJTIwJTIwJTIwJTIwcmV0dXJuJTIwJTdCJTIyaW1hZ2VzJTIyJTNBJTIwaW1hZ2VzJTdEJTBBJTBBJTBBZGF0YXNldC5zZXRfdHJhbnNmb3JtKHRyYW5zZm9ybSk=",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">transform</span>(<span class="hljs-params">examples</span>): | |
| <span class="hljs-meta">... </span> images = [preprocess(image.convert(<span class="hljs-string">"RGB"</span>)) <span class="hljs-keyword">for</span> image <span class="hljs-keyword">in</span> examples[<span class="hljs-string">"image"</span>]] | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> {<span class="hljs-string">"images"</span>: images} | |
| <span class="hljs-meta">>>> </span>dataset.set_transform(transform)`,wrap:!1}}),$=new m({props:{code:"aW1wb3J0JTIwdG9yY2glMEElMEF0cmFpbl9kYXRhbG9hZGVyJTIwJTNEJTIwdG9yY2gudXRpbHMuZGF0YS5EYXRhTG9hZGVyKGRhdGFzZXQlMkMlMjBiYXRjaF9zaXplJTNEY29uZmlnLnRyYWluX2JhdGNoX3NpemUlMkMlMjBzaHVmZmxlJTNEVHJ1ZSk=",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-meta">>>> </span>train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=<span class="hljs-literal">True</span>)`,wrap:!1}}),x=new Vl({props:{title:"UNet2DModel 생성하기",local:"unet2dmodel-생성하기",headingTag:"h2"}}),L=new m({props:{code:"ZnJvbSUyMGRpZmZ1c2VycyUyMGltcG9ydCUyMFVOZXQyRE1vZGVsJTBBJTBBbW9kZWwlMjAlM0QlMjBVTmV0MkRNb2RlbCglMEElMjAlMjAlMjAlMjBzYW1wbGVfc2l6ZSUzRGNvbmZpZy5pbWFnZV9zaXplJTJDJTIwJTIwJTIzJTIwJUVEJTgzJTgwJUVBJUIyJTlGJTIwJUVDJTlEJUI0JUVCJUFGJUI4JUVDJUE3JTgwJTIwJUVEJTk1JUI0JUVDJTgzJTgxJUVCJThGJTg0JTBBJTIwJTIwJTIwJTIwaW5fY2hhbm5lbHMlM0QzJTJDJTIwJTIwJTIzJTIwJUVDJTlFJTg1JUVCJUEwJUE1JTIwJUVDJUIxJTg0JUVCJTg0JTkwJTIwJUVDJTg4JTk4JTJDJTIwUkdCJTIwJUVDJTlEJUI0JUVCJUFGJUI4JUVDJUE3JTgwJUVDJTk3JTkwJUVDJTg0JTlDJTIwMyUwQSUyMCUyMCUyMCUyMG91dF9jaGFubmVscyUzRDMlMkMlMjAlMjAlMjMlMjAlRUMlQjYlOUMlRUIlQTAlQTUlMjAlRUMlQjElODQlRUIlODQlOTAlMjAlRUMlODglOTglMEElMjAlMjAlMjAlMjBsYXllcnNfcGVyX2Jsb2NrJTNEMiUyQyUyMCUyMCUyMyUyMFVOZXQlMjAlRUIlQjglOTQlRUIlOUYlQUQlRUIlOEIlQjklMjAlRUIlQUElODclMjAlRUElQjAlOUMlRUMlOUQlOTglMjBSZXNOZXQlMjAlRUIlQTAlODglRUMlOUQlQjQlRUMlOTYlQjQlRUElQjAlODAlMjAlRUMlODIlQUMlRUMlOUElQTklRUIlOTAlOTglRUIlOEElOTQlRUMlQTclODAlMEElMjAlMjAlMjAlMjBibG9ja19vdXRfY2hhbm5lbHMlM0QoMTI4JTJDJTIwMTI4JTJDJTIwMjU2JTJDJTIwMjU2JTJDJTIwNTEyJTJDJTIwNTEyKSUyQyUyMCUyMCUyMyUyMCVFQSVCMCU4MSUyMFVOZXQlMjAlRUIlQjglOTQlRUIlOUYlQUQlRUMlOUQlODQlMjAlRUMlOUMlODQlRUQlOTUlOUMlMjAlRUMlQjYlOUMlRUIlQTAlQTUlMjAlRUMlQjElODQlRUIlODQlOTAlMjAlRUMlODglOTglMEElMjAlMjAlMjAlMjBkb3duX2Jsb2NrX3R5cGVzJTNEKCUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMkRvd25CbG9jazJEJTIyJTJDJTIwJTIwJTIzJTIwJUVDJTlEJUJDJUVCJUIwJTk4JUVDJUEwJTgxJUVDJTlEJUI4JTIwUmVzTmV0JTIwJUVCJThCJUE0JUVDJTlBJUI0JUVDJTgzJTk4JUVEJTk0JThDJUVCJUE3JTgxJTIwJUVCJUI4JTk0JUVCJTlGJUFEJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIyRG93bkJsb2NrMkQlMjIlMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJEb3duQmxvY2syRCUyMiUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMkRvd25CbG9jazJEJTIyJTJDJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIyQXR0bkRvd25CbG9jazJEJTIyJTJDJTIwJTIwJTIzJTIwc3BhdGlhbCUyMHNlbGYtYXR0ZW50aW9uJUVDJTlEJUI0JTIwJUVEJThGJUFDJUVEJTk1JUE4JUVCJTkwJTlDJTIwJUVDJTlEJUJDJUVCJUIwJTk4JUVDJUEwJTgxJUVDJTlEJUI4JTIwUmVzTmV0JTIwJUVCJThCJUE0JUVDJTlBJUI0JUVDJTgzJTk4JUVEJTk0JThDJUVCJUE3JTgxJTIwJUVCJUI4JTk0JUVCJTlGJUFEJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIyRG93bkJsb2NrMkQlMjIlMkMlMEElMjAlMjAlMjAlMjApJTJDJTBBJTIwJTIwJTIwJTIwdXBfYmxvY2tfdHlwZXMlM0QoJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIyVXBCbG9jazJEJTIyJTJDJTIwJTIwJTIzJTIwJUVDJTlEJUJDJUVCJUIwJTk4JUVDJUEwJTgxJUVDJTlEJUI4JTIwUmVzTmV0JTIwJUVDJTk3JTg1JUVDJTgzJTk4JUVEJTk0JThDJUVCJUE3JTgxJTIwJUVCJUI4JTk0JUVCJTlGJUFEJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIyQXR0blVwQmxvY2syRCUyMiUyQyUyMCUyMCUyMyUyMHNwYXRpYWwlMjBzZWxmLWF0dGVudGlvbiVFQyU5RCVCNCUyMCVFRCU4RiVBQyVFRCU5NSVBOCVFQiU5MCU5QyUyMCVFQyU5RCVCQyVFQiVCMCU5OCVFQyVBMCU4MSVFQyU5RCVCOCUyMFJlc05ldCUyMCVFQyU5NyU4NSVFQyU4MyU5OCVFRCU5NCU4QyVFQiVBNyU4MSUyMCVFQiVCOCU5NCVFQiU5RiVBRCUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMlVwQmxvY2syRCUyMiUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMlVwQmxvY2syRCUyMiUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMlVwQmxvY2syRCUyMiUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMlVwQmxvY2syRCUyMiUyQyUwQSUyMCUyMCUyMCUyMCklMkMlMEEp",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> UNet2DModel | |
| <span class="hljs-meta">>>> </span>model = UNet2DModel( | |
| <span class="hljs-meta">... </span> sample_size=config.image_size, <span class="hljs-comment"># 타겟 이미지 해상도</span> | |
| <span class="hljs-meta">... </span> in_channels=<span class="hljs-number">3</span>, <span class="hljs-comment"># 입력 채널 수, RGB 이미지에서 3</span> | |
| <span class="hljs-meta">... </span> out_channels=<span class="hljs-number">3</span>, <span class="hljs-comment"># 출력 채널 수</span> | |
| <span class="hljs-meta">... </span> layers_per_block=<span class="hljs-number">2</span>, <span class="hljs-comment"># UNet 블럭당 몇 개의 ResNet 레이어가 사용되는지</span> | |
| <span class="hljs-meta">... </span> block_out_channels=(<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">256</span>, <span class="hljs-number">256</span>, <span class="hljs-number">512</span>, <span class="hljs-number">512</span>), <span class="hljs-comment"># 각 UNet 블럭을 위한 출력 채널 수</span> | |
| <span class="hljs-meta">... </span> down_block_types=( | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"DownBlock2D"</span>, <span class="hljs-comment"># 일반적인 ResNet 다운샘플링 블럭</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"DownBlock2D"</span>, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"DownBlock2D"</span>, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"DownBlock2D"</span>, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"AttnDownBlock2D"</span>, <span class="hljs-comment"># spatial self-attention이 포함된 일반적인 ResNet 다운샘플링 블럭</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"DownBlock2D"</span>, | |
| <span class="hljs-meta">... </span> ), | |
| <span class="hljs-meta">... </span> up_block_types=( | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"UpBlock2D"</span>, <span class="hljs-comment"># 일반적인 ResNet 업샘플링 블럭</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"AttnUpBlock2D"</span>, <span class="hljs-comment"># spatial self-attention이 포함된 일반적인 ResNet 업샘플링 블럭</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"UpBlock2D"</span>, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"UpBlock2D"</span>, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"UpBlock2D"</span>, | |
| <span class="hljs-meta">... </span> <span class="hljs-string">"UpBlock2D"</span>, | |
| <span class="hljs-meta">... </span> ), | |
| <span class="hljs-meta">... </span>)`,wrap:!1}}),K=new m({props:{code:"c2FtcGxlX2ltYWdlJTIwJTNEJTIwZGF0YXNldCU1QjAlNUQlNUIlMjJpbWFnZXMlMjIlNUQudW5zcXVlZXplKDApJTBBcHJpbnQoJTIySW5wdXQlMjBzaGFwZSUzQSUyMiUyQyUyMHNhbXBsZV9pbWFnZS5zaGFwZSklMEElMEFwcmludCglMjJPdXRwdXQlMjBzaGFwZSUzQSUyMiUyQyUyMG1vZGVsKHNhbXBsZV9pbWFnZSUyQyUyMHRpbWVzdGVwJTNEMCkuc2FtcGxlLnNoYXBlKQ==",highlighted:`<span class="hljs-meta">>>> </span>sample_image = dataset[<span class="hljs-number">0</span>][<span class="hljs-string">"images"</span>].unsqueeze(<span class="hljs-number">0</span>) | |
| <span class="hljs-meta">>>> </span><span class="hljs-built_in">print</span>(<span class="hljs-string">"Input shape:"</span>, sample_image.shape) | |
| Input shape: torch.Size([<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>]) | |
| <span class="hljs-meta">>>> </span><span class="hljs-built_in">print</span>(<span class="hljs-string">"Output shape:"</span>, model(sample_image, timestep=<span class="hljs-number">0</span>).sample.shape) | |
| Output shape: torch.Size([<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>])`,wrap:!1}}),ll=new Vl({props:{title:"스케줄러 생성하기",local:"스케줄러-생성하기",headingTag:"h2"}}),el=new m({props:{code:"aW1wb3J0JTIwdG9yY2glMEFmcm9tJTIwUElMJTIwaW1wb3J0JTIwSW1hZ2UlMEFmcm9tJTIwZGlmZnVzZXJzJTIwaW1wb3J0JTIwRERQTVNjaGVkdWxlciUwQSUwQW5vaXNlX3NjaGVkdWxlciUyMCUzRCUyMEREUE1TY2hlZHVsZXIobnVtX3RyYWluX3RpbWVzdGVwcyUzRDEwMDApJTBBbm9pc2UlMjAlM0QlMjB0b3JjaC5yYW5kbihzYW1wbGVfaW1hZ2Uuc2hhcGUpJTBBdGltZXN0ZXBzJTIwJTNEJTIwdG9yY2guTG9uZ1RlbnNvciglNUI1MCU1RCklMEFub2lzeV9pbWFnZSUyMCUzRCUyMG5vaXNlX3NjaGVkdWxlci5hZGRfbm9pc2Uoc2FtcGxlX2ltYWdlJTJDJTIwbm9pc2UlMkMlMjB0aW1lc3RlcHMpJTBBJTBBSW1hZ2UuZnJvbWFycmF5KCgobm9pc3lfaW1hZ2UucGVybXV0ZSgwJTJDJTIwMiUyQyUyMDMlMkMlMjAxKSUyMCUyQiUyMDEuMCklMjAqJTIwMTI3LjUpLnR5cGUodG9yY2gudWludDgpLm51bXB5KCklNUIwJTVEKQ==",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DDPMScheduler | |
| <span class="hljs-meta">>>> </span>noise_scheduler = DDPMScheduler(num_train_timesteps=<span class="hljs-number">1000</span>) | |
| <span class="hljs-meta">>>> </span>noise = torch.randn(sample_image.shape) | |
| <span class="hljs-meta">>>> </span>timesteps = torch.LongTensor([<span class="hljs-number">50</span>]) | |
| <span class="hljs-meta">>>> </span>noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps) | |
| <span class="hljs-meta">>>> </span>Image.fromarray(((noisy_image.permute(<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">1</span>) + <span class="hljs-number">1.0</span>) * <span class="hljs-number">127.5</span>).<span class="hljs-built_in">type</span>(torch.uint8).numpy()[<span class="hljs-number">0</span>])`,wrap:!1}}),nl=new m({props:{code:"aW1wb3J0JTIwdG9yY2gubm4uZnVuY3Rpb25hbCUyMGFzJTIwRiUwQSUwQW5vaXNlX3ByZWQlMjAlM0QlMjBtb2RlbChub2lzeV9pbWFnZSUyQyUyMHRpbWVzdGVwcykuc2FtcGxlJTBBbG9zcyUyMCUzRCUyMEYubXNlX2xvc3Mobm9pc2VfcHJlZCUyQyUyMG5vaXNlKQ==",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> torch.nn.functional <span class="hljs-keyword">as</span> F | |
| <span class="hljs-meta">>>> </span>noise_pred = model(noisy_image, timesteps).sample | |
| <span class="hljs-meta">>>> </span>loss = F.mse_loss(noise_pred, noise)`,wrap:!1}}),Ul=new Vl({props:{title:"모델 학습하기",local:"모델-학습하기",headingTag:"h2"}}),jl=new m({props:{code:"ZnJvbSUyMGRpZmZ1c2Vycy5vcHRpbWl6YXRpb24lMjBpbXBvcnQlMjBnZXRfY29zaW5lX3NjaGVkdWxlX3dpdGhfd2FybXVwJTBBJTBBb3B0aW1pemVyJTIwJTNEJTIwdG9yY2gub3B0aW0uQWRhbVcobW9kZWwucGFyYW1ldGVycygpJTJDJTIwbHIlM0Rjb25maWcubGVhcm5pbmdfcmF0ZSklMEFscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfY29zaW5lX3NjaGVkdWxlX3dpdGhfd2FybXVwKCUwQSUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0Rjb25maWcubHJfd2FybXVwX3N0ZXBzJTJDJTBBJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEKGxlbih0cmFpbl9kYXRhbG9hZGVyKSUyMColMjBjb25maWcubnVtX2Vwb2NocyklMkMlMEEp",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> diffusers.optimization <span class="hljs-keyword">import</span> get_cosine_schedule_with_warmup | |
| <span class="hljs-meta">>>> </span>optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |
| <span class="hljs-meta">>>> </span>lr_scheduler = get_cosine_schedule_with_warmup( | |
| <span class="hljs-meta">... </span> optimizer=optimizer, | |
| <span class="hljs-meta">... </span> num_warmup_steps=config.lr_warmup_steps, | |
| <span class="hljs-meta">... </span> num_training_steps=(<span class="hljs-built_in">len</span>(train_dataloader) * config.num_epochs), | |
| <span class="hljs-meta">... </span>)`,wrap:!1}}),Tl=new m({props:{code:"ZnJvbSUyMGRpZmZ1c2VycyUyMGltcG9ydCUyMEREUE1QaXBlbGluZSUwQWltcG9ydCUyMG1hdGglMEFpbXBvcnQlMjBvcyUwQSUwQSUwQWRlZiUyMG1ha2VfZ3JpZChpbWFnZXMlMkMlMjByb3dzJTJDJTIwY29scyklM0ElMEElMjAlMjAlMjAlMjB3JTJDJTIwaCUyMCUzRCUyMGltYWdlcyU1QjAlNUQuc2l6ZSUwQSUyMCUyMCUyMCUyMGdyaWQlMjAlM0QlMjBJbWFnZS5uZXcoJTIyUkdCJTIyJTJDJTIwc2l6ZSUzRChjb2xzJTIwKiUyMHclMkMlMjByb3dzJTIwKiUyMGgpKSUwQSUyMCUyMCUyMCUyMGZvciUyMGklMkMlMjBpbWFnZSUyMGluJTIwZW51bWVyYXRlKGltYWdlcyklM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBncmlkLnBhc3RlKGltYWdlJTJDJTIwYm94JTNEKGklMjAlMjUlMjBjb2xzJTIwKiUyMHclMkMlMjBpJTIwJTJGJTJGJTIwY29scyUyMColMjBoKSklMEElMjAlMjAlMjAlMjByZXR1cm4lMjBncmlkJTBBJTBBJTBBZGVmJTIwZXZhbHVhdGUoY29uZmlnJTJDJTIwZXBvY2glMkMlMjBwaXBlbGluZSklM0ElMEElMjAlMjAlMjAlMjAlMjMlMjAlRUIlOUUlOUMlRUIlOEQlQTQlRUQlOTUlOUMlMjAlRUIlODUlQjglRUMlOUQlQjQlRUMlQTYlODglRUIlQTElOUMlMjAlRUIlQjYlODAlRUQlODQlQjAlMjAlRUMlOUQlQjQlRUIlQUYlQjglRUMlQTclODAlRUIlQTUlQkMlMjAlRUMlQjYlOTQlRUMlQjYlOUMlRUQlOTUlQTklRUIlOEIlODglRUIlOEIlQTQuKCVFQyU5RCVCNCVFQiU4QSU5NCUyMCVFQyU5NyVBRCVFQyVBMCU4NCVFRCU4QyU4QyUyMGRpZmZ1c2lvbiUyMCVFQSVCMyVCQyVFQyVBMCU5NSVFQyU5RSU4NSVFQiU4QiU4OCVFQiU4QiVBNC4pJTBBJTIwJTIwJTIwJTIwJTIzJTIwJUVBJUI4JUIwJUVCJUIzJUI4JTIwJUVEJThDJThDJUVDJTlEJUI0JUVEJTk0JTg0JUVCJTlEJUJDJUVDJTlEJUI4JTIwJUVDJUI2JTlDJUVCJUEwJUE1JTIwJUVEJTk4JTk1JUVEJTgzJTlDJUVCJThBJTk0JTIwJTYwTGlzdCU1QlBJTC5JbWFnZSU1RCU2MCUyMCVFQyU5RSU4NSVFQiU4QiU4OCVFQiU4QiVBNC4lMEElMjAlMjAlMjAlMjBpbWFnZXMlMjAlM0QlMjBwaXBlbGluZSglMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBiYXRjaF9zaXplJTNEY29uZmlnLmV2YWxfYmF0Y2hfc2l6ZSUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGdlbmVyYXRvciUzRHRvcmNoLm1hbnVhbF9zZWVkKGNvbmZpZy5zZWVkKSUyQyUwQSUyMCUyMCUyMCUyMCkuaW1hZ2VzJTBBJTBBJTIwJTIwJTIwJTIwJTIzJTIwJUVDJTlEJUI0JUVCJUFGJUI4JUVDJUE3JTgwJUVCJTkzJUE0JUVDJTlEJTg0JTIwJUVBJUI3JUI4JUVCJUE2JUFDJUVCJTkzJTlDJUVCJUExJTlDJTIwJUVCJUE3JThDJUVCJTkzJUE0JUVDJTk2JUI0JUVDJUE0JThEJUVCJThCJTg4JUVCJThCJUE0LiUwQSUyMCUyMCUyMCUyMGltYWdlX2dyaWQlMjAlM0QlMjBtYWtlX2dyaWQoaW1hZ2VzJTJDJTIwcm93cyUzRDQlMkMlMjBjb2xzJTNENCklMEElMEElMjAlMjAlMjAlMjAlMjMlMjAlRUMlOUQlQjQlRUIlQUYlQjglRUMlQTclODAlRUIlOTMlQTQlRUMlOUQlODQlMjAlRUMlQTAlODAlRUMlOUUlQTUlRUQlOTUlQTklRUIlOEIlODglRUIlOEIlQTQuJTBBJTIwJTIwJTIwJTIwdGVzdF9kaXIlMjAlM0QlMjBvcy5wYXRoLmpvaW4oY29uZmlnLm91dHB1dF9kaXIlMkMlMjAlMjJzYW1wbGVzJTIyKSUwQSUyMCUyMCUyMCUyMG9zLm1ha2VkaXJzKHRlc3RfZGlyJTJDJTIwZXhpc3Rfb2slM0RUcnVlKSUwQSUyMCUyMCUyMCUyMGltYWdlX2dyaWQuc2F2ZShmJTIyJTdCdGVzdF9kaXIlN0QlMkYlN0JlcG9jaCUzQTA0ZCU3RC5wbmclMjIp",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DDPMPipeline | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> math | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> os | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">make_grid</span>(<span class="hljs-params">images, rows, cols</span>): | |
| <span class="hljs-meta">... </span> w, h = images[<span class="hljs-number">0</span>].size | |
| <span class="hljs-meta">... </span> grid = Image.new(<span class="hljs-string">"RGB"</span>, size=(cols * w, rows * h)) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> i, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(images): | |
| <span class="hljs-meta">... </span> grid.paste(image, box=(i % cols * w, i // cols * h)) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> grid | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">evaluate</span>(<span class="hljs-params">config, epoch, pipeline</span>): | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 랜덤한 노이즈로 부터 이미지를 추출합니다.(이는 역전파 diffusion 과정입니다.)</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 기본 파이프라인 출력 형태는 \`List[PIL.Image]\` 입니다.</span> | |
| <span class="hljs-meta">... </span> images = pipeline( | |
| <span class="hljs-meta">... </span> batch_size=config.eval_batch_size, | |
| <span class="hljs-meta">... </span> generator=torch.manual_seed(config.seed), | |
| <span class="hljs-meta">... </span> ).images | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 이미지들을 그리드로 만들어줍니다.</span> | |
| <span class="hljs-meta">... </span> image_grid = make_grid(images, rows=<span class="hljs-number">4</span>, cols=<span class="hljs-number">4</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 이미지들을 저장합니다.</span> | |
| <span class="hljs-meta">... </span> test_dir = os.path.join(config.output_dir, <span class="hljs-string">"samples"</span>) | |
| <span class="hljs-meta">... </span> os.makedirs(test_dir, exist_ok=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">... </span> image_grid.save(<span class="hljs-string">f"<span class="hljs-subst">{test_dir}</span>/<span class="hljs-subst">{epoch:04d}</span>.png"</span>)`,wrap:!1}}),ml=new m({props:{code:"from%20accelerate%20import%20Accelerator%0Afrom%20huggingface_hub%20import%20create_repo%2C%20upload_folder%0Afrom%20tqdm.auto%20import%20tqdm%0Afrom%20pathlib%20import%20Path%0Aimport%20os%0A%0A%0Adef%20train_loop(config%2C%20model%2C%20noise_scheduler%2C%20optimizer%2C%20train_dataloader%2C%20lr_scheduler)%3A%0A%20%20%20%20%23%20Initialize%20accelerator%20and%20tensorboard%20logging%0A%20%20%20%20accelerator%20%3D%20Accelerator(%0A%20%20%20%20%20%20%20%20mixed_precision%3Dconfig.mixed_precision%2C%0A%20%20%20%20%20%20%20%20gradient_accumulation_steps%3Dconfig.gradient_accumulation_steps%2C%0A%20%20%20%20%20%20%20%20log_with%3D%22tensorboard%22%2C%0A%20%20%20%20%20%20%20%20project_dir%3Dos.path.join(config.output_dir%2C%20%22logs%22)%2C%0A%20%20%20%20)%0A%20%20%20%20if%20accelerator.is_main_process%3A%0A%20%20%20%20%20%20%20%20if%20config.output_dir%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20os.makedirs(config.output_dir%2C%20exist_ok%3DTrue)%0A%20%20%20%20%20%20%20%20if%20config.push_to_hub%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20repo_id%20%3D%20create_repo(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20repo_id%3Dconfig.hub_model_id%20or%20Path(config.output_dir).name%2C%20exist_ok%3DTrue%0A%20%20%20%20%20%20%20%20%20%20%20%20).repo_id%0A%20%20%20%20%20%20%20%20accelerator.init_trackers(%22train_example%22)%0A%0A%20%20%20%20%23%20%EB%AA%A8%EB%93%A0%20%EA%B2%83%EC%9D%B4%20%EC%A4%80%EB%B9%84%EB%90%98%EC%97%88%EC%8A%B5%EB%8B%88%EB%8B%A4.%0A%20%20%20%20%23%20%EA%B8%B0%EC%96%B5%ED%95%B4%EC%95%BC%20%ED%95%A0%20%ED%8A%B9%EC%A0%95%ED%95%9C%20%EC%88%9C%EC%84%9C%EB%8A%94%20%EC%97%86%EC%9C%BC%EB%A9%B0%20%EC%A4%80%EB%B9%84%ED%95%9C%20%EB%B0%A9%EB%B2%95%EC%97%90%20%EC%A0%9C%EA%B3%B5%ED%95%9C%20%EA%B2%83%EA%B3%BC%20%EB%8F%99%EC%9D%BC%ED%95%9C%20%EC%88%9C%EC%84%9C%EB%A1%9C%20%EA%B0%9D%EC%B2%B4%EC%9D%98%20%EC%95%95%EC%B6%95%EC%9D%84%20%ED%92%80%EB%A9%B4%20%EB%90%A9%EB%8B%88%EB%8B%A4.%0A%20%20%20%20model%2C%20optimizer%2C%20train_dataloader%2C%20lr_scheduler%20%3D%20accelerator.prepare(%0A%20%20%20%20%20%20%20%20model%2C%20optimizer%2C%20train_dataloader%2C%20lr_scheduler%0A%20%20%20%20)%0A%0A%20%20%20%20global_step%20%3D%200%0A%0A%20%20%20%20%23%20%EC%9D%B4%EC%A0%9C%20%EB%AA%A8%EB%8D%B8%EC%9D%84%20%ED%95%99%EC%8A%B5%ED%95%A9%EB%8B%88%EB%8B%A4.%0A%20%20%20%20for%20epoch%20in%20range(config.num_epochs)%3A%0A%20%20%20%20%20%20%20%20progress_bar%20%3D%20tqdm(total%3Dlen(train_dataloader)%2C%20disable%3Dnot%20accelerator.is_local_main_process)%0A%20%20%20%20%20%20%20%20progress_bar.set_description(f%22Epoch%20%7Bepoch%7D%22)%0A%0A%20%20%20%20%20%20%20%20for%20step%2C%20batch%20in%20enumerate(train_dataloader)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20clean_images%20%3D%20batch%5B%22images%22%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20%EC%9D%B4%EB%AF%B8%EC%A7%80%EC%97%90%20%EB%8D%94%ED%95%A0%20%EB%85%B8%EC%9D%B4%EC%A6%88%EB%A5%BC%20%EC%83%98%ED%94%8C%EB%A7%81%ED%95%A9%EB%8B%88%EB%8B%A4.%0A%20%20%20%20%20%20%20%20%20%20%20%20noise%20%3D%20torch.randn(clean_images.shape%2C%20device%3Dclean_images.device)%0A%20%20%20%20%20%20%20%20%20%20%20%20bs%20%3D%20clean_images.shape%5B0%5D%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20%EA%B0%81%20%EC%9D%B4%EB%AF%B8%EC%A7%80%EB%A5%BC%20%EC%9C%84%ED%95%9C%20%EB%9E%9C%EB%8D%A4%ED%95%9C%20%ED%83%80%EC%9E%84%EC%8A%A4%ED%85%9D(timestep)%EC%9D%84%20%EC%83%98%ED%94%8C%EB%A7%81%ED%95%A9%EB%8B%88%EB%8B%A4.%0A%20%20%20%20%20%20%20%20%20%20%20%20timesteps%20%3D%20torch.randint(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%200%2C%20noise_scheduler.config.num_train_timesteps%2C%20(bs%2C)%2C%20device%3Dclean_images.device%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20dtype%3Dtorch.int64%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20%EA%B0%81%20%ED%83%80%EC%9E%84%EC%8A%A4%ED%85%9D%EC%9D%98%20%EB%85%B8%EC%9D%B4%EC%A6%88%20%ED%81%AC%EA%B8%B0%EC%97%90%20%EB%94%B0%EB%9D%BC%20%EA%B9%A8%EB%81%97%ED%95%9C%20%EC%9D%B4%EB%AF%B8%EC%A7%80%EC%97%90%20%EB%85%B8%EC%9D%B4%EC%A6%88%EB%A5%BC%20%EC%B6%94%EA%B0%80%ED%95%A9%EB%8B%88%EB%8B%A4.%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20(%EC%9D%B4%EB%8A%94%20foward%20diffusion%20%EA%B3%BC%EC%A0%95%EC%9E%85%EB%8B%88%EB%8B%A4.)%0A%20%20%20%20%20%20%20%20%20%20%20%20noisy_images%20%3D%20noise_scheduler.add_noise(clean_images%2C%20noise%2C%20timesteps)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20with%20accelerator.accumulate(model)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20%EB%85%B8%EC%9D%B4%EC%A6%88%EB%A5%BC%20%EB%B0%98%EB%B3%B5%EC%A0%81%EC%9C%BC%EB%A1%9C%20%EC%98%88%EC%B8%A1%ED%95%A9%EB%8B%88%EB%8B%A4.%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20noise_pred%20%3D%20model(noisy_images%2C%20timesteps%2C%20return_dict%3DFalse)%5B0%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20F.mse_loss(noise_pred%2C%20noise)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20accelerator.backward(loss)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20accelerator.clip_grad_norm_(model.parameters()%2C%201.0)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20optimizer.step()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20lr_scheduler.step()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20optimizer.zero_grad()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20progress_bar.update(1)%0A%20%20%20%20%20%20%20%20%20%20%20%20logs%20%3D%20%7B%22loss%22%3A%20loss.detach().item()%2C%20%22lr%22%3A%20lr_scheduler.get_last_lr()%5B0%5D%2C%20%22step%22%3A%20global_step%7D%0A%20%20%20%20%20%20%20%20%20%20%20%20progress_bar.set_postfix(**logs)%0A%20%20%20%20%20%20%20%20%20%20%20%20accelerator.log(logs%2C%20step%3Dglobal_step)%0A%20%20%20%20%20%20%20%20%20%20%20%20global_step%20%2B%3D%201%0A%0A%20%20%20%20%20%20%20%20%23%20%EA%B0%81%20%EC%97%90%ED%8F%AC%ED%81%AC%EA%B0%80%20%EB%81%9D%EB%82%9C%20%ED%9B%84%20evaluate()%EC%99%80%20%EB%AA%87%20%EA%B0%80%EC%A7%80%20%EB%8D%B0%EB%AA%A8%20%EC%9D%B4%EB%AF%B8%EC%A7%80%EB%A5%BC%20%EC%84%A0%ED%83%9D%EC%A0%81%EC%9C%BC%EB%A1%9C%20%EC%83%98%ED%94%8C%EB%A7%81%ED%95%98%EA%B3%A0%20%EB%AA%A8%EB%8D%B8%EC%9D%84%20%EC%A0%80%EC%9E%A5%ED%95%A9%EB%8B%88%EB%8B%A4.%0A%20%20%20%20%20%20%20%20if%20accelerator.is_main_process%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20pipeline%20%3D%20DDPMPipeline(unet%3Daccelerator.unwrap_model(model)%2C%20scheduler%3Dnoise_scheduler)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20(epoch%20%2B%201)%20%25%20config.save_image_epochs%20%3D%3D%200%20or%20epoch%20%3D%3D%20config.num_epochs%20-%201%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20evaluate(config%2C%20epoch%2C%20pipeline)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20(epoch%20%2B%201)%20%25%20config.save_model_epochs%20%3D%3D%200%20or%20epoch%20%3D%3D%20config.num_epochs%20-%201%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20config.push_to_hub%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20upload_folder(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20repo_id%3Drepo_id%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20folder_path%3Dconfig.output_dir%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20commit_message%3Df%22Epoch%20%7Bepoch%7D%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ignore_patterns%3D%5B%22step_*%22%2C%20%22epoch_*%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20pipeline.save_pretrained(config.output_dir)",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> create_repo, upload_folder | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> pathlib <span class="hljs-keyword">import</span> Path | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> os | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">train_loop</span>(<span class="hljs-params">config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler</span>): | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># Initialize accelerator and tensorboard logging</span> | |
| <span class="hljs-meta">... </span> accelerator = Accelerator( | |
| <span class="hljs-meta">... </span> mixed_precision=config.mixed_precision, | |
| <span class="hljs-meta">... </span> gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| <span class="hljs-meta">... </span> log_with=<span class="hljs-string">"tensorboard"</span>, | |
| <span class="hljs-meta">... </span> project_dir=os.path.join(config.output_dir, <span class="hljs-string">"logs"</span>), | |
| <span class="hljs-meta">... </span> ) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> accelerator.is_main_process: | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.output_dir <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>: | |
| <span class="hljs-meta">... </span> os.makedirs(config.output_dir, exist_ok=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.push_to_hub: | |
| <span class="hljs-meta">... </span> repo_id = create_repo( | |
| <span class="hljs-meta">... </span> repo_id=config.hub_model_id <span class="hljs-keyword">or</span> Path(config.output_dir).name, exist_ok=<span class="hljs-literal">True</span> | |
| <span class="hljs-meta">... </span> ).repo_id | |
| <span class="hljs-meta">... </span> accelerator.init_trackers(<span class="hljs-string">"train_example"</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 모든 것이 준비되었습니다.</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 기억해야 할 특정한 순서는 없으며 준비한 방법에 제공한 것과 동일한 순서로 객체의 압축을 풀면 됩니다.</span> | |
| <span class="hljs-meta">... </span> model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
| <span class="hljs-meta">... </span> model, optimizer, train_dataloader, lr_scheduler | |
| <span class="hljs-meta">... </span> ) | |
| <span class="hljs-meta">... </span> global_step = <span class="hljs-number">0</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 이제 모델을 학습합니다.</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(config.num_epochs): | |
| <span class="hljs-meta">... </span> progress_bar = tqdm(total=<span class="hljs-built_in">len</span>(train_dataloader), disable=<span class="hljs-keyword">not</span> accelerator.is_local_main_process) | |
| <span class="hljs-meta">... </span> progress_bar.set_description(<span class="hljs-string">f"Epoch <span class="hljs-subst">{epoch}</span>"</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(train_dataloader): | |
| <span class="hljs-meta">... </span> clean_images = batch[<span class="hljs-string">"images"</span>] | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 이미지에 더할 노이즈를 샘플링합니다.</span> | |
| <span class="hljs-meta">... </span> noise = torch.randn(clean_images.shape, device=clean_images.device) | |
| <span class="hljs-meta">... </span> bs = clean_images.shape[<span class="hljs-number">0</span>] | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 각 이미지를 위한 랜덤한 타임스텝(timestep)을 샘플링합니다.</span> | |
| <span class="hljs-meta">... </span> timesteps = torch.randint( | |
| <span class="hljs-meta">... </span> <span class="hljs-number">0</span>, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device, | |
| <span class="hljs-meta">... </span> dtype=torch.int64 | |
| <span class="hljs-meta">... </span> ) | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 각 타임스텝의 노이즈 크기에 따라 깨끗한 이미지에 노이즈를 추가합니다.</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># (이는 foward diffusion 과정입니다.)</span> | |
| <span class="hljs-meta">... </span> noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">with</span> accelerator.accumulate(model): | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 노이즈를 반복적으로 예측합니다.</span> | |
| <span class="hljs-meta">... </span> noise_pred = model(noisy_images, timesteps, return_dict=<span class="hljs-literal">False</span>)[<span class="hljs-number">0</span>] | |
| <span class="hljs-meta">... </span> loss = F.mse_loss(noise_pred, noise) | |
| <span class="hljs-meta">... </span> accelerator.backward(loss) | |
| <span class="hljs-meta">... </span> accelerator.clip_grad_norm_(model.parameters(), <span class="hljs-number">1.0</span>) | |
| <span class="hljs-meta">... </span> optimizer.step() | |
| <span class="hljs-meta">... </span> lr_scheduler.step() | |
| <span class="hljs-meta">... </span> optimizer.zero_grad() | |
| <span class="hljs-meta">... </span> progress_bar.update(<span class="hljs-number">1</span>) | |
| <span class="hljs-meta">... </span> logs = {<span class="hljs-string">"loss"</span>: loss.detach().item(), <span class="hljs-string">"lr"</span>: lr_scheduler.get_last_lr()[<span class="hljs-number">0</span>], <span class="hljs-string">"step"</span>: global_step} | |
| <span class="hljs-meta">... </span> progress_bar.set_postfix(**logs) | |
| <span class="hljs-meta">... </span> accelerator.log(logs, step=global_step) | |
| <span class="hljs-meta">... </span> global_step += <span class="hljs-number">1</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-comment"># 각 에포크가 끝난 후 evaluate()와 몇 가지 데모 이미지를 선택적으로 샘플링하고 모델을 저장합니다.</span> | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> accelerator.is_main_process: | |
| <span class="hljs-meta">... </span> pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> (epoch + <span class="hljs-number">1</span>) % config.save_image_epochs == <span class="hljs-number">0</span> <span class="hljs-keyword">or</span> epoch == config.num_epochs - <span class="hljs-number">1</span>: | |
| <span class="hljs-meta">... </span> evaluate(config, epoch, pipeline) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> (epoch + <span class="hljs-number">1</span>) % config.save_model_epochs == <span class="hljs-number">0</span> <span class="hljs-keyword">or</span> epoch == config.num_epochs - <span class="hljs-number">1</span>: | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.push_to_hub: | |
| <span class="hljs-meta">... </span> upload_folder( | |
| <span class="hljs-meta">... </span> repo_id=repo_id, | |
| <span class="hljs-meta">... </span> folder_path=config.output_dir, | |
| <span class="hljs-meta">... </span> commit_message=<span class="hljs-string">f"Epoch <span class="hljs-subst">{epoch}</span>"</span>, | |
| <span class="hljs-meta">... </span> ignore_patterns=[<span class="hljs-string">"step_*"</span>, <span class="hljs-string">"epoch_*"</span>], | |
| <span class="hljs-meta">... </span> ) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">else</span>: | |
| <span class="hljs-meta">... </span> pipeline.save_pretrained(config.output_dir)`,wrap:!1}}),wl=new m({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBub3RlYm9va19sYXVuY2hlciUwQSUwQWFyZ3MlMjAlM0QlMjAoY29uZmlnJTJDJTIwbW9kZWwlMkMlMjBub2lzZV9zY2hlZHVsZXIlMkMlMjBvcHRpbWl6ZXIlMkMlMjB0cmFpbl9kYXRhbG9hZGVyJTJDJTIwbHJfc2NoZWR1bGVyKSUwQSUwQW5vdGVib29rX2xhdW5jaGVyKHRyYWluX2xvb3AlMkMlMjBhcmdzJTJDJTIwbnVtX3Byb2Nlc3NlcyUzRDEp",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher | |
| <span class="hljs-meta">>>> </span>args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) | |
| <span class="hljs-meta">>>> </span>notebook_launcher(train_loop, args, num_processes=<span class="hljs-number">1</span>)`,wrap:!1}}),hl=new m({props:{code:"aW1wb3J0JTIwZ2xvYiUwQSUwQXNhbXBsZV9pbWFnZXMlMjAlM0QlMjBzb3J0ZWQoZ2xvYi5nbG9iKGYlMjIlN0Jjb25maWcub3V0cHV0X2RpciU3RCUyRnNhbXBsZXMlMkYqLnBuZyUyMikpJTBBSW1hZ2Uub3BlbihzYW1wbGVfaW1hZ2VzJTVCLTElNUQp",highlighted:`<span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> glob | |
| <span class="hljs-meta">>>> </span>sample_images = <span class="hljs-built_in">sorted</span>(glob.glob(<span class="hljs-string">f"<span class="hljs-subst">{config.output_dir}</span>/samples/*.png"</span>)) | |
| <span class="hljs-meta">>>> </span>Image.<span class="hljs-built_in">open</span>(sample_images[-<span class="hljs-number">1</span>])`,wrap:!1}}),rl=new Vl({props:{title:"다음 단계",local:"다음-단계",headingTag:"h2"}}),ul=new fa({props:{source:"https://github.com/huggingface/diffusers/blob/main/docs/source/ko/tutorials/basic_training.md"}}),{c(){C=n("meta"),Al=M(),bl=n("p"),fl=M(),J(o.$$.fragment),Rl=M(),J(h.$$.fragment),Bl=M(),J(I.$$.fragment),Zl=M(),r=n("p"),r.innerHTML=Ws,kl=M(),Q=n("p"),Q.innerHTML=Ss,Gl=M(),w=n("blockquote"),w.innerHTML=Ds,El=M(),g=n("p"),g.innerHTML=Ys,Fl=M(),J(u.$$.fragment),Xl=M(),V=n("p"),V.innerHTML=zs,Nl=M(),J(b.$$.fragment),_l=M(),d=n("p"),d.textContent=vs,Ol=M(),J(A.$$.fragment),Wl=M(),f=n("p"),f.innerHTML=$s,Sl=M(),J(R.$$.fragment),Dl=M(),J(B.$$.fragment),Yl=M(),Z=n("p"),Z.innerHTML=xs,zl=M(),J(k.$$.fragment),vl=M(),J(G.$$.fragment),$l=M(),E=n("p"),E.innerHTML=Hs,xl=M(),J(F.$$.fragment),Hl=M(),X=n("p"),X.innerHTML=Ls,Ll=M(),N=n("p"),N.innerHTML=qs,ql=M(),J(_.$$.fragment),Kl=M(),O=n("p"),O.innerHTML=Ks,Pl=M(),W=n("p"),W.textContent=Ps,ls=M(),S=n("ul"),S.innerHTML=la,ss=M(),J(D.$$.fragment),as=M(),Y=n("p"),Y.innerHTML=sa,es=M(),J(z.$$.fragment),Ms=M(),v=n("p"),v.innerHTML=aa,ts=M(),J($.$$.fragment),ns=M(),J(x.$$.fragment),Us=M(),H=n("p"),H.innerHTML=ea,ps=M(),J(L.$$.fragment),Js=M(),q=n("p"),q.textContent=Ma,js=M(),J(K.$$.fragment),ys=M(),P=n("p"),P.textContent=ta,Ts=M(),J(ll.$$.fragment),cs=M(),sl=n("p"),sl.innerHTML=na,is=M(),al=n("p"),al.innerHTML=Ua,ms=M(),J(el.$$.fragment),Cs=M(),Ml=n("p"),Ml.innerHTML=pa,ws=M(),tl=n("p"),tl.textContent=Ja,os=M(),J(nl.$$.fragment),hs=M(),J(Ul.$$.fragment),Is=M(),pl=n("p"),pl.textContent=ja,rs=M(),Jl=n("p"),Jl.textContent=ya,Qs=M(),J(jl.$$.fragment),gs=M(),yl=n("p"),yl.innerHTML=Ta,us=M(),J(Tl.$$.fragment),Vs=M(),cl=n("p"),cl.textContent=ca,bs=M(),il=n("p"),il.textContent=ia,ds=M(),J(ml.$$.fragment),As=M(),Cl=n("p"),Cl.innerHTML=ma,fs=M(),J(wl.$$.fragment),Rs=M(),ol=n("p"),ol.textContent=Ca,Bs=M(),J(hl.$$.fragment),Zs=M(),Il=n("p"),Il.innerHTML=wa,ks=M(),J(rl.$$.fragment),Gs=M(),Ql=n("p"),Ql.innerHTML=oa,Es=M(),gl=n("ul"),gl.innerHTML=ha,Fs=M(),J(ul.$$.fragment),Xs=M(),dl=n("p"),this.h()},l(l){const s=ba("svelte-u9bgzb",document.head);C=U(s,"META",{name:!0,content:!0}),s.forEach(a),Al=t(l),bl=U(l,"P",{}),Ia(bl).forEach(a),fl=t(l),j(o.$$.fragment,l),Rl=t(l),j(h.$$.fragment,l),Bl=t(l),j(I.$$.fragment,l),Zl=t(l),r=U(l,"P",{"data-svelte-h":!0}),p(r)!=="svelte-1aqgth7"&&(r.innerHTML=Ws),kl=t(l),Q=U(l,"P",{"data-svelte-h":!0}),p(Q)!=="svelte-32ojwt"&&(Q.innerHTML=Ss),Gl=t(l),w=U(l,"BLOCKQUOTE",{class:!0,"data-svelte-h":!0}),p(w)!=="svelte-19biaqu"&&(w.innerHTML=Ds),El=t(l),g=U(l,"P",{"data-svelte-h":!0}),p(g)!=="svelte-1o0c2y8"&&(g.innerHTML=Ys),Fl=t(l),j(u.$$.fragment,l),Xl=t(l),V=U(l,"P",{"data-svelte-h":!0}),p(V)!=="svelte-1owzlqj"&&(V.innerHTML=zs),Nl=t(l),j(b.$$.fragment,l),_l=t(l),d=U(l,"P",{"data-svelte-h":!0}),p(d)!=="svelte-10ayust"&&(d.textContent=vs),Ol=t(l),j(A.$$.fragment,l),Wl=t(l),f=U(l,"P",{"data-svelte-h":!0}),p(f)!=="svelte-1w5k095"&&(f.innerHTML=$s),Sl=t(l),j(R.$$.fragment,l),Dl=t(l),j(B.$$.fragment,l),Yl=t(l),Z=U(l,"P",{"data-svelte-h":!0}),p(Z)!=="svelte-1rpvlkg"&&(Z.innerHTML=xs),zl=t(l),j(k.$$.fragment,l),vl=t(l),j(G.$$.fragment,l),$l=t(l),E=U(l,"P",{"data-svelte-h":!0}),p(E)!=="svelte-1our457"&&(E.innerHTML=Hs),xl=t(l),j(F.$$.fragment,l),Hl=t(l),X=U(l,"P",{"data-svelte-h":!0}),p(X)!=="svelte-1hi7huh"&&(X.innerHTML=Ls),Ll=t(l),N=U(l,"P",{"data-svelte-h":!0}),p(N)!=="svelte-g2btn3"&&(N.innerHTML=qs),ql=t(l),j(_.$$.fragment,l),Kl=t(l),O=U(l,"P",{"data-svelte-h":!0}),p(O)!=="svelte-12z3lda"&&(O.innerHTML=Ks),Pl=t(l),W=U(l,"P",{"data-svelte-h":!0}),p(W)!=="svelte-2vcep9"&&(W.textContent=Ps),ls=t(l),S=U(l,"UL",{"data-svelte-h":!0}),p(S)!=="svelte-lrd3tn"&&(S.innerHTML=la),ss=t(l),j(D.$$.fragment,l),as=t(l),Y=U(l,"P",{"data-svelte-h":!0}),p(Y)!=="svelte-mhh25q"&&(Y.innerHTML=sa),es=t(l),j(z.$$.fragment,l),Ms=t(l),v=U(l,"P",{"data-svelte-h":!0}),p(v)!=="svelte-jbxdac"&&(v.innerHTML=aa),ts=t(l),j($.$$.fragment,l),ns=t(l),j(x.$$.fragment,l),Us=t(l),H=U(l,"P",{"data-svelte-h":!0}),p(H)!=="svelte-ywj4cf"&&(H.innerHTML=ea),ps=t(l),j(L.$$.fragment,l),Js=t(l),q=U(l,"P",{"data-svelte-h":!0}),p(q)!=="svelte-1x8klru"&&(q.textContent=Ma),js=t(l),j(K.$$.fragment,l),ys=t(l),P=U(l,"P",{"data-svelte-h":!0}),p(P)!=="svelte-hu0j5c"&&(P.textContent=ta),Ts=t(l),j(ll.$$.fragment,l),cs=t(l),sl=U(l,"P",{"data-svelte-h":!0}),p(sl)!=="svelte-1ickp2n"&&(sl.innerHTML=na),is=t(l),al=U(l,"P",{"data-svelte-h":!0}),p(al)!=="svelte-boidtv"&&(al.innerHTML=Ua),ms=t(l),j(el.$$.fragment,l),Cs=t(l),Ml=U(l,"P",{"data-svelte-h":!0}),p(Ml)!=="svelte-3yki19"&&(Ml.innerHTML=pa),ws=t(l),tl=U(l,"P",{"data-svelte-h":!0}),p(tl)!=="svelte-1oo0f0r"&&(tl.textContent=Ja),os=t(l),j(nl.$$.fragment,l),hs=t(l),j(Ul.$$.fragment,l),Is=t(l),pl=U(l,"P",{"data-svelte-h":!0}),p(pl)!=="svelte-1syjdvo"&&(pl.textContent=ja),rs=t(l),Jl=U(l,"P",{"data-svelte-h":!0}),p(Jl)!=="svelte-1x5az67"&&(Jl.textContent=ya),Qs=t(l),j(jl.$$.fragment,l),gs=t(l),yl=U(l,"P",{"data-svelte-h":!0}),p(yl)!=="svelte-baczkn"&&(yl.innerHTML=Ta),us=t(l),j(Tl.$$.fragment,l),Vs=t(l),cl=U(l,"P",{"data-svelte-h":!0}),p(cl)!=="svelte-14zfk37"&&(cl.textContent=ca),bs=t(l),il=U(l,"P",{"data-svelte-h":!0}),p(il)!=="svelte-u719rq"&&(il.textContent=ia),ds=t(l),j(ml.$$.fragment,l),As=t(l),Cl=U(l,"P",{"data-svelte-h":!0}),p(Cl)!=="svelte-zyu14c"&&(Cl.innerHTML=ma),fs=t(l),j(wl.$$.fragment,l),Rs=t(l),ol=U(l,"P",{"data-svelte-h":!0}),p(ol)!=="svelte-1dbylv7"&&(ol.textContent=Ca),Bs=t(l),j(hl.$$.fragment,l),Zs=t(l),Il=U(l,"P",{"data-svelte-h":!0}),p(Il)!=="svelte-1bzvmcv"&&(Il.innerHTML=wa),ks=t(l),j(rl.$$.fragment,l),Gs=t(l),Ql=U(l,"P",{"data-svelte-h":!0}),p(Ql)!=="svelte-1mf9wqw"&&(Ql.innerHTML=oa),Es=t(l),gl=U(l,"UL",{"data-svelte-h":!0}),p(gl)!=="svelte-y5d1yz"&&(gl.innerHTML=ha),Fs=t(l),j(ul.$$.fragment,l),Xs=t(l),dl=U(l,"P",{}),Ia(dl).forEach(a),this.h()},h(){_s(C,"name","hf:doc:metadata"),_s(C,"content",Za),_s(w,"class","tip")},m(l,s){da(document.head,C),e(l,Al,s),e(l,bl,s),e(l,fl,s),y(o,l,s),e(l,Rl,s),y(h,l,s),e(l,Bl,s),y(I,l,s),e(l,Zl,s),e(l,r,s),e(l,kl,s),e(l,Q,s),e(l,Gl,s),e(l,w,s),e(l,El,s),e(l,g,s),e(l,Fl,s),y(u,l,s),e(l,Xl,s),e(l,V,s),e(l,Nl,s),y(b,l,s),e(l,_l,s),e(l,d,s),e(l,Ol,s),y(A,l,s),e(l,Wl,s),e(l,f,s),e(l,Sl,s),y(R,l,s),e(l,Dl,s),y(B,l,s),e(l,Yl,s),e(l,Z,s),e(l,zl,s),y(k,l,s),e(l,vl,s),y(G,l,s),e(l,$l,s),e(l,E,s),e(l,xl,s),y(F,l,s),e(l,Hl,s),e(l,X,s),e(l,Ll,s),e(l,N,s),e(l,ql,s),y(_,l,s),e(l,Kl,s),e(l,O,s),e(l,Pl,s),e(l,W,s),e(l,ls,s),e(l,S,s),e(l,ss,s),y(D,l,s),e(l,as,s),e(l,Y,s),e(l,es,s),y(z,l,s),e(l,Ms,s),e(l,v,s),e(l,ts,s),y($,l,s),e(l,ns,s),y(x,l,s),e(l,Us,s),e(l,H,s),e(l,ps,s),y(L,l,s),e(l,Js,s),e(l,q,s),e(l,js,s),y(K,l,s),e(l,ys,s),e(l,P,s),e(l,Ts,s),y(ll,l,s),e(l,cs,s),e(l,sl,s),e(l,is,s),e(l,al,s),e(l,ms,s),y(el,l,s),e(l,Cs,s),e(l,Ml,s),e(l,ws,s),e(l,tl,s),e(l,os,s),y(nl,l,s),e(l,hs,s),y(Ul,l,s),e(l,Is,s),e(l,pl,s),e(l,rs,s),e(l,Jl,s),e(l,Qs,s),y(jl,l,s),e(l,gs,s),e(l,yl,s),e(l,us,s),y(Tl,l,s),e(l,Vs,s),e(l,cl,s),e(l,bs,s),e(l,il,s),e(l,ds,s),y(ml,l,s),e(l,As,s),e(l,Cl,s),e(l,fs,s),y(wl,l,s),e(l,Rs,s),e(l,ol,s),e(l,Bs,s),y(hl,l,s),e(l,Zs,s),e(l,Il,s),e(l,ks,s),y(rl,l,s),e(l,Gs,s),e(l,Ql,s),e(l,Es,s),e(l,gl,s),e(l,Fs,s),y(ul,l,s),e(l,Xs,s),e(l,dl,s),Ns=!0},p:Qa,i(l){Ns||(T(o.$$.fragment,l),T(h.$$.fragment,l),T(I.$$.fragment,l),T(u.$$.fragment,l),T(b.$$.fragment,l),T(A.$$.fragment,l),T(R.$$.fragment,l),T(B.$$.fragment,l),T(k.$$.fragment,l),T(G.$$.fragment,l),T(F.$$.fragment,l),T(_.$$.fragment,l),T(D.$$.fragment,l),T(z.$$.fragment,l),T($.$$.fragment,l),T(x.$$.fragment,l),T(L.$$.fragment,l),T(K.$$.fragment,l),T(ll.$$.fragment,l),T(el.$$.fragment,l),T(nl.$$.fragment,l),T(Ul.$$.fragment,l),T(jl.$$.fragment,l),T(Tl.$$.fragment,l),T(ml.$$.fragment,l),T(wl.$$.fragment,l),T(hl.$$.fragment,l),T(rl.$$.fragment,l),T(ul.$$.fragment,l),Ns=!0)},o(l){c(o.$$.fragment,l),c(h.$$.fragment,l),c(I.$$.fragment,l),c(u.$$.fragment,l),c(b.$$.fragment,l),c(A.$$.fragment,l),c(R.$$.fragment,l),c(B.$$.fragment,l),c(k.$$.fragment,l),c(G.$$.fragment,l),c(F.$$.fragment,l),c(_.$$.fragment,l),c(D.$$.fragment,l),c(z.$$.fragment,l),c($.$$.fragment,l),c(x.$$.fragment,l),c(L.$$.fragment,l),c(K.$$.fragment,l),c(ll.$$.fragment,l),c(el.$$.fragment,l),c(nl.$$.fragment,l),c(Ul.$$.fragment,l),c(jl.$$.fragment,l),c(Tl.$$.fragment,l),c(ml.$$.fragment,l),c(wl.$$.fragment,l),c(hl.$$.fragment,l),c(rl.$$.fragment,l),c(ul.$$.fragment,l),Ns=!1},d(l){l&&(a(Al),a(bl),a(fl),a(Rl),a(Bl),a(Zl),a(r),a(kl),a(Q),a(Gl),a(w),a(El),a(g),a(Fl),a(Xl),a(V),a(Nl),a(_l),a(d),a(Ol),a(Wl),a(f),a(Sl),a(Dl),a(Yl),a(Z),a(zl),a(vl),a($l),a(E),a(xl),a(Hl),a(X),a(Ll),a(N),a(ql),a(Kl),a(O),a(Pl),a(W),a(ls),a(S),a(ss),a(as),a(Y),a(es),a(Ms),a(v),a(ts),a(ns),a(Us),a(H),a(ps),a(Js),a(q),a(js),a(ys),a(P),a(Ts),a(cs),a(sl),a(is),a(al),a(ms),a(Cs),a(Ml),a(ws),a(tl),a(os),a(hs),a(Is),a(pl),a(rs),a(Jl),a(Qs),a(gs),a(yl),a(us),a(Vs),a(cl),a(bs),a(il),a(ds),a(As),a(Cl),a(fs),a(Rs),a(ol),a(Bs),a(Zs),a(Il),a(ks),a(Gs),a(Ql),a(Es),a(gl),a(Fs),a(Xs),a(dl)),a(C),i(o,l),i(h,l),i(I,l),i(u,l),i(b,l),i(A,l),i(R,l),i(B,l),i(k,l),i(G,l),i(F,l),i(_,l),i(D,l),i(z,l),i($,l),i(x,l),i(L,l),i(K,l),i(ll,l),i(el,l),i(nl,l),i(Ul,l),i(jl,l),i(Tl,l),i(ml,l),i(wl,l),i(hl,l),i(rl,l),i(ul,l)}}}const Za='{"title":"Diffusion 모델을 학습하기","local":"diffusion-모델을-학습하기","sections":[{"title":"학습 구성","local":"학습-구성","sections":[],"depth":2},{"title":"데이터셋 불러오기","local":"데이터셋-불러오기","sections":[],"depth":2},{"title":"UNet2DModel 생성하기","local":"unet2dmodel-생성하기","sections":[],"depth":2},{"title":"스케줄러 생성하기","local":"스케줄러-생성하기","sections":[],"depth":2},{"title":"모델 학습하기","local":"모델-학습하기","sections":[],"depth":2},{"title":"다음 단계","local":"다음-단계","sections":[],"depth":2}],"depth":1}';function ka(Os){return ga(()=>{new URLSearchParams(window.location.search).get("fw")}),[]}class _a extends ua{constructor(C){super(),Va(this,C,ka,Ba,ra,{})}}export{_a as component}; | |
Xet Storage Details
- Size:
- 68.7 kB
- Xet hash:
- ec89d2ea2a8cf4e65e240a8ca55cbbe2b1052563ad0edf077f635ba63857e57a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.