import argparse import os import torch def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument( "--data-location", type=str, default=os.path.expanduser("/path/datasets/"), help="The root directory for the datasets.", ) parser.add_argument( "--eval-datasets", default=None, type=lambda x: x.split(","), help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ", ) parser.add_argument( "--train-dataset", default=None, type=lambda x: x.split(","), help="Which dataset(s) to patch on.", ) parser.add_argument( "--exp_name", type=str, default=None, help="Name of the experiment, for organization purposes only.", ) parser.add_argument( "--results-db", type=str, default=None, help="Where to store the results, else does not store", ) parser.add_argument( "--model", type=str, default="ViT-B-32", help="The type of model (e.g. RN50, ViT-B-32).", ) parser.add_argument( "--batch-size", type=int, default=128, ) parser.add_argument( "--num-grad-accumulation", type=int, default=1, help="Number of gradient accumulation steps.", ) parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.") parser.add_argument("--wd", type=float, default=0.1, help="Weight decay") parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.") parser.add_argument( "--warmup_length", type=int, default=500, ) parser.add_argument( "--epochs", type=int, default=10, ) parser.add_argument( "--load", type=lambda x: x.split(","), default=None, help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", ) parser.add_argument( "--save", type=str, default=None, help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.", ) parser.add_argument( "--cache-dir", type=str, default=None, help="Directory for caching features and encoder", ) parser.add_argument( "--openclip-cachedir", type=str, default=os.path.expanduser("~/openclip-cachedir/open_clip"), help="Directory for caching models from OpenCLIP", ) parser.add_argument( "--world-size", type=int, default=1, help="Number of processes for distributed training.", ) parser.add_argument( "--checkpoint-every", type=int, default=-1, help="How often to checkpoint the model.", ) parser.add_argument( "--port", type=int, default=12355, help="Port for distributed training.", ) parser.add_argument( "--seed", type=int, default=1993, help="Random seed.", ) parser.add_argument( "--finetuning-mode", choices=["standard", "standard_ortho", "linear", "linear_ortho", "linear-2", "linear-2_ortho"], help="Finetuning mode: standard/linear/linear-2 with optional ortho regularization.", ) parser.add_argument( "--n-eval-points", type=int, default=21, help="Number of evaluation points used to find optimal coefficient in task arithmetic.", ) parser.add_argument( "--ortho-lambda", type=float, default=0.0, help="Weight of the orthogonality regularization term. Default 0.0 means no regularization.", ) parser.add_argument( "--control_threshold", type=float, default=0.95, help="Control dataset performance degradation threshold.", ) parser.add_argument( "--alpha", type=float, default=None, help="Manually specify the scaling coefficient for task vectors. If None, it will be optimized on validation set.", ) parsed_args = parser.parse_args() parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" if parsed_args.load is not None and len(parsed_args.load) == 1: parsed_args.load = parsed_args.load[0] return parsed_args