from .embedders import ( CaduceusEmbedder, DNABertEmbedder, NucleotideTransformerEmbedder, GPNEmbedder, SegmentNTEmbedder, ESMEmbedder, ESMDBPEmbedder, ProGenEmbedder, ) def get_embedder(name, device, for_dna=True): name = name.lower() if for_dna: if name == "caduceus": return CaduceusEmbedder(device) if name == "dnabert": return DNABertEmbedder(device) if name == "nucleotide": return NucleotideTransformerEmbedder(device) if name == "gpn": return GPNEmbedder(device) if name == "segmentnt": return SegmentNTEmbedder(device) else: if name in ("esm",): return ESMEmbedder(device) if name in ("esm-dbp", "esm_dbp"): return ESMDBPEmbedder(device) if name == "progen": return ProGenEmbedder(device) raise ValueError(f"Unknown model {name} (for_dna={for_dna})")