File size: 970 Bytes
bc0d37c
29899b4
 
 
 
 
bc0d37c
 
29899b4
bc0d37c
 
29899b4
bc0d37c
 
 
29899b4
 
 
 
 
 
 
 
 
 
bc0d37c
29899b4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from .embedders import (
    CaduceusEmbedder,
    DNABertEmbedder,
    NucleotideTransformerEmbedder,
    GPNEmbedder,
    SegmentNTEmbedder,
    ESMEmbedder,
    ESMDBPEmbedder,
    ProGenEmbedder,
)


def get_embedder(name, device, for_dna=True):
    name = name.lower()
    if for_dna:
        if name == "caduceus":
            return CaduceusEmbedder(device)
        if name == "dnabert":
            return DNABertEmbedder(device)
        if name == "nucleotide":
            return NucleotideTransformerEmbedder(device)
        if name == "gpn":
            return GPNEmbedder(device)
        if name == "segmentnt":
            return SegmentNTEmbedder(device)
    else:
        if name in ("esm",):
            return ESMEmbedder(device)
        if name in ("esm-dbp", "esm_dbp"):
            return ESMDBPEmbedder(device)
        if name == "progen":
            return ProGenEmbedder(device)
    raise ValueError(f"Unknown model {name} (for_dna={for_dna})")