from transformers.models.bert import modeling_bert from open_clip import CustomTextCLIP from open_clip.hf_model import HFTextEncoder import torch.nn.functional as F from torch import TensorType def patch_encode_text(): def encode_text_patched(self, text, normalize: bool = False, output_attentions = False, output_tokens = False): if output_attentions: features, attn_scores = self.text(text, output_attentions = output_attentions, output_tokens = output_tokens) features = F.normalize(features, dim=-1) if normalize else features return features, attn_scores else: features = self.text(text, output_attentions = output_attentions, output_tokens = output_tokens) return F.normalize(features, dim=-1) if normalize else features def HFText_encoder_patched(self, x: TensorType, output_attentions=False, output_tokens=False): self.output_tokens = output_tokens attn_mask = (x != self.config.pad_token_id).long() out = self.transformer(input_ids=x, attention_mask=attn_mask, output_attentions=output_attentions) if self.output_tokens: tokens = self.proj(out[0]) if output_attentions: return tokens, out[1] else: return tokens else: pooled_out = self.pooler(out, attn_mask) projected = self.proj(pooled_out) return projected CustomTextCLIP.encode_text = encode_text_patched HFTextEncoder.forward = HFText_encoder_patched