haiphamcse's picture
Upload folder using huggingface_hub
9855f47 verified
import json
import os
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from safetensors.torch import safe_open
from transformers import AutoModel
from ..vision_encoder.pe import CLIP as PeEncoder
from .aligner import AlignModalities
from .audio_codec import DacEncoderVAE
from .config import (
AudioEncoderConfig,
PEAudioFrameConfig,
PEAudioVisualConfig,
PEAudioVisualEncoderConfig,
VisualEncoderConfig,
)
from .transformer import BaseModelOutputWithPooling, Transformer
@dataclass
class AudioOutput(BaseModelOutputWithPooling):
audio_feature_padding_mask: Optional[torch.Tensor] = None
dac_vae_features: Optional[torch.Tensor] = None
@dataclass
class VisualOutput(BaseModelOutputWithPooling):
pe_output: Optional[torch.Tensor] = None
@dataclass
class AudioVisualOutput(BaseModelOutputWithPooling):
audio_output: Optional[AudioOutput] = None
visual_output: Optional[VisualOutput] = None
@dataclass
class PEAudioFrameOutput:
audio_embeds: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
spans: Optional[list[list[list[float]]]] = None
audio_output: Optional[AudioOutput] = None
text_output: Optional[BaseModelOutputWithPooling] = None
@dataclass
class PEAudioVisualOutput:
"""
Output embeddings and intermediate results from the PEAudioVisual model.
Attributes:
audio_embeds (Optional[torch.FloatTensor]): Embeddings for the audio modality.
audio_visual_embeds (Optional[torch.FloatTensor]): Embeddings for the combined audio-visual modality.
visual_embeds (Optional[torch.FloatTensor]): Embeddings for the visual modality.
audio_text_embeds (Optional[torch.FloatTensor]): Embeddings for the audio-text modality. This should be used for Audio <-> Text retrieval.
audio_visual_text_embeds (Optional[torch.FloatTensor]): Embeddings for the audio-visual-text modality. This should be used for Audio/Video <-> Text retrieval.
visual_text_embeds (Optional[torch.FloatTensor]): Embeddings for the visual-text modality. This should be used for Video <-> Text retrieval.
audio_plus_text_embeds (Optional[torch.FloatTensor]): Embeddings for combined audio and text features.
visual_plus_text_embeds (Optional[torch.FloatTensor]): Embeddings for combined visual and text features.
audio_visual_output (Optional[AudioVisualOutput]): Intermediate outputs from the audio-visual encoder.
text_output (Optional[BaseModelOutputWithPooling]): Intermediate outputs from the text encoder.
"""
audio_embeds: Optional[torch.FloatTensor] = None
audio_visual_embeds: Optional[torch.FloatTensor] = None
visual_embeds: Optional[torch.FloatTensor] = None
audio_text_embeds: Optional[torch.FloatTensor] = None
audio_visual_text_embeds: Optional[torch.FloatTensor] = None
visual_text_embeds: Optional[torch.FloatTensor] = None
audio_plus_text_embeds: Optional[torch.FloatTensor] = None
visual_plus_text_embeds: Optional[torch.FloatTensor] = None
audio_visual_output: Optional[AudioVisualOutput] = None
text_output: Optional[BaseModelOutputWithPooling] = None
class ContrastiveHead(torch.nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
) -> None:
super().__init__()
self.layer_norm = torch.nn.LayerNorm(normalized_shape=in_dim, eps=1e-6)
self.proj = torch.nn.Linear(in_dim, out_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(self.layer_norm(x))
class AVTransformer(Transformer):
def __init__(self, config):
super().__init__(config)
self.modality_aligner = AlignModalities(
self.config.hidden_size, self.config.hidden_size, normalize=True, btc=True
)
self.concat_modality_proj = torch.nn.Linear(
self.config.hidden_size * 2, self.config.hidden_size
)
self.data_proj = torch.nn.Linear(
self.config.hidden_size, self.config.hidden_size
)
def forward(
self,
audio: torch.Tensor,
video: torch.Tensor,
audio_padding_mask: Optional[torch.Tensor] = None,
video_padding_mask: Optional[torch.Tensor] = None,
):
video, video_padding_mask = self.modality_aligner(
audio, audio_padding_mask, video, video_padding_mask
)
x = torch.cat([audio, video], dim=-1)
x = self.concat_modality_proj(x)
return super().forward(self.data_proj(x), attention_mask=video_padding_mask)
class AudioEncoder(torch.nn.Module):
def __init__(self, config: AudioEncoderConfig):
super().__init__()
self.data_proj = torch.nn.Linear(
config.dac_vae_encoder.codebook_dim, config.audio_transformer.hidden_size
)
self.dac_vae_encoder = DacEncoderVAE(config.dac_vae_encoder)
self.audio_transformer = Transformer(config.audio_transformer)
def forward(
self,
input_values: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
input_features: Optional[torch.Tensor] = None, # codec_features
) -> AudioOutput:
if input_features is None:
codec_features = self.dac_vae_encoder(input_values).transpose(1, 2)
feature_padding_mask = None
if padding_mask is not None:
feature_padding_mask = padding_mask[
:, :: self.dac_vae_encoder.config.hop_length
]
else:
codec_features = input_features
feature_padding_mask = padding_mask
outputs = self.audio_transformer(
self.data_proj(codec_features), attention_mask=feature_padding_mask
)
return AudioOutput(
last_hidden_state=outputs.last_hidden_state,
pooler_output=outputs.pooler_output,
audio_feature_padding_mask=feature_padding_mask,
dac_vae_features=codec_features,
)
class VisualEncoder(torch.nn.Module):
def __init__(self, config: VisualEncoderConfig):
super().__init__()
# Note we only use the visual branch of the model. Throw the rest away to save space
self.pe_encoder = PeEncoder.from_config(
config.pe_encoder, pretrained=False
).visual
self.proj = torch.nn.Linear(
self.pe_encoder.output_dim,
config.visual_transformer.hidden_size,
bias=False,
)
self.data_proj = torch.nn.Linear(
config.visual_transformer.hidden_size, config.visual_transformer.hidden_size
)
self.visual_transformer = Transformer(config.visual_transformer)
def forward(
self,
pixel_values_videos: torch.Tensor,
padding_mask_videos: Optional[torch.Tensor] = None,
pe_features: Optional[torch.Tensor] = None,
) -> BaseModelOutputWithPooling:
B, N, C, H, W = pixel_values_videos.shape
if pe_features is None:
backbone_output = self.pe_encoder(
pixel_values_videos.view(B * N, C, H, W)
).view(B, N, -1)
pe_features = F.normalize(backbone_output, dim=-1)
projected = self.proj(pe_features)
output = self.visual_transformer(
self.data_proj(projected), attention_mask=padding_mask_videos
)
return VisualOutput(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
pe_output=pe_features,
)
class AudioVisualEncoder(torch.nn.Module):
def __init__(self, config: PEAudioVisualEncoderConfig):
super().__init__()
self.audio_model = AudioEncoder(config.audio_model)
self.visual_model = VisualEncoder(config.visual_model)
self.audio_visual_transformer = AVTransformer(config.audio_visual_transformer)
def forward(
self,
input_values: torch.Tensor,
pixel_values_videos: torch.Tensor,
pe_features: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
padding_mask_videos: Optional[torch.Tensor] = None,
input_features: Optional[torch.Tensor] = None, # codec_features
) -> AudioVisualOutput:
audio_output = self.audio_model(input_values, padding_mask=padding_mask, input_features=input_features)
video_output = self.visual_model(
pixel_values_videos, padding_mask_videos=padding_mask_videos, pe_features=pe_features
)
av_output = self.audio_visual_transformer(
audio_output.last_hidden_state,
video_output.last_hidden_state,
audio_padding_mask=audio_output.audio_feature_padding_mask,
video_padding_mask=padding_mask_videos,
)
return AudioVisualOutput(
last_hidden_state=av_output.last_hidden_state,
pooler_output=av_output.pooler_output,
audio_output=audio_output,
visual_output=video_output,
)
class BasePEAudio(torch.nn.Module):
@classmethod
def from_config(cls, name_or_checkpoint: str, pretrained: bool = False):
if os.path.isdir(name_or_checkpoint):
checkpoint_dir = name_or_checkpoint
else:
checkpoint_dir = snapshot_download(
repo_id=f"facebook/{name_or_checkpoint}", revision="perception_models"
)
config_path = os.path.join(checkpoint_dir, "config.json")
with open(config_path) as fin:
config_dict = json.load(fin)
config = cls.config_cls(**config_dict)
model = cls(config)
if pretrained:
checkpoint_path = os.path.join(checkpoint_dir, "model.safetensors")
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
model.load_state_dict({k: f.get_tensor(k) for k in f.keys()})
return model
class PEAudioVisual(BasePEAudio):
config_cls = PEAudioVisualConfig
def __init__(self, config: PEAudioVisualConfig):
super().__init__()
self.config = config
self.audio_visual_model = AudioVisualEncoder(config.audio_visual_model)
self.text_model = AutoModel.from_config(config.text_model)
self.audio_visual_text_head = ContrastiveHead(
config.text_model.hidden_size, config.output_dim
)
self.audio_text_head = ContrastiveHead(
config.text_model.hidden_size, config.output_dim
)
self.visual_text_head = ContrastiveHead(
config.text_model.hidden_size, config.output_dim
)
self.audio_visual_head = ContrastiveHead(
config.audio_visual_model.audio_visual_transformer.hidden_size,
config.output_dim,
)
self.audio_head = ContrastiveHead(
config.audio_visual_model.audio_model.audio_transformer.hidden_size,
config.output_dim,
)
self.visual_head = ContrastiveHead(
config.audio_visual_model.visual_model.visual_transformer.hidden_size,
config.output_dim,
)
self.visual_plus_text_head = ContrastiveHead(
config.audio_visual_model.visual_model.visual_transformer.hidden_size
+ config.text_model.hidden_size,
config.output_dim,
)
self.audio_plus_text_head = ContrastiveHead(
config.audio_visual_model.audio_model.audio_transformer.hidden_size
+ config.text_model.hidden_size,
config.output_dim,
)
def _get_text_output(self, input_ids, attention_mask):
nth_layer = self.config.nth_text_layer
output = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=nth_layer is not None,
)
if nth_layer is None:
text_model_output = output.last_hidden_state
else:
text_model_output = output.hidden_states[nth_layer]
return BaseModelOutputWithPooling(
last_hidden_state=text_model_output, pooler_output=text_model_output[:, 0]
)
def encode_video_text(self, input_ids, attention_mask=None):
text_outputs = self._get_text_output(input_ids, attention_mask)
return self.visual_text_head(text_outputs.pooler_output)
def encode_audio_text(self, input_ids, attention_mask=None):
text_outputs = self._get_text_output(input_ids, attention_mask)
return self.audio_text_head(text_outputs.pooler_output)
def encode_audio_video_text(self, input_ids, attention_mask=None):
text_outputs = self._get_text_output(input_ids, attention_mask)
return self.audio_visual_text_head(text_outputs.pooler_output)
def encode_audio(self, input_values, padding_mask=None, input_features=None):
audio_outputs = self.audio_visual_model.audio_model(
input_values, padding_mask=padding_mask, input_features=input_features
)
return self.audio_head(audio_outputs.pooler_output)
def encode_video(self, pixel_values_videos, padding_mask_videos=None, pe_features=None):
video_outputs = self.audio_visual_model.visual_model(
pixel_values_videos, padding_mask_videos=padding_mask_videos, pe_features=pe_features
)
return self.visual_head(video_outputs.pooler_output)
def encode_audio_video(
self,
input_values,
pixel_values_videos,
padding_mask=None,
padding_mask_videos=None,
pe_features=None,
input_features=None,
):
audio_video_outputs = self.audio_visual_model(
input_values,
pixel_values_videos,
padding_mask=padding_mask,
padding_mask_videos=padding_mask_videos,
pe_features=pe_features,
input_features=input_features,
)
return self.audio_visual_head(audio_video_outputs.pooler_output)
def encode_audio_plus_text(
self, input_ids, input_values, attention_mask=None, padding_mask=None, input_features=None
):
text_outputs = self._get_text_output(input_ids, attention_mask)
audio_outputs = self.audio_visual_model.audio_model(
input_values, padding_mask=padding_mask, input_features=input_features
)
return self.audio_plus_text_head(
torch.cat(
[audio_outputs.pooler_output, text_outputs.pooler_output],
dim=-1,
)
)
def encode_video_plus_text(
self,
input_ids,
pixel_values_videos,
attention_mask=None,
padding_mask_videos=None,
pe_features=None,
):
text_outputs = self._get_text_output(input_ids, attention_mask)
video_outputs = self.audio_visual_model.visual_model(
pixel_values_videos, padding_mask_videos=padding_mask_videos, pe_features=pe_features
)
return self.visual_plus_text_head(
torch.cat(
[video_outputs.pooler_output, text_outputs.pooler_output],
dim=-1,
)
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
input_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
padding_mask_videos: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
pe_features: Optional[torch.Tensor] = None,
input_features: Optional[torch.Tensor] = None,
return_loss=False,
) -> PEAudioVisualOutput:
# text embeddings
audio_text_embeds = visual_text_embeds = audio_visual_text_embeds = None
# media embeddings (audio, video, audio_video)
audio_embeds = visual_embeds = audio_visual_embeds = None
# media + text embeddings
audio_plus_text_embeds = visual_plus_text_embeds = None
audio_visual_outputs = None
# Compute model outputs and embeddings for each modality
text_outputs = None
if input_ids is not None:
text_outputs = self._get_text_output(input_ids, attention_mask)
if input_values is not None and pixel_values_videos is not None:
# If we compute audio/video outputs, then extract the intermediate audio and video outputs
audio_visual_outputs = self.audio_visual_model(
input_values,
pixel_values_videos,
padding_mask=padding_mask,
padding_mask_videos=padding_mask_videos,
pe_features=pe_features,
input_features=input_features,
)
audio_outputs = audio_visual_outputs.audio_output
video_outputs = audio_visual_outputs.visual_output
audio_embeds = self.audio_head(audio_outputs.pooler_output)
visual_embeds = self.visual_head(video_outputs.pooler_output)
audio_visual_embeds = self.audio_visual_head(
audio_visual_outputs.pooler_output
)
if text_outputs is not None:
# Compute the corresponding text embeddings
audio_text_embeds = self.audio_text_head(text_outputs.pooler_output)
visual_text_embeds = self.visual_text_head(text_outputs.pooler_output)
audio_visual_text_embeds = self.audio_visual_text_head(
text_outputs.pooler_output
)
audio_plus_text_embeds = self.audio_plus_text_head(
torch.cat(
[audio_outputs.pooler_output, text_outputs.pooler_output],
dim=-1,
)
)
visual_plus_text_embeds = self.visual_plus_text_head(
torch.cat(
[video_outputs.pooler_output, text_outputs.pooler_output],
dim=-1,
)
)
else:
if pixel_values_videos is not None:
video_outputs = self.audio_visual_model.visual_model(
pixel_values_videos, padding_mask_videos=padding_mask_videos, pe_features=pe_features
)
audio_visual_outputs = AudioVisualOutput(visual_output=video_outputs)
visual_embeds = self.visual_head(video_outputs.pooler_output)
if text_outputs is not None:
visual_text_embeds = self.visual_text_head(
text_outputs.pooler_output
)
visual_plus_text_embeds = self.visual_plus_text_head(
torch.cat(
[video_outputs.pooler_output, text_outputs.pooler_output],
dim=-1,
)
)
elif input_values is not None:
audio_outputs = self.audio_visual_model.audio_model(
input_values, padding_mask=padding_mask, input_features=input_features
)
audio_visual_outputs = AudioVisualOutput(audio_output=audio_outputs)
audio_embeds = self.audio_head(audio_outputs.pooler_output)
if text_outputs is not None:
audio_text_embeds = self.audio_text_head(text_outputs.pooler_output)
audio_plus_text_embeds = self.audio_plus_text_head(
torch.cat(
[audio_outputs.pooler_output, text_outputs.pooler_output],
dim=-1,
)
)
elif text_outputs is not None:
# If text is supplied, but no audio or video, use audio_video_text as the default embedding
audio_visual_text_embeds = self.audio_visual_text_head(
text_outputs.pooler_output
)
return PEAudioVisualOutput(
audio_embeds=audio_embeds,
audio_visual_embeds=audio_visual_embeds,
visual_embeds=visual_embeds,
audio_text_embeds=audio_text_embeds,
audio_visual_text_embeds=audio_visual_text_embeds,
visual_text_embeds=visual_text_embeds,
audio_plus_text_embeds=audio_plus_text_embeds,
visual_plus_text_embeds=visual_plus_text_embeds,
audio_visual_output=audio_visual_outputs,
text_output=text_outputs,
)
class PEAudioFrame(BasePEAudio):
config_cls = PEAudioFrameConfig
def __init__(self, config: PEAudioFrameConfig):
super().__init__()
self.config = config
self.text_model = AutoModel.from_config(config.text_model)
self.audio_model = AudioEncoder(config.audio_model)
self.text_head = ContrastiveHead(
config.text_model.hidden_size, config.output_dim
)
self.audio_head = ContrastiveHead(
config.audio_model.audio_transformer.hidden_size, config.output_dim
)
self.logit_scale = torch.nn.Parameter(torch.tensor([0.0]))
self.logit_bias = torch.nn.Parameter(torch.tensor([0.0]))
def _get_text_output(self, input_ids, attention_mask):
nth_layer = self.config.nth_text_layer
output = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=nth_layer is not None,
)
if nth_layer is None:
text_model_output = output.last_hidden_state
else:
text_model_output = output.hidden_states[nth_layer]
return BaseModelOutputWithPooling(
last_hidden_state=text_model_output, pooler_output=text_model_output[:, 0]
)
def forward(
self,
input_ids: torch.Tensor, # tokenized text
input_values: Optional[torch.Tensor] = None, # audio waveform (may be None if input_features is provided)
input_features: Optional[torch.Tensor] = None, # codec_features (if already computed)
attention_mask: Optional[torch.Tensor] = None, # text attention mask
padding_mask: Optional[torch.Tensor] = None, # audio padding mask
threshold: float = 0.3,
return_spans: bool = True,
) -> PEAudioFrameOutput:
audio_output = self.audio_model(input_values, padding_mask, input_features=input_features)
text_model_output = self._get_text_output(input_ids, attention_mask)
text_embeds = self.text_head(text_model_output.pooler_output)
audio_embeds = self.audio_head(audio_output.last_hidden_state)
spans = None
if return_spans:
bsz = input_ids.size(0)
unscaled_logits = audio_embeds @ text_embeds.unsqueeze(1).transpose(-1, -2)
logits = unscaled_logits.squeeze(-1) * self.logit_scale + self.logit_bias
probs = logits.sigmoid()
preds = probs > threshold
# Find where predictions changed from False->True and True->False
changes = torch.diff(F.pad(preds, (1, 1), value=False), dim=1).nonzero()
span_tensor = torch.cat([changes[::2], changes[1::2, [1]]], dim=1)
# Convert audio frame index to time
dac_config = self.config.audio_model.dac_vae_encoder
spans = [
(
span_tensor[span_tensor[:, 0] == i, 1:]
* dac_config.hop_length
/ dac_config.sampling_rate
).tolist()
for i in range(bsz)
]
return PEAudioFrameOutput(
text_embeds=text_embeds,
audio_embeds=audio_embeds,
spans=spans,
text_output=text_model_output,
audio_output=audio_output,
)
__all__ = ["PEAudioVisual", "PEAudioFrame"]