File size: 1,570 Bytes
05a82cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers.models.bert import modeling_bert
from open_clip import CustomTextCLIP
from open_clip.hf_model import HFTextEncoder
import torch.nn.functional as F
from torch import TensorType

def patch_encode_text():

    def encode_text_patched(self, text, normalize: bool = False, output_attentions = False, output_tokens = False):
        if output_attentions:
                features, attn_scores = self.text(text, output_attentions = output_attentions, output_tokens = output_tokens)
                features = F.normalize(features, dim=-1) if normalize else features
                return features, attn_scores
        else:
            features = self.text(text, output_attentions = output_attentions, output_tokens = output_tokens)
            return F.normalize(features, dim=-1) if normalize else features

    def HFText_encoder_patched(self, x: TensorType, output_attentions=False, output_tokens=False):
        self.output_tokens = output_tokens
        attn_mask = (x != self.config.pad_token_id).long()
        out = self.transformer(input_ids=x, attention_mask=attn_mask, output_attentions=output_attentions)
        if self.output_tokens:
            tokens = self.proj(out[0])
            if output_attentions:
                return tokens, out[1]
            else:
                return tokens
        else:
            pooled_out = self.pooler(out, attn_mask)
            projected = self.proj(pooled_out)

            return projected

    CustomTextCLIP.encode_text = encode_text_patched
    HFTextEncoder.forward = HFText_encoder_patched