AlienChen commited on
Commit
ff6f52f
·
verified ·
1 Parent(s): f52b5d3

Update models/peptide_classifiers.py

Browse files
Files changed (1) hide show
  1. models/peptide_classifiers.py +4 -19
models/peptide_classifiers.py CHANGED
@@ -509,7 +509,7 @@ class AffinityModel(nn.Module):
509
 
510
  class HemolysisModel:
511
  def __init__(self, device):
512
- self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/collection/classifiers/ckpt/wt_hemolysis.json')
513
 
514
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
515
  self.model.eval()
@@ -569,7 +569,7 @@ class MLPClassifier(nn.Module):
569
 
570
  class NonfoulingModel:
571
  def __init__(self, device):
572
- ckpt = torch.load('/scratch/pranamlab/tong/collection/classifiers/ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
573
  best_params = ckpt["best_params"]
574
  self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
575
  self.predictor.load_state_dict(ckpt["state_dict"])
@@ -595,7 +595,7 @@ class NonfoulingModel:
595
  class SolubilityModel:
596
  def __init__(self, device):
597
  # change model path
598
- self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json')
599
 
600
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
601
  self.model.eval()
@@ -675,21 +675,6 @@ class PeptideCNN(nn.Module):
675
  return features
676
  return self.predictor(features) # Output shape: (B, 1)
677
 
678
- # class HalfLifeModel:
679
- # def __init__(self, device):
680
- # input_dim = 1280
681
- # hidden_dims = [input_dim // 2, input_dim // 4]
682
- # output_dim = input_dim // 8
683
- # dropout_rate = 0.3
684
- # self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
685
- # self.model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
686
- # self.model.eval()
687
-
688
- # def __call__(self, x):
689
- # prediction = self.model(x, return_features=False)
690
- # halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
691
- # return halflife / 2
692
-
693
 
694
  # -----------------------------
695
  # Model definition (must match training)
@@ -758,7 +743,7 @@ class HalfLifeModel:
758
  def __init__(
759
  self,
760
  device,
761
- ckpt_path = "/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_transformer_log/final_model.pt",
762
  ):
763
  self.device = device
764
 
 
509
 
510
  class HemolysisModel:
511
  def __init__(self, device):
512
+ self.predictor = xgb.Booster(model_file='../classifier_ckpt/wt_hemolysis.json')
513
 
514
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
515
  self.model.eval()
 
569
 
570
  class NonfoulingModel:
571
  def __init__(self, device):
572
+ ckpt = torch.load('../classifier_ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
573
  best_params = ckpt["best_params"]
574
  self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
575
  self.predictor.load_state_dict(ckpt["state_dict"])
 
595
  class SolubilityModel:
596
  def __init__(self, device):
597
  # change model path
598
+ self.predictor = xgb.Booster(model_file='../classifier_ckpt/best_model_solubility.json')
599
 
600
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
601
  self.model.eval()
 
675
  return features
676
  return self.predictor(features) # Output shape: (B, 1)
677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
  # -----------------------------
680
  # Model definition (must match training)
 
743
  def __init__(
744
  self,
745
  device,
746
+ ckpt_path = "../classifier_ckpt/wt_halflife.pt",
747
  ):
748
  self.device = device
749