Fun-CineForge-Demo / funcineforge /models /language_model.py
xuan3986's picture
Upload 111 files
03022ee verified
import logging
import os
import torch
import torch.nn as nn
from funcineforge.models.utils.llm_decoding import LLMDecoder
from funcineforge.utils.device_funcs import to_device
import numpy as np
from funcineforge.models.utils import dtype_map
from funcineforge.models import FunCineForgeSpecAug
from transformers import AutoModelForCausalLM
import pickle
class FunCineForgeLM(nn.Module):
def __init__(
self,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
length_normalized_loss: bool = False,
**kwargs,
):
super().__init__()
# llm
self.llm_conf = llm_conf
self.llm = None
init_param_path = llm_conf.get("init_param_path", "")
llm_load_kwargs = llm_conf.get("load_kwargs", {})
self.sample_rate = kwargs.get("sample_rate", 24000)
self.token_rate = kwargs.get("token_rate", 25)
if kwargs.get("infer_lora_merged", False):
llm_conf["use_qlora"] = False
llm_conf["use_lora"] = False
kwargs["infer_use_lora"] = False
model = AutoModelForCausalLM.from_pretrained(
init_param_path,
load_in_8bit=None,
device_map=None,
use_cache=None,
**llm_load_kwargs,
)
freeze = llm_conf.get("freeze", True)
if freeze:
for name, param in model.named_parameters():
param.requires_grad = False
model.eval()
logging.info(f"use_lora: {llm_conf.get('use_lora', False)}, use_qlora: {llm_conf.get('use_qlora', False)}, infer_use_lora: {kwargs.get('infer_use_lora',False)}, infer_lora_merged: {kwargs.get('infer_lora_merged',False)}")
if llm_conf.get("activation_checkpoint", False):
model.gradient_checkpointing_enable()
self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
self.llm = model.to(dtype_map[self.llm_dtype])
llm_dim = model.get_input_embeddings().weight.shape[-1]
if (not llm_conf.get("use_lora", False)) and (not kwargs.get("infer_use_lora",False)):
del self.llm.lm_head
self.codec_unit = kwargs.get("codec_unit", 6761)
self.timespk_unit = kwargs.get("timespk_unit", 1550)
self.codec_embed = nn.Embedding(self.codec_unit, llm_dim, 0)
self.timespk_embed = nn.Embedding(self.timespk_unit, llm_dim, 0)
self.codec_head = nn.Linear(llm_dim, self.codec_unit, bias=False)
self.face_size = kwargs.get("face_size", 512)
self.face_linear = nn.Linear(self.face_size, llm_dim)
self.length_normalized_loss = length_normalized_loss
self.ignore_id = kwargs.get("ignore_id", -100)
specaug = kwargs.get("specaug", None)
specaug_conf = kwargs.get("specaug_conf", {})
if specaug is not None:
specaug = FunCineForgeSpecAug(**specaug_conf)
self.specaug = specaug
rank = int(os.environ.get("RANK", 0))
logging.info(f"rank: {rank}, model is builded.")
def insert_face_embeddings(
self, inputs_embeds, face_emb, attention_mask, labels_ids,
codec_len, insert_pos, device
):
"""
将face_emb插入到inputs_embeds中的指定位置, 同步更新attention_mask和labels_ids
Args:
inputs_embeds: (batch_size, token_num, dims) 输入embedding
face_emb: (batch_size, max_face_len, dims) 面部embedding
attention_mask: (batch_size, token_num) 注意力mask
labels_ids: (batch_size, token_num) 标签ID
codec_len: (batch_size,) 每个样本的实际face_emb长度
insert_pos: int 插入位置, SOS token之后
device
Returns:
padded_inputs_embeds: 插入face_emb并padding后的inputs_embeds
padded_attention_mask: 更新后的attention_mask
padded_labels: 更新后的labels_ids
"""
batch_size, token_num, dims = inputs_embeds.shape
max_face_len = face_emb.size(1)
# 预计算新序列的最大长度
new_max_length = token_num + max_face_len
# 预分配输出张量
padded_inputs_embeds = torch.zeros(batch_size, new_max_length, dims, device=device)
padded_attention_mask = torch.zeros(batch_size, new_max_length, device=device, dtype=attention_mask.dtype)
padded_labels = torch.full((batch_size, new_max_length), self.ignore_id, device=device, dtype=labels_ids.dtype)
for i in range(batch_size):
current_face_len = codec_len[i].item()
# 直接填充,避免中间拼接
padded_inputs_embeds[i, :insert_pos] = inputs_embeds[i, :insert_pos]
padded_inputs_embeds[i, insert_pos:insert_pos+current_face_len] = face_emb[i, :current_face_len]
padded_inputs_embeds[i, insert_pos+current_face_len:token_num+current_face_len] = inputs_embeds[i, insert_pos:]
# 同样处理mask和labels
padded_attention_mask[i, :insert_pos] = attention_mask[i, :insert_pos]
padded_attention_mask[i, insert_pos:insert_pos+current_face_len] = 1
padded_attention_mask[i, insert_pos+current_face_len:token_num+current_face_len] = attention_mask[i, insert_pos:]
padded_labels[i, :insert_pos] = labels_ids[i, :insert_pos]
padded_labels[i, insert_pos:insert_pos+current_face_len] = self.ignore_id
padded_labels[i, insert_pos+current_face_len:token_num+current_face_len] = labels_ids[i, insert_pos:]
return padded_inputs_embeds, padded_attention_mask, padded_labels
def load_data(self, contents: dict, **kwargs):
lm_use_prompt = kwargs.get("lm_use_prompt", True)
tokenizer = kwargs.get("tokenizer")
# text + clue
text = contents["text"]
clue = "<|startofclue|>" + contents["clue"] + "<|endofclue|>"
if lm_use_prompt:
text = clue + text
text_ids = tokenizer.encode(text)
text_len = len(text_ids)
# timespk_ids
timespk_ids = contents["timespk_ids"].tolist()
type_id = contents["type_id"]
# sequence
sequence = [
kwargs['dataset_conf']["sos"],
*text_ids,
type_id,
*timespk_ids,
kwargs['dataset_conf']["turn_of_speech"]
]
input_ids = torch.tensor(sequence, dtype=torch.int64)
# flag tensors
text_flag = torch.zeros(len(sequence), dtype=torch.float32)
timespk_flag = torch.zeros(len(sequence), dtype=torch.float32)
codec_flag = torch.zeros(len(sequence), dtype=torch.float32)
text_flag[1: text_len+1] = 1
timespk_flag[text_len+1: -1] = 1
codec_flag = 1 - text_flag - timespk_flag
# face embs
speech_len = contents["speech_len"]
face_embs = torch.zeros((speech_len, self.face_size), dtype=torch.float32)
face_path = contents.get("face")
with open(face_path, 'rb') as f:
stat_obj = pickle.load(f)
embeddings = stat_obj['embeddings']
faceI = stat_obj['faceI']
for emb, frameI in zip(embeddings, faceI):
fi = int(frameI)
if 0 <= fi < speech_len:
end = min(fi + 5, speech_len)
face_embs[fi:end] = torch.from_numpy(emb).expand(end - fi, -1)
# batch dimension
input_ids = input_ids[None, :]
text_flag = text_flag[None, :]
timespk_flag = timespk_flag[None, :]
codec_flag = codec_flag[None, :]
face_embs = face_embs[None, :, :]
output = {
"input_ids": input_ids,
"face_embs": face_embs,
"text_flag": text_flag > 0,
"timespk_flag": timespk_flag > 0,
"codec_flag": codec_flag > 0,
"prompt_codec": None, # you can add prompt codec here if needed
}
return output
def inference_prepare(self, data_in, **kwargs):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
output = self.load_data(data_in[0], **kwargs)
batch = to_device(output, kwargs["device"])
input_ids = batch["input_ids"]
input_ids = input_ids * (input_ids > 0)
text_flag = batch["text_flag"]
timespk_flag = batch["timespk_flag"]
codec_flag = batch["codec_flag"]
face_embs = batch["face_embs"]
if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
text_embeds = self.llm.base_model.model.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1)
else:
text_embeds = self.llm.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1)
timespk_embeds = self.timespk_embed(input_ids * timespk_flag) * timespk_flag.unsqueeze(-1)
codec_embs = self.codec_embed(input_ids * codec_flag) * codec_flag.unsqueeze(-1)
face_embs = self.face_linear(face_embs)
inputs_embeds = text_embeds + timespk_embeds + codec_embs
inputs_embeds = torch.cat([
inputs_embeds[:, 0:1, :], # sos token
face_embs, # face embeddings
inputs_embeds[:, 1:, :] # inputs_embeds after sos
], dim=1)
prompt_codec = batch.get("prompt_codec", None)
if prompt_codec is not None:
codec_emb = self.codec_embed(prompt_codec)
inputs_embeds = torch.cat((inputs_embeds, codec_emb), dim=1)
return inputs_embeds
@torch.no_grad()
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
**kwargs,
):
uttid = key[0]
inputs_emb = self.inference_prepare(data_in, **kwargs)
logging.info(f"{uttid}: min length: {kwargs['min_length']}, max length: {kwargs['max_length']}")
dtype = dtype_map[kwargs.get("llm_dtype", "fp32")]
if not hasattr(self, "llm_generator"):
llm_generator_conf = kwargs.get("dataset_conf", {})
self.llm_generator = LLMDecoder(
token_embeder=self.codec_embed,
**llm_generator_conf
).to(dtype)
if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
self.llm.base_model.model.lm_head = self.codec_head.to(dtype)
else:
self.llm.lm_head = self.codec_head.to(dtype)
gen_codec, hit_eos, states = self.llm_generator(
inputs_emb.to(dtype),
self.llm,
states=kwargs.get("states", {}),
**kwargs
)
output_dir = kwargs.get("output_dir", None)
if output_dir is not None:
output_dir = os.path.join(output_dir, "codec")
os.makedirs(output_dir, exist_ok=True)
np.save(
os.path.join(output_dir, f"{key[0]}.npy"),
gen_codec[0].cpu().numpy()
)
return gen_codec, hit_eos, states