| from .embedders import ( |
| CaduceusEmbedder, |
| DNABertEmbedder, |
| NucleotideTransformerEmbedder, |
| GPNEmbedder, |
| SegmentNTEmbedder, |
| ESMEmbedder, |
| ESMDBPEmbedder, |
| ProGenEmbedder, |
| ) |
|
|
|
|
| def get_embedder(name, device, for_dna=True): |
| name = name.lower() |
| if for_dna: |
| if name == "caduceus": |
| return CaduceusEmbedder(device) |
| if name == "dnabert": |
| return DNABertEmbedder(device) |
| if name == "nucleotide": |
| return NucleotideTransformerEmbedder(device) |
| if name == "gpn": |
| return GPNEmbedder(device) |
| if name == "segmentnt": |
| return SegmentNTEmbedder(device) |
| else: |
| if name in ("esm",): |
| return ESMEmbedder(device) |
| if name in ("esm-dbp", "esm_dbp"): |
| return ESMDBPEmbedder(device) |
| if name == "progen": |
| return ProGenEmbedder(device) |
| raise ValueError(f"Unknown model {name} (for_dna={for_dna})") |
|
|