Sophia Tang commited on
Commit
c511e34
·
1 Parent(s): 4e90f71
Files changed (1) hide show
  1. tr2d2-pep/peptide_mcts.py +3 -3
tr2d2-pep/peptide_mcts.py CHANGED
@@ -345,7 +345,7 @@ class MCTS:
345
  print(f"[BUFFER] reason={reason} sv={np.round(sv,4)} "
346
  f"buf_len={len(self.buffer)} extra={extra}")
347
 
348
- def updateBuffer(self, x_final, log_rnd, score_vectors, childSequences):
349
  B = x_final.shape[0]
350
  traj_log_rnds, scalar_rewards = [], []
351
 
@@ -367,7 +367,7 @@ class MCTS:
367
  "log_rnd": traj_log_rnd.clone(),
368
  "final_reward": scalar_reward,
369
  "score_vector": sv.copy(),
370
- "seq": childSequences[i],
371
  }
372
 
373
  # Drop if dominated by any existing
@@ -601,7 +601,7 @@ class MCTS:
601
  valid_traj_log_rnd = torch.stack(valid_traj_log_rnd, dim=0)
602
  # update buffer and get rewards
603
  with self.timer.section("expand.update_buffer"):
604
- traj_log_rnds, scalar_rewards = self.updateBuffer(valid_x_final, valid_traj_log_rnd, score_vectors, childSequences)
605
 
606
  allChildReward = np.zeros_like(score_vectors[0])
607
 
 
345
  print(f"[BUFFER] reason={reason} sv={np.round(sv,4)} "
346
  f"buf_len={len(self.buffer)} extra={extra}")
347
 
348
+ def updateBuffer(self, x_final, log_rnd, score_vectors, validSequences):
349
  B = x_final.shape[0]
350
  traj_log_rnds, scalar_rewards = [], []
351
 
 
367
  "log_rnd": traj_log_rnd.clone(),
368
  "final_reward": scalar_reward,
369
  "score_vector": sv.copy(),
370
+ "seq": validSequences[i],
371
  }
372
 
373
  # Drop if dominated by any existing
 
601
  valid_traj_log_rnd = torch.stack(valid_traj_log_rnd, dim=0)
602
  # update buffer and get rewards
603
  with self.timer.section("expand.update_buffer"):
604
+ traj_log_rnds, scalar_rewards = self.updateBuffer(valid_x_final, valid_traj_log_rnd, score_vectors, validSequences)
605
 
606
  allChildReward = np.zeros_like(score_vectors[0])
607