| import os |
|
|
| import torch |
| import transformers |
| from tqdm import tqdm |
|
|
| import diffusion |
|
|
|
|
| def compute_ppl( |
| pretrained_model, |
| val_ds |
| ): |
| ppl_metrics = diffusion.Perplexity().to('cuda') |
| pbar = tqdm(val_ds, desc='PPL') |
| for batch in pbar: |
| input_ids = batch['input_ids'].to('cuda') |
| if 'attention_mask' in batch: |
| attention_mask = batch['attention_mask'].to('cuda') |
| else: |
| attention_mask = None |
| losses = pretrained_model._loss(input_ids, attention_mask) |
| ppl_metrics.update(losses.nlls, losses.token_mask) |
| pbar.set_postfix({'ppl': ppl_metrics.compute().item()}) |
| return ppl_metrics.compute().item() |
|
|
|
|
| def compute_generative_ppl( |
| sentences, |
| eval_model_name_or_path, |
| gen_ppl_eval_batch_size=8, |
| max_length=128): |
| gen_ppl_metric = diffusion.Perplexity().to('cuda') |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
| eval_model_tokenizer = transformers.AutoTokenizer.from_pretrained( |
| eval_model_name_or_path) |
| if eval_model_tokenizer.pad_token is None: |
| eval_model_tokenizer.pad_token = \ |
| eval_model_tokenizer.eos_token |
| eval_model_tokenizer.pad_token_id = \ |
| eval_model_tokenizer.eos_token_id |
| eval_model = transformers.AutoModelForCausalLM.from_pretrained( |
| eval_model_name_or_path).eval() |
| if max_length is None: |
| max_length = max_length |
| eval_model = eval_model.to('cuda') |
| |
| tokenizer_kwargs = { |
| 'return_tensors': 'pt', |
| 'return_token_type_ids': False, |
| 'return_attention_mask': True, |
| 'truncation': True, |
| 'padding': True, |
| 'max_length': max_length, |
| } |
| eval_context_size = 1024 |
| samples = eval_model_tokenizer( |
| sentences, **tokenizer_kwargs) |
| attn_mask = samples['attention_mask'] |
| samples = samples['input_ids'] |
| attn_mask = attn_mask.to('cuda') |
| samples = samples.to('cuda') |
| num_batches = samples.shape[0] // gen_ppl_eval_batch_size |
| for i in tqdm(range(num_batches), |
| desc='Gen. PPL', leave=False): |
| _samples = torch.split( |
| samples[i * gen_ppl_eval_batch_size: (i + 1) * gen_ppl_eval_batch_size], |
| eval_context_size, |
| dim=-1) |
| _attn_mask = torch.split( |
| attn_mask[i * gen_ppl_eval_batch_size: (i + 1) * gen_ppl_eval_batch_size], |
| eval_context_size, |
| dim=-1) |
| for (sample_chunk, attn_mask_chunk) in zip( |
| _samples, _attn_mask): |
| logits = eval_model( |
| sample_chunk, attention_mask=attn_mask_chunk)[0] |
| logits = logits.transpose(-1, -2) |
|
|
| nlls = torch.nn.functional.cross_entropy( |
| logits[..., :-1], |
| sample_chunk[..., 1:], |
| reduction='none') |
| |
| |
| |
| |
| gen_ppl_metric.update( |
| nlls, attn_mask_chunk[..., 1:]) |
| return gen_ppl_metric.compute().item() |
|
|