Sbhat2026 commited on
Commit
fce0ea0
·
1 Parent(s): 05e2a7c

fix mlb load order again

Browse files
Files changed (1) hide show
  1. server.py +39 -37
server.py CHANGED
@@ -24,21 +24,40 @@ HF_REPO = "Sbhat2026/protfunc-models" # exact case matters
24
  HF_FILES = ["baseline_res.pth", "mlb_public_v1.pkl", "go_annotations_fixed.csv", "go_names.json"]
25
 
26
  def ensure_model_files():
27
- # go_names.json is optional — skip if not yet on HF
 
 
 
28
  optional = {"go_names.json"}
29
  missing = [f for f in HF_FILES if not os.path.exists(os.path.join(BASE_DIR, f))]
30
  if not missing:
31
  return
 
32
  print(f"Downloading {len(missing)} file(s) from HuggingFace...")
33
- from huggingface_hub import hf_hub_download
34
  for fname in missing:
35
- print(f" {fname}...")
36
- path = hf_hub_download(
37
- repo_id=HF_REPO, filename=fname,
38
- local_dir=BASE_DIR, repo_type="model",
39
- token=os.environ.get("HF_TOKEN"),
40
- )
41
- print(f" saved to {path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ensure_model_files()
44
 
@@ -67,6 +86,9 @@ def load_go_map():
67
 
68
  go_map = load_go_map()
69
 
 
 
 
70
  # Build MF-only whitelist from go_names.json aspect data (populated after fetch)
71
  # Falls back to allowing all labels if not available
72
  mf_indices = None # set below after go_names loaded
@@ -77,8 +99,9 @@ if os.path.exists(go_names_path):
77
  go_map.update(json.load(open(go_names_path)))
78
  print(f"Canonical GO names loaded: {len(go_map)} entries")
79
 
80
- mlb = joblib.load(os.path.join(BASE_DIR, "mlb_public_v1.pkl"))
81
- NUM_LABELS = len(mlb.classes_)
 
82
  mf_go_ids = {go_id for go_id, name in go_map.items() if name != go_id and not name.startswith("GO:")}
83
  if mf_go_ids:
84
  mf_indices = {i for i, go_id in enumerate(mlb.classes_) if go_id in mf_go_ids}
@@ -136,24 +159,6 @@ print("ESM-2 loaded OK")
136
  class ProteinRequest(BaseModel):
137
  sequence: str
138
 
139
- VALID_AA = set('ACDEFGHIKLMNPQRSTVWY')
140
- INVALID_AA = set('BJOUXY Z') # ambiguous or non-standard single-letter codes
141
-
142
- def clean_sequence(raw):
143
- """Uppercase, strip all non-alpha characters, return cleaned string."""
144
- return re.sub(r'[^A-Za-z]', '', raw).upper()
145
-
146
- def validate_sequence(seq, name):
147
- """Return error string if invalid, else None."""
148
- if not seq:
149
- return "Empty sequence"
150
- if len(seq) > 2500:
151
- return f"Sequence too long ({len(seq):,} aa, max 2500)"
152
- invalid = sorted(set(seq) - VALID_AA)
153
- if invalid:
154
- return f"Invalid amino acid characters: {', '.join(invalid)} — only standard 20 AA accepted"
155
- return None
156
-
157
  def parse_sequences(text):
158
  text = text.strip()
159
  if text.startswith(">"):
@@ -162,15 +167,13 @@ def parse_sequences(text):
162
  i = 1
163
  while i < len(blocks):
164
  name = blocks[i][1:].strip()
165
- raw = blocks[i+1] if i+1 < len(blocks) else ""
166
- seq = clean_sequence(raw)
167
  if seq:
168
- names.append(name or f"Sequence {len(names)+1}")
169
  seqs.append(seq)
170
  i += 2
171
  return list(zip(names, seqs))
172
- seqs = [clean_sequence(l) for l in text.splitlines() if l.strip()]
173
- seqs = [s for s in seqs if s]
174
  return [(f"Sequence {i+1}", s) for i, s in enumerate(seqs)]
175
 
176
  @app.post("/predict")
@@ -178,9 +181,8 @@ async def predict(request: ProteinRequest):
178
  entries = parse_sequences(request.sequence)
179
  results = []
180
  for name, sequence in entries:
181
- err = validate_sequence(sequence, name)
182
- if err:
183
- results.append({"name": name, "error": err})
184
  continue
185
  try:
186
  _, _, tokens = batch_converter([("p", sequence)])
 
24
  HF_FILES = ["baseline_res.pth", "mlb_public_v1.pkl", "go_annotations_fixed.csv", "go_names.json"]
25
 
26
  def ensure_model_files():
27
+ import time
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ # go_names.json is optional — do not fail if absent from repo
31
  optional = {"go_names.json"}
32
  missing = [f for f in HF_FILES if not os.path.exists(os.path.join(BASE_DIR, f))]
33
  if not missing:
34
  return
35
+
36
  print(f"Downloading {len(missing)} file(s) from HuggingFace...")
 
37
  for fname in missing:
38
+ # Retry with exponential back-off so DNS resolves after cold-start network delay
39
+ max_attempts = 5
40
+ for attempt in range(1, max_attempts + 1):
41
+ try:
42
+ print(f" {fname} (attempt {attempt}/{max_attempts})...")
43
+ path = hf_hub_download(
44
+ repo_id=HF_REPO, filename=fname,
45
+ local_dir=BASE_DIR, repo_type="model",
46
+ token=os.environ.get("HF_TOKEN"),
47
+ )
48
+ print(f" saved to {path}")
49
+ break
50
+ except Exception as e:
51
+ if fname in optional:
52
+ print(f" {fname} optional — skipping ({e})")
53
+ break
54
+ if attempt == max_attempts:
55
+ raise RuntimeError(
56
+ f"Failed to download {fname} after {max_attempts} attempts: {e}"
57
+ )
58
+ wait = 2 ** attempt # 2s, 4s, 8s, 16s
59
+ print(f" DNS/network error, retrying in {wait}s... ({e})")
60
+ time.sleep(wait)
61
 
62
  ensure_model_files()
63
 
 
86
 
87
  go_map = load_go_map()
88
 
89
+ mlb = joblib.load(os.path.join(BASE_DIR, "mlb_public_v1.pkl"))
90
+ NUM_LABELS = len(mlb.classes_)
91
+
92
  # Build MF-only whitelist from go_names.json aspect data (populated after fetch)
93
  # Falls back to allowing all labels if not available
94
  mf_indices = None # set below after go_names loaded
 
99
  go_map.update(json.load(open(go_names_path)))
100
  print(f"Canonical GO names loaded: {len(go_map)} entries")
101
 
102
+ # Build index whitelist: only predict labels that are MF terms
103
+ # go_names.json maps GO ID -> name; non-MF terms were stored as raw GO ID (e.g. "GO:0005886")
104
+ # We identify MF terms as those whose name != their GO ID (i.e. successfully resolved)
105
  mf_go_ids = {go_id for go_id, name in go_map.items() if name != go_id and not name.startswith("GO:")}
106
  if mf_go_ids:
107
  mf_indices = {i for i, go_id in enumerate(mlb.classes_) if go_id in mf_go_ids}
 
159
  class ProteinRequest(BaseModel):
160
  sequence: str
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def parse_sequences(text):
163
  text = text.strip()
164
  if text.startswith(">"):
 
167
  i = 1
168
  while i < len(blocks):
169
  name = blocks[i][1:].strip()
170
+ seq = re.sub(r"\s+", "", blocks[i+1]) if i+1 < len(blocks) else ""
 
171
  if seq:
172
+ names.append(name)
173
  seqs.append(seq)
174
  i += 2
175
  return list(zip(names, seqs))
176
+ seqs = [l.strip() for l in text.splitlines() if l.strip()]
 
177
  return [(f"Sequence {i+1}", s) for i, s in enumerate(seqs)]
178
 
179
  @app.post("/predict")
 
181
  entries = parse_sequences(request.sequence)
182
  results = []
183
  for name, sequence in entries:
184
+ if len(sequence) > 2500:
185
+ results.append({"name": name, "error": "Sequence too long (max 2500 aa)"})
 
186
  continue
187
  try:
188
  _, _, tokens = batch_converter([("p", sequence)])