| | from click import Parameter |
| | import numpy as np |
| | from joblib import load |
| | from typing import List |
| | import pandas as pd |
| | import random |
| | from pydantic import ( |
| | BaseModel, |
| | ValidationError, |
| | ValidationInfo, |
| | field_validator, |
| | model_validator, |
| | ) |
| |
|
| | PARAM_BOUNDS = [ |
| | {"name": "N", "type": "range", "bounds": [1, 10]}, |
| | {"name": "alpha", "type": "range", "bounds": [0.0, 1.0]}, |
| | {"name": "d_model", "type": "range", "bounds": [100, 1024]}, |
| | {"name": "dim_feedforward", "type": "range", "bounds": [1024, 4096]}, |
| | {"name": "dropout", "type": "range", "bounds": [0.0, 1.0]}, |
| | {"name": "emb_scaler", "type": "range", "bounds": [0.0, 1.0]}, |
| | {"name": "epochs_step", "type": "range", "bounds": [5, 20]}, |
| | {"name": "eps", "type": "range", "bounds": [1e-7, 1e-4]}, |
| | {"name": "fudge", "type": "range", "bounds": [0.0, 0.1]}, |
| | {"name": "heads", "type": "range", "bounds": [1, 10]}, |
| | {"name": "k", "type": "range", "bounds": [2, 10]}, |
| | {"name": "lr", "type": "range", "bounds": [1e-4, 6e-3]}, |
| | {"name": "pe_resolution", "type": "range", "bounds": [2500, 10000]}, |
| | {"name": "ple_resolution", "type": "range", "bounds": [2500, 10000]}, |
| | {"name": "pos_scaler", "type": "range", "bounds": [0.0, 1.0]}, |
| | {"name": "weight_decay", "type": "range", "bounds": [0.0, 1.0]}, |
| | {"name": "batch_size", "type": "range", "bounds": [32, 256]}, |
| | {"name": "out_hidden4", "type": "range", "bounds": [32, 512]}, |
| | {"name": "betas1", "type": "range", "bounds": [0.5, 0.9999]}, |
| | {"name": "betas2", "type": "range", "bounds": [0.5, 0.9999]}, |
| | {"name": "bias", "type": "choice", "values": [False, True]}, |
| | {"name": "criterion", "type": "choice", "values": ["RobustL1", "RobustL2"]}, |
| | {"name": "elem_prop", "type": "choice", "values": ["mat2vec", "magpie", "onehot"]}, |
| | {"name": "train_frac", "type": "range", "bounds": [0.01, 1.0]}, |
| | ] |
| |
|
| | tol = 1e-6 |
| |
|
| | class Parameterization(BaseModel): |
| | N: float |
| | alpha: float |
| | d_model: float |
| | dim_feedforward: float |
| | dropout: float |
| | emb_scaler: float |
| | epochs_step: float |
| | eps: float |
| | fudge: float |
| | heads: float |
| | k: float |
| | lr: float |
| | pe_resolution: float |
| | ple_resolution: float |
| | pos_scaler: float |
| | weight_decay: float |
| | batch_size: float |
| | out_hidden4: float |
| | betas1: float |
| | betas2: float |
| | bias: bool |
| | criterion: str |
| | elem_prop: str |
| | train_frac: float |
| |
|
| | @field_validator("*") |
| | def check_bounds(cls, v: int, info: ValidationInfo) -> int: |
| | param = next( |
| | (item for item in PARAM_BOUNDS if item["name"] == info.field_name), |
| | None, |
| | ) |
| | if param is None: |
| | return v |
| |
|
| | if param["type"] == "range": |
| | min_val, max_val = param["bounds"] |
| | if not (min_val - tol) <= v <= (max_val + tol): |
| | raise ValueError( |
| | f"{info.field_name} must be between {min_val} and {max_val}" |
| | ) |
| | elif param["type"] == "choice": |
| | if v not in param["values"]: |
| | raise ValueError(f"{info.field_name} must be one of {param['values']}") |
| |
|
| | return v |
| |
|
| | @model_validator(mode="after") |
| | def check_constraints(self) -> "Parameterization": |
| | if (self.betas1 - tol) > (self.betas2 + tol): |
| | raise ValueError( |
| | f"Received betas1={self.betas1} which should be less than betas2={self.betas2}" |
| | ) |
| | if self.emb_scaler + self.pos_scaler - tol > 1.0: |
| | raise ValueError( |
| | f"Received emb_scaler={self.emb_scaler} and pos_scaler={self.pos_scaler} which should sum to less than or equal to 1.0" |
| | ) |
| |
|
| |
|
| | class CrabNetSurrogateModel(object): |
| | def __init__(self, fpath="models/surrogate_models_hgbr_opt.pkl"): |
| | self.models = load(fpath) |
| |
|
| | def prepare_params_for_eval(self, raw_params: dict): |
| | raw_params["bias"] = int(raw_params["bias"]) |
| | raw_params["use_RobustL1"] = raw_params["criterion"] == "RobustL1" |
| | del raw_params["criterion"] |
| |
|
| | |
| | |
| | elem_prop = raw_params["elem_prop"] |
| | raw_params["elem_prop_magpie"] = 0 |
| | raw_params["elem_prop_mat2vec"] = 0 |
| | raw_params["elem_prop_onehot"] = 0 |
| | raw_params[f"elem_prop_{elem_prop}"] = 1 |
| | del raw_params["elem_prop"] |
| |
|
| | return raw_params |
| |
|
| | def surrogate_evaluate( |
| | self, params_list: List[dict], seed=None, remove_noise=False |
| | ): |
| | assert isinstance(params_list, list), "Input must be a list of dictionaries" |
| | |
| | [Parameterization(**params) for params in params_list] |
| |
|
| | parameters = pd.DataFrame(params_list) |
| | parameters = parameters.apply(self.prepare_params_for_eval, axis=1) |
| |
|
| | if remove_noise: |
| | mae_percentiles = [0.5] * len(parameters) |
| | rmse_percentiles = [0.5] * len(parameters) |
| | runtime_percentiles = [0.5] * len(parameters) |
| | else: |
| | |
| | rng = np.random.default_rng(seed) |
| |
|
| | |
| | |
| | mae_percentiles = rng.uniform(0, 1, size=len(parameters)) |
| | rmse_percentiles = mae_percentiles |
| |
|
| | |
| | runtime_percentiles = 1 - mae_percentiles |
| |
|
| | |
| | mae_model = self.models["mae"] |
| | rmse_model = self.models["rmse"] |
| | runtime_model = self.models["runtime"] |
| | model_size_model = self.models["model_size"] |
| |
|
| | |
| | mae = self.models["mae"].predict( |
| | parameters.assign(mae_rank=mae_percentiles)[mae_model.feature_names_in_] |
| | ) |
| |
|
| | rmse = self.models["rmse"].predict( |
| | parameters.assign(rmse_rank=rmse_percentiles)[rmse_model.feature_names_in_] |
| | ) |
| |
|
| | runtime = self.models["runtime"].predict( |
| | parameters.assign(runtime_rank=runtime_percentiles)[ |
| | runtime_model.feature_names_in_ |
| | ] |
| | ) |
| |
|
| | |
| | model_size = self.models["model_size"].predict( |
| | parameters[model_size_model.feature_names_in_] |
| | ) |
| |
|
| | |
| | results = [ |
| | {"mae": m, "rmse": r, "runtime": rt, "model_size": ms} |
| | for m, r, rt, ms in zip(mae, rmse, runtime, model_size) |
| | ] |
| |
|
| | return results |
| |
|
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|