File size: 2,413 Bytes
a29dc33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn.functional as F

from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM


@register_model("cloverlm")
class CloverLMHFLM(HFLM):
    def __init__(self, pad_multiple=128, **kwargs):
        super().__init__(**kwargs)
        self.pad_multiple = pad_multiple

    def _encode_pair(self, context, continuation):
        context_enc, continuation_enc = super()._encode_pair(context, continuation)

        if not continuation_enc and continuation:
            whole_enc = self.tok_encode(context + continuation)
            if len(whole_enc) > 1:
                continuation_enc = whole_enc[-1:]
                context_enc = whole_enc[:-1]
            elif whole_enc:
                continuation_enc = whole_enc
                context_enc = [self.prefix_token_id]
            else:
                continuation_enc = [self.prefix_token_id]
                context_enc = [self.prefix_token_id]

        return context_enc, continuation_enc

    def _model_call(self, inps: torch.Tensor, attn_mask: torch.Tensor = None, **kwargs):
        orig_len = inps.shape[1]
        remainder = orig_len % self.pad_multiple
        
        if remainder != 0:
            pad_len = self.pad_multiple - remainder
            inps = F.pad(inps, (0, pad_len), value=self.tokenizer.pad_token_id)
            if attn_mask is not None:
                attn_mask = F.pad(attn_mask, (0, pad_len), value=0)
                
        logits = super()._model_call(inps, attn_mask=attn_mask, **kwargs)
        if remainder != 0:
            logits = logits[:, :orig_len, :]
        return logits

    def _model_generate(self, context, max_length, **kwargs):
        orig_len = context.shape[1]
        remainder = orig_len % self.pad_multiple
        
        if remainder != 0:
            pad_len = self.pad_multiple - remainder
            context = F.pad(context, (pad_len, 0), value=self.tokenizer.pad_token_id)
            if "attention_mask" in kwargs and kwargs["attention_mask"] is not None:
                kwargs["attention_mask"] = F.pad(kwargs["attention_mask"], (pad_len, 0), value=0)
                
        out = super()._model_generate(context, max_length, **kwargs)
        if remainder != 0:
            out = out[:, pad_len:]
            
        return out


if __name__ == "__main__":
    from lm_eval.__main__ import cli_evaluate
    cli_evaluate()