| | import json |
| | from dataclasses import dataclass, field, asdict |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import Optional, Set |
| | import warnings |
| |
|
| | import torch |
| |
|
| |
|
| | @dataclass |
| | class EvaluationState: |
| | _attacks_to_run: Set[str] |
| | path: Optional[Path] = None |
| | _run_attacks: Set[str] = field(default_factory=set) |
| | _robust_flags: Optional[torch.Tensor] = None |
| | _last_saved: datetime = datetime(1, 1, 1) |
| | _SAVE_TIMEOUT: int = 60 |
| | _clean_accuracy: float = float("nan") |
| |
|
| | def to_disk(self, force: bool = False) -> None: |
| | seconds_since_last_save = (datetime.now() - |
| | self._last_saved).total_seconds() |
| | if self.path is None or (seconds_since_last_save < self._SAVE_TIMEOUT |
| | and not force): |
| | return |
| | self._last_saved = datetime.now() |
| | d = asdict(self) |
| | if self.robust_flags is not None: |
| | d["_robust_flags"] = d["_robust_flags"].cpu().tolist() |
| | d["_run_attacks"] = list(self._run_attacks) |
| | with self.path.open("w", ) as f: |
| | json.dump(d, f, default=str) |
| |
|
| | @classmethod |
| | def from_disk(cls, path: Path) -> "EvaluationState": |
| | with path.open("r") as f: |
| | d = json.load(f) |
| | d["_robust_flags"] = torch.tensor(d["_robust_flags"], dtype=torch.bool) |
| | d["path"] = Path(d["path"]) |
| | if path != d["path"]: |
| | warnings.warn( |
| | UserWarning( |
| | "The given path is different from the one found in the state file." |
| | )) |
| | d["_last_saved"] = datetime.fromisoformat(d["_last_saved"]) |
| | return cls(**d) |
| |
|
| | @property |
| | def robust_flags(self) -> Optional[torch.Tensor]: |
| | return self._robust_flags |
| |
|
| | @robust_flags.setter |
| | def robust_flags(self, robust_flags: torch.Tensor) -> None: |
| | self._robust_flags = robust_flags |
| | self.to_disk(force=True) |
| |
|
| | @property |
| | def run_attacks(self) -> Set[str]: |
| | return self._run_attacks |
| |
|
| | def add_run_attack(self, attack: str) -> None: |
| | self._run_attacks.add(attack) |
| | self.to_disk() |
| | |
| | @property |
| | def attacks_to_run(self) -> Set[str]: |
| | return self._attacks_to_run |
| | |
| | @attacks_to_run.setter |
| | def attacks_to_run(self, _: Set[str]) -> None: |
| | raise ValueError("attacks_to_run cannot be set outside of the constructor") |
| |
|
| | @property |
| | def clean_accuracy(self) -> float: |
| | return self._clean_accuracy |
| |
|
| | @clean_accuracy.setter |
| | def clean_accuracy(self, accuracy) -> None: |
| | self._clean_accuracy = accuracy |
| | self.to_disk(force=True) |
| |
|
| | @property |
| | def robust_accuracy(self) -> float: |
| | if self.robust_flags is None: |
| | raise ValueError("robust_flags is not set yet. Start the attack first.") |
| | if self.attacks_to_run - self.run_attacks: |
| | warnings.warn("You are checking `robust_accuracy` before all the attacks" |
| | " have been run.") |
| | return self.robust_flags.float().mean().item() |