File size: 970 Bytes
bc0d37c 29899b4 bc0d37c 29899b4 bc0d37c 29899b4 bc0d37c 29899b4 bc0d37c 29899b4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | from .embedders import (
CaduceusEmbedder,
DNABertEmbedder,
NucleotideTransformerEmbedder,
GPNEmbedder,
SegmentNTEmbedder,
ESMEmbedder,
ESMDBPEmbedder,
ProGenEmbedder,
)
def get_embedder(name, device, for_dna=True):
name = name.lower()
if for_dna:
if name == "caduceus":
return CaduceusEmbedder(device)
if name == "dnabert":
return DNABertEmbedder(device)
if name == "nucleotide":
return NucleotideTransformerEmbedder(device)
if name == "gpn":
return GPNEmbedder(device)
if name == "segmentnt":
return SegmentNTEmbedder(device)
else:
if name in ("esm",):
return ESMEmbedder(device)
if name in ("esm-dbp", "esm_dbp"):
return ESMDBPEmbedder(device)
if name == "progen":
return ProGenEmbedder(device)
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|