| import math |
| import argparse |
| import torch |
| import random |
|
|
| from eval_utils import get_test_dataset |
| from .modeling_bitnet import BitnetForCausalLM |
| from .tokenization_bitnet import BitnetTokenizer |
|
|
| from tqdm import tqdm |
| torch.set_grad_enabled(False) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--seed', default=0, type=int) |
| parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) |
| parser.add_argument('--seqlen', default=2048, type=int) |
|
|
|
|
| def calulate_loss(model, input, loss_fct): |
| output = model(input, |
| use_cache=False, |
| output_hidden_states=False, |
| output_attentions=False)[0] |
| shift_logits = output[:, :-1, :].contiguous() |
| shift_labels = input[:, 1:] |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| return loss |
|
|
|
|
| def main(args): |
| datasets = ['c4', 'wikitext2'] |
| model = BitnetForCausalLM.from_pretrained( |
| args.hf_path, |
| device_map='auto', |
| low_cpu_mem_usage=True, |
| use_flash_attention_2=True, |
| torch_dtype=torch.float16, |
| ).half() |
| tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) |
| loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda() |
|
|
| ppl = [] |
| for dataset in datasets: |
| testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen) |
| acc_loss, count = 0.0, 0 |
| progress = tqdm(range(len(testdata))) |
| for ii in progress: |
| input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) |
| loss = calulate_loss(model, input, loss_fct) |
| count += (input.size(-1) - 1) |
| acc_loss += loss.item() |
| progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}") |
|
|
| avg_loss = acc_loss / count / math.log(2) |
| ppl.append(2 ** avg_loss) |
| print("{} PPL: {}".format(dataset, ppl[-1])) |
|
|
| print(ppl) |
| print("Avg PPL:", sum(ppl) / len(ppl)) |
|
|
|
|
| if __name__ == '__main__': |
| torch.set_grad_enabled(False) |
| args = parser.parse_args() |
| random.seed(args.seed) |
| torch.random.manual_seed(args.seed) |
| main(args) |