Update sample.py
Browse files
sample.py
CHANGED
|
@@ -91,6 +91,7 @@ def main(config: omegaconf.DictConfig) -> None:
|
|
| 91 |
tokenizer=tokenizer,
|
| 92 |
config=config, logger=False)
|
| 93 |
pretrained.eval()
|
|
|
|
| 94 |
|
| 95 |
bindevaluator = BindEvaluator.load_from_checkpoint(
|
| 96 |
config.guidance.classifier_checkpoint_path,
|
|
@@ -101,6 +102,7 @@ def main(config: omegaconf.DictConfig) -> None:
|
|
| 101 |
d_k=64,
|
| 102 |
d_v=128,
|
| 103 |
d_inner=64)
|
|
|
|
| 104 |
|
| 105 |
samples = []
|
| 106 |
for _ in tqdm(
|
|
|
|
| 91 |
tokenizer=tokenizer,
|
| 92 |
config=config, logger=False)
|
| 93 |
pretrained.eval()
|
| 94 |
+
pretrained = pretrained.to('cuda')
|
| 95 |
|
| 96 |
bindevaluator = BindEvaluator.load_from_checkpoint(
|
| 97 |
config.guidance.classifier_checkpoint_path,
|
|
|
|
| 102 |
d_k=64,
|
| 103 |
d_v=128,
|
| 104 |
d_inner=64)
|
| 105 |
+
bindevaluator = bindevaluator.to('cuda')
|
| 106 |
|
| 107 |
samples = []
|
| 108 |
for _ in tqdm(
|