OrthoReg / src /args.py
gezi2333's picture
Upload folder using huggingface_hub
3589275 verified
import argparse
import os
import torch
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-location",
type=str,
default=os.path.expanduser("/path/datasets/"),
help="The root directory for the datasets.",
)
parser.add_argument(
"--eval-datasets",
default=None,
type=lambda x: x.split(","),
help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ",
)
parser.add_argument(
"--train-dataset",
default=None,
type=lambda x: x.split(","),
help="Which dataset(s) to patch on.",
)
parser.add_argument(
"--exp_name",
type=str,
default=None,
help="Name of the experiment, for organization purposes only.",
)
parser.add_argument(
"--results-db",
type=str,
default=None,
help="Where to store the results, else does not store",
)
parser.add_argument(
"--model",
type=str,
default="ViT-B-32",
help="The type of model (e.g. RN50, ViT-B-32).",
)
parser.add_argument(
"--batch-size",
type=int,
default=128,
)
parser.add_argument(
"--num-grad-accumulation",
type=int,
default=1,
help="Number of gradient accumulation steps.",
)
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.")
parser.add_argument("--wd", type=float, default=0.1, help="Weight decay")
parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.")
parser.add_argument(
"--warmup_length",
type=int,
default=500,
)
parser.add_argument(
"--epochs",
type=int,
default=10,
)
parser.add_argument(
"--load",
type=lambda x: x.split(","),
default=None,
help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
)
parser.add_argument(
"--save",
type=str,
default=None,
help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
)
parser.add_argument(
"--cache-dir",
type=str,
default=None,
help="Directory for caching features and encoder",
)
parser.add_argument(
"--openclip-cachedir",
type=str,
default=os.path.expanduser("~/openclip-cachedir/open_clip"),
help="Directory for caching models from OpenCLIP",
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of processes for distributed training.",
)
parser.add_argument(
"--checkpoint-every",
type=int,
default=-1,
help="How often to checkpoint the model.",
)
parser.add_argument(
"--port",
type=int,
default=12355,
help="Port for distributed training.",
)
parser.add_argument(
"--seed",
type=int,
default=1993,
help="Random seed.",
)
parser.add_argument(
"--finetuning-mode",
choices=["standard", "standard_ortho", "linear", "linear_ortho", "linear-2", "linear-2_ortho"],
help="Finetuning mode: standard/linear/linear-2 with optional ortho regularization.",
)
parser.add_argument(
"--n-eval-points",
type=int,
default=21,
help="Number of evaluation points used to find optimal coefficient in task arithmetic.",
)
parser.add_argument(
"--ortho-lambda",
type=float,
default=0.0,
help="Weight of the orthogonality regularization term. Default 0.0 means no regularization.",
)
parser.add_argument(
"--control_threshold",
type=float,
default=0.95,
help="Control dataset performance degradation threshold.",
)
parser.add_argument(
"--alpha",
type=float,
default=None,
help="Manually specify the scaling coefficient for task vectors. If None, it will be optimized on validation set.",
)
parsed_args = parser.parse_args()
parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
if parsed_args.load is not None and len(parsed_args.load) == 1:
parsed_args.load = parsed_args.load[0]
return parsed_args