deepfake-server / src /models /backbones.py
DevQueen's picture
Sync from GitHub via hub-sync
1dc2504 verified
Raw
History Blame Contribute Delete
455 Bytes
from __future__ import annotations
import timm
import torch
import torch.nn as nn
def build_backbone(name: str, pretrained: bool = True) -> nn.Module:
model = timm.create_model(name, pretrained=pretrained, num_classes=0, global_pool="avg")
return model
@torch.no_grad()
def infer_feature_dim(backbone: nn.Module, image_size: int = 224) -> int:
x = torch.randn(1, 3, image_size, image_size)
y = backbone(x)
return int(y.shape[-1])