Spaces:
Running on Zero
Running on Zero
| import logging | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from funcineforge.models.utils.llm_decoding import LLMDecoder | |
| from funcineforge.utils.device_funcs import to_device | |
| import numpy as np | |
| from funcineforge.models.utils import dtype_map | |
| from funcineforge.models import FunCineForgeSpecAug | |
| from transformers import AutoModelForCausalLM | |
| import pickle | |
| class FunCineForgeLM(nn.Module): | |
| def __init__( | |
| self, | |
| llm: str = None, | |
| llm_conf: dict = None, | |
| input_size: int = 80, | |
| length_normalized_loss: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| # llm | |
| self.llm_conf = llm_conf | |
| self.llm = None | |
| init_param_path = llm_conf.get("init_param_path", "") | |
| llm_load_kwargs = llm_conf.get("load_kwargs", {}) | |
| self.sample_rate = kwargs.get("sample_rate", 24000) | |
| self.token_rate = kwargs.get("token_rate", 25) | |
| if kwargs.get("infer_lora_merged", False): | |
| llm_conf["use_qlora"] = False | |
| llm_conf["use_lora"] = False | |
| kwargs["infer_use_lora"] = False | |
| model = AutoModelForCausalLM.from_pretrained( | |
| init_param_path, | |
| load_in_8bit=None, | |
| device_map=None, | |
| use_cache=None, | |
| **llm_load_kwargs, | |
| ) | |
| freeze = llm_conf.get("freeze", True) | |
| if freeze: | |
| for name, param in model.named_parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| logging.info(f"use_lora: {llm_conf.get('use_lora', False)}, use_qlora: {llm_conf.get('use_qlora', False)}, infer_use_lora: {kwargs.get('infer_use_lora',False)}, infer_lora_merged: {kwargs.get('infer_lora_merged',False)}") | |
| if llm_conf.get("activation_checkpoint", False): | |
| model.gradient_checkpointing_enable() | |
| self.llm_dtype = llm_conf.get("llm_dtype", "fp32") | |
| self.llm = model.to(dtype_map[self.llm_dtype]) | |
| llm_dim = model.get_input_embeddings().weight.shape[-1] | |
| if (not llm_conf.get("use_lora", False)) and (not kwargs.get("infer_use_lora",False)): | |
| del self.llm.lm_head | |
| self.codec_unit = kwargs.get("codec_unit", 6761) | |
| self.timespk_unit = kwargs.get("timespk_unit", 1550) | |
| self.codec_embed = nn.Embedding(self.codec_unit, llm_dim, 0) | |
| self.timespk_embed = nn.Embedding(self.timespk_unit, llm_dim, 0) | |
| self.codec_head = nn.Linear(llm_dim, self.codec_unit, bias=False) | |
| self.face_size = kwargs.get("face_size", 512) | |
| self.face_linear = nn.Linear(self.face_size, llm_dim) | |
| self.length_normalized_loss = length_normalized_loss | |
| self.ignore_id = kwargs.get("ignore_id", -100) | |
| specaug = kwargs.get("specaug", None) | |
| specaug_conf = kwargs.get("specaug_conf", {}) | |
| if specaug is not None: | |
| specaug = FunCineForgeSpecAug(**specaug_conf) | |
| self.specaug = specaug | |
| rank = int(os.environ.get("RANK", 0)) | |
| logging.info(f"rank: {rank}, model is builded.") | |
| def insert_face_embeddings( | |
| self, inputs_embeds, face_emb, attention_mask, labels_ids, | |
| codec_len, insert_pos, device | |
| ): | |
| """ | |
| 将face_emb插入到inputs_embeds中的指定位置, 同步更新attention_mask和labels_ids | |
| Args: | |
| inputs_embeds: (batch_size, token_num, dims) 输入embedding | |
| face_emb: (batch_size, max_face_len, dims) 面部embedding | |
| attention_mask: (batch_size, token_num) 注意力mask | |
| labels_ids: (batch_size, token_num) 标签ID | |
| codec_len: (batch_size,) 每个样本的实际face_emb长度 | |
| insert_pos: int 插入位置, SOS token之后 | |
| device | |
| Returns: | |
| padded_inputs_embeds: 插入face_emb并padding后的inputs_embeds | |
| padded_attention_mask: 更新后的attention_mask | |
| padded_labels: 更新后的labels_ids | |
| """ | |
| batch_size, token_num, dims = inputs_embeds.shape | |
| max_face_len = face_emb.size(1) | |
| # 预计算新序列的最大长度 | |
| new_max_length = token_num + max_face_len | |
| # 预分配输出张量 | |
| padded_inputs_embeds = torch.zeros(batch_size, new_max_length, dims, device=device) | |
| padded_attention_mask = torch.zeros(batch_size, new_max_length, device=device, dtype=attention_mask.dtype) | |
| padded_labels = torch.full((batch_size, new_max_length), self.ignore_id, device=device, dtype=labels_ids.dtype) | |
| for i in range(batch_size): | |
| current_face_len = codec_len[i].item() | |
| # 直接填充,避免中间拼接 | |
| padded_inputs_embeds[i, :insert_pos] = inputs_embeds[i, :insert_pos] | |
| padded_inputs_embeds[i, insert_pos:insert_pos+current_face_len] = face_emb[i, :current_face_len] | |
| padded_inputs_embeds[i, insert_pos+current_face_len:token_num+current_face_len] = inputs_embeds[i, insert_pos:] | |
| # 同样处理mask和labels | |
| padded_attention_mask[i, :insert_pos] = attention_mask[i, :insert_pos] | |
| padded_attention_mask[i, insert_pos:insert_pos+current_face_len] = 1 | |
| padded_attention_mask[i, insert_pos+current_face_len:token_num+current_face_len] = attention_mask[i, insert_pos:] | |
| padded_labels[i, :insert_pos] = labels_ids[i, :insert_pos] | |
| padded_labels[i, insert_pos:insert_pos+current_face_len] = self.ignore_id | |
| padded_labels[i, insert_pos+current_face_len:token_num+current_face_len] = labels_ids[i, insert_pos:] | |
| return padded_inputs_embeds, padded_attention_mask, padded_labels | |
| def load_data(self, contents: dict, **kwargs): | |
| lm_use_prompt = kwargs.get("lm_use_prompt", True) | |
| tokenizer = kwargs.get("tokenizer") | |
| # text + clue | |
| text = contents["text"] | |
| clue = "<|startofclue|>" + contents["clue"] + "<|endofclue|>" | |
| if lm_use_prompt: | |
| text = clue + text | |
| text_ids = tokenizer.encode(text) | |
| text_len = len(text_ids) | |
| # timespk_ids | |
| timespk_ids = contents["timespk_ids"].tolist() | |
| type_id = contents["type_id"] | |
| # sequence | |
| sequence = [ | |
| kwargs['dataset_conf']["sos"], | |
| *text_ids, | |
| type_id, | |
| *timespk_ids, | |
| kwargs['dataset_conf']["turn_of_speech"] | |
| ] | |
| input_ids = torch.tensor(sequence, dtype=torch.int64) | |
| # flag tensors | |
| text_flag = torch.zeros(len(sequence), dtype=torch.float32) | |
| timespk_flag = torch.zeros(len(sequence), dtype=torch.float32) | |
| codec_flag = torch.zeros(len(sequence), dtype=torch.float32) | |
| text_flag[1: text_len+1] = 1 | |
| timespk_flag[text_len+1: -1] = 1 | |
| codec_flag = 1 - text_flag - timespk_flag | |
| # face embs | |
| speech_len = contents["speech_len"] | |
| face_embs = torch.zeros((speech_len, self.face_size), dtype=torch.float32) | |
| face_path = contents.get("face") | |
| with open(face_path, 'rb') as f: | |
| stat_obj = pickle.load(f) | |
| embeddings = stat_obj['embeddings'] | |
| faceI = stat_obj['faceI'] | |
| for emb, frameI in zip(embeddings, faceI): | |
| fi = int(frameI) | |
| if 0 <= fi < speech_len: | |
| end = min(fi + 5, speech_len) | |
| face_embs[fi:end] = torch.from_numpy(emb).expand(end - fi, -1) | |
| # batch dimension | |
| input_ids = input_ids[None, :] | |
| text_flag = text_flag[None, :] | |
| timespk_flag = timespk_flag[None, :] | |
| codec_flag = codec_flag[None, :] | |
| face_embs = face_embs[None, :, :] | |
| output = { | |
| "input_ids": input_ids, | |
| "face_embs": face_embs, | |
| "text_flag": text_flag > 0, | |
| "timespk_flag": timespk_flag > 0, | |
| "codec_flag": codec_flag > 0, | |
| "prompt_codec": None, # you can add prompt codec here if needed | |
| } | |
| return output | |
| def inference_prepare(self, data_in, **kwargs): | |
| if kwargs.get("batch_size", 1) > 1: | |
| raise NotImplementedError("batch decoding is not implemented") | |
| output = self.load_data(data_in[0], **kwargs) | |
| batch = to_device(output, kwargs["device"]) | |
| input_ids = batch["input_ids"] | |
| input_ids = input_ids * (input_ids > 0) | |
| text_flag = batch["text_flag"] | |
| timespk_flag = batch["timespk_flag"] | |
| codec_flag = batch["codec_flag"] | |
| face_embs = batch["face_embs"] | |
| if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)): | |
| text_embeds = self.llm.base_model.model.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1) | |
| else: | |
| text_embeds = self.llm.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1) | |
| timespk_embeds = self.timespk_embed(input_ids * timespk_flag) * timespk_flag.unsqueeze(-1) | |
| codec_embs = self.codec_embed(input_ids * codec_flag) * codec_flag.unsqueeze(-1) | |
| face_embs = self.face_linear(face_embs) | |
| inputs_embeds = text_embeds + timespk_embeds + codec_embs | |
| inputs_embeds = torch.cat([ | |
| inputs_embeds[:, 0:1, :], # sos token | |
| face_embs, # face embeddings | |
| inputs_embeds[:, 1:, :] # inputs_embeds after sos | |
| ], dim=1) | |
| prompt_codec = batch.get("prompt_codec", None) | |
| if prompt_codec is not None: | |
| codec_emb = self.codec_embed(prompt_codec) | |
| inputs_embeds = torch.cat((inputs_embeds, codec_emb), dim=1) | |
| return inputs_embeds | |
| def inference( | |
| self, | |
| data_in, | |
| data_lengths=None, | |
| key: list = None, | |
| **kwargs, | |
| ): | |
| uttid = key[0] | |
| inputs_emb = self.inference_prepare(data_in, **kwargs) | |
| logging.info(f"{uttid}: min length: {kwargs['min_length']}, max length: {kwargs['max_length']}") | |
| dtype = dtype_map[kwargs.get("llm_dtype", "fp32")] | |
| if not hasattr(self, "llm_generator"): | |
| llm_generator_conf = kwargs.get("dataset_conf", {}) | |
| self.llm_generator = LLMDecoder( | |
| token_embeder=self.codec_embed, | |
| **llm_generator_conf | |
| ).to(dtype) | |
| if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)): | |
| self.llm.base_model.model.lm_head = self.codec_head.to(dtype) | |
| else: | |
| self.llm.lm_head = self.codec_head.to(dtype) | |
| gen_codec, hit_eos, states = self.llm_generator( | |
| inputs_emb.to(dtype), | |
| self.llm, | |
| states=kwargs.get("states", {}), | |
| **kwargs | |
| ) | |
| output_dir = kwargs.get("output_dir", None) | |
| if output_dir is not None: | |
| output_dir = os.path.join(output_dir, "codec") | |
| os.makedirs(output_dir, exist_ok=True) | |
| np.save( | |
| os.path.join(output_dir, f"{key[0]}.npy"), | |
| gen_codec[0].cpu().numpy() | |
| ) | |
| return gen_codec, hit_eos, states |