import torch import torch.nn as nn from models.CLAP.open_clip import create_model from transformers import RobertaTokenizer # removed: import random, torchaudio, get_audio_features — only used by _get_audio_embed (audio modality), not needed for text inference class CLAP_Encoder(nn.Module): def __init__( self, pretrained_path='checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt', amodel = "HTSAT-base", ): super().__init__() self.device = "cpu" self.precision = "fp32" self.amodel = amodel # or 'PANN-14' self.tmodel = "roberta" # the best text encoder in our training self.enable_fusion = False # False if you do not want to use the fusion model self.fusion_type = "aff_2d" self.pretrained = pretrained_path # removed: self.sampling_rate — only used by _get_audio_embed self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") self.model, self.model_cfg = create_model( self.amodel, self.tmodel, self.pretrained, precision=self.precision, device=self.device, enable_fusion=self.enable_fusion, fusion_type=self.fusion_type, ) for p in self.model.parameters(): p.requires_grad = False self.model.eval() self.encoder_type = 'CLAP' # removed: batch_to_list, _get_audio_embed — audio modality not used in inference def _get_text_embed(self, batch): double_batch = False if len(batch) == 1: batch = batch * 2 double_batch = True with torch.no_grad(): # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode text_data = self.tokenizer(batch) embed = self.model.get_text_embedding(text_data) if double_batch: embed = embed[0].unsqueeze(0) return embed.detach() def get_query_embed(self, modality, audio=None, text=None, use_text_ratio=0.5, device=None): # removed: audio and hybird modality branches — only text modality used in inference if modality == 'text': embed = self._get_text_embed(text) else: raise NotImplementedError("Please check flag 'training_modality'.") return embed.float() def tokenizer(self, text): result = self.tokenize( text, padding="max_length", truncation=True, max_length=512, return_tensors="pt", ) return {k: v.squeeze(0) for k, v in result.items()}