import logging import torch import pickle import numpy as np from funcineforge.utils.hinter import hint_once from funcineforge.datasets import FunCineForgeDS from funcineforge.models import FunCineForgeSpecAug class FunCineForgeDataset(torch.utils.data.Dataset): """ Dataset for Mixed LM of FunCineForge """ def __init__( self, path, index_ds: str = None, frontend=None, tokenizer=None, face_encoder=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs, ): super().__init__() self.index_ds = FunCineForgeDS(path, **kwargs) self.tokenizer = tokenizer self.face_encoder = face_encoder self.int_pad_value = int_pad_value self.float_pad_value = float_pad_value self.batch_size = kwargs.get("batch_size") self.batch_type = kwargs.get("batch_type") self.retry = kwargs.get("retry", 100) # self.kwargs = kwargs self.max_token_length = kwargs.get("max_token_length", 1500) self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5) self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500) self.multiturn_num_max = kwargs.get("multiturn_num_max", 1) self.face_size = kwargs.get("face_size", 512) self.codebook_size = kwargs.get("codebook_size", 6561) self.sos = kwargs.get("sos", self.codebook_size) self.eos = kwargs.get("eos", self.codebook_size + 1) self.turn_of_speech = kwargs.get("turn_of_speech", self.codebook_size + 2) self.ignore_id = kwargs.get("ignore_id", -100) specaug = kwargs.get("specaug", None) specaug_conf = kwargs.get("specaug_conf", {}) if specaug is not None: specaug = FunCineForgeSpecAug(**specaug_conf) self.specaug = specaug self.set_invalid_xvec_zeros = kwargs.get("set_invalid_xvec_zeros", False) self.use_emotion_clue = kwargs.get("use_emotion_clue", False) logging.info(f"use_emotion_clue: {self.use_emotion_clue}") def get_source_len(self, index): item = self.index_ds[index] source_len = self.index_ds.get_source_len(item) return source_len def get_target_len(self, index): item = self.index_ds[index] return self.index_ds.get_target_len(item) def __len__(self): return len(self.index_ds) def mixup_text_codec(self, text: torch.Tensor, aug_codec: torch.Tensor, timespk_ids: torch.Tensor, type_id: int): text_len = text.shape[0] timespk_len = timespk_ids.shape[0] sequence = [self.sos, *text.tolist(), type_id, *timespk_ids.tolist(), self.turn_of_speech, *aug_codec.tolist(), self.eos] # sequence = [self.sos, *text.tolist(), type_id, self.turn_of_speech, *aug_codec.tolist(), self.eos] input_ids = torch.tensor(sequence, dtype=torch.int64) text_flag = torch.zeros(len(sequence), dtype=torch.float32) text_flag[1:text_len+1] = 1 timespk_flag = torch.zeros(len(sequence), dtype=torch.float32) timespk_flag[text_len+1:text_len+2+timespk_len] = 1 # timespk_flag[text_len+1:text_len+2] = 1 codec_flag = 1 - (text_flag + timespk_flag) labels = torch.tensor(sequence, dtype=torch.int64) labels[:text_len+timespk_len+3] = self.ignore_id # labels[:text_len+3] = self.ignore_id return input_ids, labels, text_flag, codec_flag, timespk_flag def __getitem__(self, index): output = None for idx in range(self.retry): if idx == 0: index_cur = index else: index_cur = torch.randint(0, len(self.index_ds), ()).item() item = self.index_ds[index_cur] # clue + text text = item["text"] clue = "<|startofclue|>" + item["clue"] + "<|endofclue|>" if self.use_emotion_clue: text = clue + text text_ids = torch.tensor(self.tokenizer.encode(text), dtype=torch.int32) hint_once(f"raw text: {text}", "log_text") # speech tokens target_out = item["token"] codec = torch.from_numpy(np.load(target_out)) codec_len = codec.shape[0] # 可用数据集中的 speech_length 代替 aug_codec = codec.clone() if self.specaug is not None: # aug_codec是随机mask的codec增强鲁棒性 aug_codec, _ = self.specaug(aug_codec.float().unsqueeze(0).unsqueeze(-1)) aug_codec = aug_codec.squeeze(0).squeeze(-1).long() # dialogue timespk_ids = torch.from_numpy(item["timespk_ids"]) # mixup type_id = item["type_id"] input_ids, labels, text_flag, codec_flag, timespk_flag = self.mixup_text_codec( text_ids, aug_codec, timespk_ids, type_id ) # face face_features = item["face"] face_emb = torch.zeros((codec_len, self.face_size), dtype=torch.float32) # face_emb 长度与 codec_len 相同 with open(face_features, 'rb') as f: stat_obj = pickle.load(f) embeddings = stat_obj['embeddings'] faceI = stat_obj['faceI'] for emb, frameI in zip(embeddings, faceI): fi = int(frameI) if 0 <= fi < codec_len: end = min(fi + 5, codec_len) face_emb[fi:end] = torch.from_numpy(emb).expand(end - fi, -1) # attention_mask 对应序列长度包括input_id=(sos, <|startofclue|>, clue, <|endofclue|>, text, type_id, timespk_ids, turn_of_speech, speech, eos) attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) codec_len = torch.tensor([codec_len], dtype=torch.int32) output = { "input_ids": input_ids, "face_emb": face_emb, "attention_mask": attention_mask, "labels_ids": labels, "text_flag": text_flag, "codec_flag": codec_flag, "timespk_flag": timespk_flag, "codec_len": codec_len, } break return output def collator(self, samples: list = None): for idx in range(self.retry): badcase_flag = False outputs = {} for sample in samples: if sample is None: continue for key in sample.keys(): if key not in outputs: outputs[key] = [] if isinstance(sample[key], (list, tuple)): outputs[key].extend(sample[key]) else: outputs[key].append(sample[key]) for key, data_list in outputs.items(): if isinstance(data_list[0], torch.Tensor): if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: pad_value = self.int_pad_value else: pad_value = self.float_pad_value outputs[key] = torch.nn.utils.rnn.pad_sequence( data_list, batch_first=True, padding_value=pad_value ) if self.batch_type != "example": b, t = outputs["input_ids"].shape if b > 1 and b * t > self.batch_size_token_max: logging.info( f"Warning, {idx}th, b*t: {b}*{t}={b * t} > batch_size_token_max: {self.batch_size_token_max}, drop last data" ) samples = samples[:-1] continue break return outputs