SkinTokens / src /model /parse_encoder.py
pookiefoof's picture
Public release: SkinTokens 路 TokenRig demo
9d7cf7f
from copy import deepcopy
from dataclasses import dataclass
from .michelangelo.get_model import get_encoder as get_encoder_michelangelo
from .michelangelo.get_model import AlignedShapeLatentPerceiver
from .michelangelo.get_model import get_encoder_simplified as get_encoder_michelangelo_encoder
from .michelangelo.get_model import ShapeAsLatentPerceiverEncoder
from .skin_vae.autoencoders.autoencoder_kl_tripo2 import Tripo2Encoder
@dataclass(frozen=True)
class _MAP_MESH_ENCODER:
michelangelo = AlignedShapeLatentPerceiver
michelangelo_encoder = ShapeAsLatentPerceiverEncoder
tripo = Tripo2Encoder
MAP_MESH_ENCODER = _MAP_MESH_ENCODER()
def get_mesh_encoder(**kwargs):
MAP = {
'michelangelo': get_encoder_michelangelo,
'michelangelo_encoder': get_encoder_michelangelo_encoder,
'tripo': Tripo2Encoder,
}
__target__ = kwargs['__target__']
del kwargs['__target__']
assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
return MAP[__target__](**deepcopy(kwargs))