| |
| from transformers import PretrainedConfig |
| from typing import List |
|
|
|
|
| ''' |
| newtwork_config = { |
| "epochs": 150, |
| "batch_size": 250, |
| "n_steps": 16, # timestep |
| "dataset": "CAPS", |
| "in_channels": 1, |
| "data_path": "./data", |
| "lr": 0.001, |
| "n_class": 10, |
| "latent_dim": 128, |
| "input_size": 32, |
| "model": "FSVAE" ,# FSVAE or FSVAE_large |
| "k": 20, # multiplier of channel |
| "scheduled": True, # whether to apply scheduled sampling |
| "loss_func": 'kld', # mmd or kld |
| "accum_iter" : 1, |
| "devices": [0], |
| } |
| |
| hidden_dims = [32, 64, 128, 256] |
| |
| ''' |
|
|
| class FSAEConfig(PretrainedConfig): |
| model_type = "fsae" |
|
|
| def __init__( |
| self, |
| in_channels: int = 1, |
| hidden_dims : List[int] = [32, 64, 128, 256], |
| k : int = 20, |
| n_steps : int = 16, |
| latent_dim : int = 128, |
| scheduled : bool = True, |
| |
| dt:float = 5, |
| a:float = 0.25, |
| aa: float = 0.5, |
| Vth : float = 0.2, |
| tau : float = 0.25, |
| **kwargs, |
| ): |
| |
| |
| |
| |
|
|
| self.in_channels = in_channels |
| self.hidden_dims = hidden_dims |
| self.k = k |
| self.n_steps = n_steps |
| self.latent_dim = latent_dim |
| self.scheduled = scheduled |
| self.dt = dt |
| self.a = a |
| self.aa = aa |
| self.Vth = Vth |
| self.tau = tau |
| super().__init__(**kwargs) |