Spaces:
Running on Zero
Running on Zero
| from contextlib import nullcontext | |
| import torch | |
| import torch.nn as nn | |
| from typing import Union | |
| from funcineforge.utils.hinter import hint_once | |
| import numpy as np | |
| dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} | |
| class LLMDecoder(nn.Module): | |
| def __init__(self, **kwargs): | |
| super(LLMDecoder, self).__init__() | |
| self.eos_token = kwargs["eos"] | |
| if isinstance(self.eos_token, int): | |
| self.eos_token = [self.eos_token] | |
| self.token_embeder = kwargs["token_embeder"] | |
| self.ras_conf = kwargs.get("ras_conf", {}) | |
| self.token_offset = kwargs.get("token_offset", 0) | |
| def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25, beam_size=1): | |
| prob, indices = [], [] | |
| cum_prob = 0.0 | |
| sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) | |
| for i in range(len(sorted_idx)): | |
| # sampling both top-p and numbers. | |
| if cum_prob < top_p and len(prob) < top_k: | |
| cum_prob += sorted_value[i] | |
| prob.append(sorted_value[i]) | |
| indices.append(sorted_idx[i]) | |
| else: | |
| break | |
| prob = torch.tensor(prob).to(weighted_scores) | |
| indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) | |
| sampling_ids = prob.multinomial(beam_size, replacement=True) | |
| top_ids = indices[sampling_ids] | |
| return top_ids | |
| def random_sampling(self, weighted_scores, beam_size=1): | |
| top_ids = weighted_scores.softmax(dim=0).multinomial(beam_size, replacement=True) | |
| return top_ids | |
| # Repetition Aware Sampling in VALL-E 2 | |
| def ras_sampling( | |
| self, weighted_scores, decoded_tokens, *, | |
| top_p=0.8, top_k=25, win_size=10, tau_r=0.1 | |
| ): | |
| if self.ras_conf is not None: | |
| top_p = self.ras_conf.get("top_p", top_p) | |
| top_k = self.ras_conf.get("top_k", top_k) | |
| win_size = self.ras_conf.get("win_size", win_size) | |
| tau_r = self.ras_conf.get("tau_r", tau_r) | |
| hint_once(f"using Repetition Aware Sampling: top_p: {top_p}, top_k: {top_k},win_size: {win_size}, tau_r: {tau_r}", "ras_sampling") | |
| top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) | |
| rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item() | |
| if rep_num >= win_size * tau_r: | |
| top_ids = self.random_sampling(weighted_scores) | |
| return top_ids | |
| def sampling_ids( | |
| self, | |
| weighted_scores: torch.Tensor, | |
| sampling: Union[bool, int, float] = True, | |
| decoded_tokens: list = None, | |
| ): | |
| if isinstance(sampling, bool): | |
| if sampling: | |
| top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) | |
| else: | |
| top_ids = weighted_scores.topk(1)[1] | |
| elif isinstance(sampling, int): | |
| prob, indices = weighted_scores.softmax(dim=0).topk(sampling) | |
| sampling_ids = prob.multinomial(1, replacement=True) | |
| top_ids = indices[sampling_ids] | |
| elif isinstance(sampling, float): | |
| prob, indices = [], [] | |
| cum_prob = 0.0 | |
| sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) | |
| for i in range(len(sorted_idx)): | |
| # sampling both top-p and numbers. | |
| if cum_prob < sampling and len(prob) < 25: | |
| cum_prob += sorted_value[i] | |
| prob.append(sorted_value[i]) | |
| indices.append(sorted_idx[i]) | |
| else: | |
| break | |
| prob = torch.tensor(prob).to(weighted_scores) | |
| indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) | |
| sampling_ids = prob.multinomial(1, replacement=True) | |
| top_ids = indices[sampling_ids] | |
| elif isinstance(sampling, str) and sampling.lower() == "ras": | |
| top_ids = self.ras_sampling(weighted_scores, decoded_tokens=decoded_tokens) | |
| else: | |
| raise NotImplementedError(f"Not implemented for {type(sampling)} sampling") | |
| return top_ids | |
| def __call__(self, input_embeddings, llm, states, quantize=False, **kwargs): | |
| max_length = kwargs.get("max_length", 60 * 25) | |
| min_length = kwargs.get("min_length", 2 * 25) | |
| sampling = kwargs.get("sampling", True) | |
| device = kwargs.get("device", "cuda") | |
| llm_dtype = kwargs.get("llm_dtype", "fp32") | |
| use_llm_cache = kwargs.get("use_llm_cache", True) | |
| include_eos = kwargs.get("include_eos", False) | |
| custom_eos_token = kwargs.get("custom_eos_token", self.eos_token) | |
| avoid_token = kwargs.get("avoid_token", None) | |
| llm_cache = states.get("llm_cache", None) | |
| out_tokens, hit_eos = [], False | |
| for i in range(max_length): | |
| with torch.cuda.amp.autocast( | |
| enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] | |
| ) if quantize is False else nullcontext(): | |
| # default attention_mask is causal, no longer need manually construct | |
| # input_masks = torch.ones((1, input_embeddings.shape[1]), device=input_embeddings.device).to(torch.bool) | |
| if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)): | |
| outputs = llm.base_model.model( | |
| inputs_embeds=input_embeddings.to(torch.bfloat16) if quantize is True else input_embeddings, | |
| # attention_mask=input_masks, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| use_cache=use_llm_cache, | |
| past_key_values=llm_cache, | |
| ) | |
| else: | |
| outputs = llm( | |
| inputs_embeds=input_embeddings.to(torch.bfloat16) if quantize is True else input_embeddings, | |
| # attention_mask=input_masks, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| use_cache=use_llm_cache, | |
| past_key_values=llm_cache, | |
| ) | |
| lm_hidden_states = outputs.hidden_states[-1] | |
| h = llm.lm_head(lm_hidden_states[:, -1]) | |
| # logp = h.log_softmax(dim=-1).squeeze(0) | |
| logp = h.squeeze(0) | |
| if use_llm_cache: | |
| llm_cache = outputs.past_key_values | |
| pred = torch.log_softmax(logp, dim=-1) | |
| if min_length is not None and i < min_length: | |
| for x in custom_eos_token: | |
| if pred.dtype == torch.bfloat16: | |
| pred[x] = float(np.finfo(np.float16).min) | |
| else: | |
| pred[x] = float(np.finfo(np.float32).min) | |
| if avoid_token is not None and len(avoid_token) > 0: | |
| for x in avoid_token: | |
| if pred.dtype == torch.bfloat16: | |
| pred[x] = float(np.finfo(np.float16).min) | |
| else: | |
| pred[x] = float(np.finfo(np.float32).min) | |
| top_id = self.sampling_ids(pred, sampling, out_tokens)[0].item() | |
| if top_id in custom_eos_token: | |
| if include_eos: | |
| out_tokens.append(top_id) | |
| hit_eos = True | |
| break | |
| out_tokens.append(top_id) | |
| if use_llm_cache: | |
| input_embeddings = self.token_embeder(torch.tensor([[top_id]], dtype=torch.int64, device=device) + self.token_offset) | |
| else: | |
| input_embeddings = torch.cat([ | |
| input_embeddings, | |
| self.token_embeder(torch.tensor([[top_id]], dtype=torch.int64, device=device) + self.token_offset) | |
| ], dim=1) | |
| out_tokens = torch.tensor([out_tokens], dtype=torch.int64, device=device) | |
| states = {"llm_cache": llm_cache} | |
| return out_tokens, hit_eos, states | |