import torch import torch.nn.functional as F from lm_eval.api.registry import register_model from lm_eval.models.huggingface import HFLM @register_model("cloverlm") class CloverLMHFLM(HFLM): def __init__(self, pad_multiple=128, **kwargs): super().__init__(**kwargs) self.pad_multiple = pad_multiple def _encode_pair(self, context, continuation): context_enc, continuation_enc = super()._encode_pair(context, continuation) if not continuation_enc and continuation: whole_enc = self.tok_encode(context + continuation) if len(whole_enc) > 1: continuation_enc = whole_enc[-1:] context_enc = whole_enc[:-1] elif whole_enc: continuation_enc = whole_enc context_enc = [self.prefix_token_id] else: continuation_enc = [self.prefix_token_id] context_enc = [self.prefix_token_id] return context_enc, continuation_enc def _model_call(self, inps: torch.Tensor, attn_mask: torch.Tensor = None, **kwargs): orig_len = inps.shape[1] remainder = orig_len % self.pad_multiple if remainder != 0: pad_len = self.pad_multiple - remainder inps = F.pad(inps, (0, pad_len), value=self.tokenizer.pad_token_id) if attn_mask is not None: attn_mask = F.pad(attn_mask, (0, pad_len), value=0) logits = super()._model_call(inps, attn_mask=attn_mask, **kwargs) if remainder != 0: logits = logits[:, :orig_len, :] return logits def _model_generate(self, context, max_length, **kwargs): orig_len = context.shape[1] remainder = orig_len % self.pad_multiple if remainder != 0: pad_len = self.pad_multiple - remainder context = F.pad(context, (pad_len, 0), value=self.tokenizer.pad_token_id) if "attention_mask" in kwargs and kwargs["attention_mask"] is not None: kwargs["attention_mask"] = F.pad(kwargs["attention_mask"], (pad_len, 0), value=0) out = super()._model_generate(context, max_length, **kwargs) if remainder != 0: out = out[:, pad_len:] return out if __name__ == "__main__": from lm_eval.__main__ import cli_evaluate cli_evaluate()