| #!/bin/bash |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| <<comment |
| # Usage: |
| cd scripts/ |
| MODEL=<ar|mdlm|udlm> |
| sbatch \ |
| --export=ALL,MODEL=${MODEL} \ |
| --job-name=eval_text8_gen_ppl_${MODEL} \ |
| eval_text8_gen_ppl.sh |
| comment |
|
|
| |
| cd ../ || exit |
| source setup_env.sh || exit |
| export HYDRA_FULL_ERROR=1 |
|
|
| |
| |
| |
| |
|
|
| if [ -z "${MODEL}" ]; then |
| echo "MODEL is not set" |
| exit 1 |
| fi |
| if [ -z "${SAMPLING_STEPS}" ]; then |
| SAMPLING_STEPS=128 |
| fi |
| if [ -z "${SEED}" ]; then |
| SEED=1 |
| fi |
|
|
| if [ "${MODEL}" = "ar" ]; then |
| parameterization="ar" |
| diffusion="absorbing_state" |
| TRAIN_T=0 |
| time_conditioning=False |
| sampling_use_cache=False |
| CKPT="${PWD}/outputs/text8/ar" |
| elif [ "${MODEL}" = "mdlm" ]; then |
| parameterization="subs" |
| diffusion="absorbing_state" |
| TRAIN_T=0 |
| time_conditioning=False |
| sampling_use_cache=True |
| CKPT="${PWD}/outputs/text8/mdlm" |
| elif [ "${MODEL}" = "udlm" ]; then |
| parameterization="d3pm" |
| diffusion="uniform" |
| TRAIN_T=0 |
| time_conditioning=True |
| sampling_use_cache=False |
| CKPT="${PWD}/outputs/text8/udlm" |
| else |
| echo "Invalid MODEL: ${MODEL}" |
| exit 1 |
| fi |
| generated_seqs_path="${CKPT}/samples-text8-gen-ppl-eval-_T-${SAMPLING_STEPS}_seed-${SEED}.json" |
|
|
| |
| python -u -m main \ |
| hydra.output_subdir=null \ |
| hydra.run.dir="${CKPT}" \ |
| hydra/job_logging=disabled \ |
| hydra/hydra_logging=disabled \ |
| seed=${SEED} \ |
| mode="gen_ppl_eval" \ |
| eval.checkpoint_path="${CKPT}/checkpoints/best.ckpt" \ |
| data=text8 \ |
| backbone=dit \ |
| model=small \ |
| model.length=256 \ |
| training.guidance=null \ |
| parameterization=${parameterization} \ |
| diffusion=${diffusion} \ |
| time_conditioning=${time_conditioning} \ |
| T=${TRAIN_T} \ |
| sampling.num_sample_batches=32 \ |
| sampling.batch_size=32 \ |
| sampling.steps=${SAMPLING_STEPS} \ |
| sampling.use_cache=${sampling_use_cache} \ |
| eval.generated_samples_path=${generated_seqs_path} \ |
| +eval.generative_ppl_model_name_or_path="gpt2-large" |
|
|