import sys,os cur_dir = os.path.dirname(os.path.abspath(__file__)) if __name__=='__main__': sys.path.append(os.path.abspath(os.path.join(cur_dir, '..'))) from confs import * import json import argparse, os, sys, glob import cv2 import torch import numpy as np from MoE import * from multiTask_model import * from lora_layers import * from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange from itertools import islice from einops import rearrange from torchvision.utils import make_grid from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A import time import copy from pytorch_lightning import seed_everything from torch import autocast from contextlib import contextmanager, nullcontext import torchvision from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.models.diffusion.bank import Bank from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from transformers import AutoFeatureExtractor # import clip from torchvision.transforms import Resize from fnmatch import fnmatch from PIL import Image from torchvision.transforms import PILToTensor #---------------------------------------------------------------------------- def get_moe(): if 1: seed_everything(42) # torch.cuda.set_device(opt.device_ID) model :LatentDiffusion = instantiate_from_config(OmegaConf.load(f"LatentDiffusion.yaml").model,) if REFNET.ENABLE: assert model.model.diffusion_model_refNet.is_refNet device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cpu") model = model.to(device) if FOR_upcycle_ckpt_GEN_or_USE: del model.ptsM_Generator def average_module_weight( src_modules: list, ): """Average the weights of multiple modules""" if not src_modules: return None # Get the state dict of the first module as template avg_state_dict = {} first_state_dict = src_modules[0].state_dict() # Initialize with zeros for key in first_state_dict: avg_state_dict[key] = torch.zeros_like(first_state_dict[key]) # Sum for module in src_modules: module_state_dict = module.state_dict() for key in avg_state_dict: avg_state_dict[key] += module_state_dict[key] # Average for key in avg_state_dict: avg_state_dict[key] /= len(src_modules) return avg_state_dict def recursive_average_module_weight( tgt_module: nn.Module, src_modules: list, cb, ): """ Recursively find modules and replace with averaged weights based on callback """ for name, child in tgt_module.named_children(): if 1: # Get corresponding modules from source models src_child_modules = [] for src_module in src_modules: src_child = getattr(src_module, name) assert src_child is not None,name src_child_modules.append(src_child) # assert not isinstance(child, TaskSpecific_MoE) if cb(child, name, tgt_module): print(f"[recursive_average_module_weight] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}") # Average & load avg_weights = average_module_weight(src_child_modules) child.load_state_dict(avg_weights) else: recursive_average_module_weight(child, src_child_modules, cb) return tgt_module def replace_module_with_TaskSpecific( tgt_module: nn.Module,# tgt module src_modules: list, cb, parent_name: str = "", depth :int = 0, ): for name, child in tgt_module.named_children(): if 1: # Get corresponding modules from source models src_child_modules = [] for src_module in src_modules: src_child = getattr(src_module, name) assert src_child is not None,name src_child_modules.append(src_child) assert not isinstance(child, TaskSpecific_MoE) full_name = f"{parent_name}.{name}" if cb(child, name, full_name, tgt_module): print(f"[replace_module_with_TaskSpecific] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}") setattr(tgt_module, name, TaskSpecific_MoE(src_child_modules,TASKS)) else: if depth<=0: replace_module_with_TaskSpecific(child, src_child_modules,cb,parent_name=full_name,depth=depth+1) return tgt_module if not FOR_upcycle_ckpt_GEN_or_USE: modelMOE :LatentDiffusion = model del model if 1: # ensure distinct module instances per task (avoid shared identities) with open(PRETRAIN_JSON_PATH, 'r') as f: global_.moduleName_2_adaRank = json.load(f) print(f"loaded from {PRETRAIN_JSON_PATH=}") _src0 = copy.deepcopy(modelMOE.model.diffusion_model) _src1 = copy.deepcopy(modelMOE.model.diffusion_model) _src2 = copy.deepcopy(modelMOE.model.diffusion_model) _src3 = copy.deepcopy(modelMOE.model.diffusion_model) replace_modules_lossless( modelMOE.model.diffusion_model, [ _src0, _src1, _src2, _src3 ], [0,1,2,3], parent_name=".model.diffusion_model", ) # Build-time dummy wrapping for task-specific heads so that ckpt keys match modelMOE.ID_proj_out = TaskSpecific_MoE([ copy.deepcopy(modelMOE.ID_proj_out), copy.deepcopy(modelMOE.ID_proj_out), copy.deepcopy(modelMOE.ID_proj_out), ], [0,2,3]) modelMOE.landmark_proj_out = TaskSpecific_MoE([ copy.deepcopy(modelMOE.landmark_proj_out), copy.deepcopy(modelMOE.landmark_proj_out), copy.deepcopy(modelMOE.landmark_proj_out), ], [0,2,3]) modelMOE.proj_out_source__head = TaskSpecific_MoE([ copy.deepcopy(modelMOE.proj_out_source__head), copy.deepcopy(modelMOE.proj_out_source__head), ], [2,3]) # Upcycle single refNet using three source refNets, and keep only one if REFNET.ENABLE: shared_ref = modelMOE.model.diffusion_model_refNet src0 = shared_ref src1 = copy.deepcopy(shared_ref) src2 = copy.deepcopy(shared_ref) src3 = copy.deepcopy(shared_ref) replace_modules_lossless(shared_ref, [src0, src1, src2, src3],[0,1,2,3], parent_name=".model.diffusion_model_refNet", for_refnet=True) # load from ./modelMOE.ckpt time.sleep(20*rank_) print(f"ckpt load over. m,u:") # Initialize bank here (after model structure is finalized) if REFNET.ENABLE : modelMOE.model.bank = Bank(reader=modelMOE.model.diffusion_model,writer=modelMOE.model.diffusion_model_refNet) if __name__=='__main__': for key in sorted( get_representative_moduleNames(modelMOE.state_dict().keys()) ): print(f" - {key}") return modelMOE