| | import torch |
| | import torchvision |
| | import os |
| | import gc |
| | import tqdm |
| | import matplotlib.pyplot as plt |
| | import torchvision.transforms as transforms |
| | from transformers import CLIPTextModel |
| | from peft import PeftModel, LoraConfig |
| | from lora_w2w import LoRAw2w |
| | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler |
| | from peft.utils.save_and_load import load_peft_weights, set_peft_model_state_dict |
| | from transformers import AutoTokenizer, PretrainedConfig |
| | from PIL import Image |
| | import warnings |
| | warnings.filterwarnings("ignore") |
| | from diffusers import ( |
| | AutoencoderKL, |
| | DDPMScheduler, |
| | DiffusionPipeline, |
| | DPMSolverMultistepScheduler, |
| | UNet2DConditionModel, |
| | PNDMScheduler, |
| | StableDiffusionPipeline |
| | ) |
| |
|
| |
|
| |
|
| | |
| |
|
| |
|
| | def sample_weights(unet, proj, mean, std, v, device, factor = 1.0): |
| | |
| | m = torch.mean(proj, 0) |
| | standev = torch.std(proj, 0) |
| | del proj |
| | torch.cuda.empty_cache() |
| | |
| | sample = torch.zeros([1, 1000]).to(device) |
| | for i in range(1000): |
| | sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1)) |
| |
|
| | return sample |
| |
|
| |
|
| |
|
| |
|
| |
|