CloverLM / lm_eval /eval.py
BlackSamorez's picture
Upload folder using huggingface_hub
a29dc33 verified
import torch
import torch.nn.functional as F
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
@register_model("cloverlm")
class CloverLMHFLM(HFLM):
def __init__(self, pad_multiple=128, **kwargs):
super().__init__(**kwargs)
self.pad_multiple = pad_multiple
def _encode_pair(self, context, continuation):
context_enc, continuation_enc = super()._encode_pair(context, continuation)
if not continuation_enc and continuation:
whole_enc = self.tok_encode(context + continuation)
if len(whole_enc) > 1:
continuation_enc = whole_enc[-1:]
context_enc = whole_enc[:-1]
elif whole_enc:
continuation_enc = whole_enc
context_enc = [self.prefix_token_id]
else:
continuation_enc = [self.prefix_token_id]
context_enc = [self.prefix_token_id]
return context_enc, continuation_enc
def _model_call(self, inps: torch.Tensor, attn_mask: torch.Tensor = None, **kwargs):
orig_len = inps.shape[1]
remainder = orig_len % self.pad_multiple
if remainder != 0:
pad_len = self.pad_multiple - remainder
inps = F.pad(inps, (0, pad_len), value=self.tokenizer.pad_token_id)
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, pad_len), value=0)
logits = super()._model_call(inps, attn_mask=attn_mask, **kwargs)
if remainder != 0:
logits = logits[:, :orig_len, :]
return logits
def _model_generate(self, context, max_length, **kwargs):
orig_len = context.shape[1]
remainder = orig_len % self.pad_multiple
if remainder != 0:
pad_len = self.pad_multiple - remainder
context = F.pad(context, (pad_len, 0), value=self.tokenizer.pad_token_id)
if "attention_mask" in kwargs and kwargs["attention_mask"] is not None:
kwargs["attention_mask"] = F.pad(kwargs["attention_mask"], (pad_len, 0), value=0)
out = super()._model_generate(context, max_length, **kwargs)
if remainder != 0:
out = out[:, pad_len:]
return out
if __name__ == "__main__":
from lm_eval.__main__ import cli_evaluate
cli_evaluate()