File size: 2,413 Bytes
a29dc33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import torch
import torch.nn.functional as F
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
@register_model("cloverlm")
class CloverLMHFLM(HFLM):
def __init__(self, pad_multiple=128, **kwargs):
super().__init__(**kwargs)
self.pad_multiple = pad_multiple
def _encode_pair(self, context, continuation):
context_enc, continuation_enc = super()._encode_pair(context, continuation)
if not continuation_enc and continuation:
whole_enc = self.tok_encode(context + continuation)
if len(whole_enc) > 1:
continuation_enc = whole_enc[-1:]
context_enc = whole_enc[:-1]
elif whole_enc:
continuation_enc = whole_enc
context_enc = [self.prefix_token_id]
else:
continuation_enc = [self.prefix_token_id]
context_enc = [self.prefix_token_id]
return context_enc, continuation_enc
def _model_call(self, inps: torch.Tensor, attn_mask: torch.Tensor = None, **kwargs):
orig_len = inps.shape[1]
remainder = orig_len % self.pad_multiple
if remainder != 0:
pad_len = self.pad_multiple - remainder
inps = F.pad(inps, (0, pad_len), value=self.tokenizer.pad_token_id)
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, pad_len), value=0)
logits = super()._model_call(inps, attn_mask=attn_mask, **kwargs)
if remainder != 0:
logits = logits[:, :orig_len, :]
return logits
def _model_generate(self, context, max_length, **kwargs):
orig_len = context.shape[1]
remainder = orig_len % self.pad_multiple
if remainder != 0:
pad_len = self.pad_multiple - remainder
context = F.pad(context, (pad_len, 0), value=self.tokenizer.pad_token_id)
if "attention_mask" in kwargs and kwargs["attention_mask"] is not None:
kwargs["attention_mask"] = F.pad(kwargs["attention_mask"], (pad_len, 0), value=0)
out = super()._model_generate(context, max_length, **kwargs)
if remainder != 0:
out = out[:, pad_len:]
return out
if __name__ == "__main__":
from lm_eval.__main__ import cli_evaluate
cli_evaluate()
|