| import json |
| from dataclasses import dataclass, field, asdict |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Optional, Set |
| import warnings |
|
|
| import torch |
|
|
|
|
| @dataclass |
| class EvaluationState: |
| _attacks_to_run: Set[str] |
| path: Optional[Path] = None |
| _run_attacks: Set[str] = field(default_factory=set) |
| _robust_flags: Optional[torch.Tensor] = None |
| _last_saved: datetime = datetime(1, 1, 1) |
| _SAVE_TIMEOUT: int = 60 |
| _clean_accuracy: float = float("nan") |
|
|
| def to_disk(self, force: bool = False) -> None: |
| seconds_since_last_save = (datetime.now() - |
| self._last_saved).total_seconds() |
| if self.path is None or (seconds_since_last_save < self._SAVE_TIMEOUT |
| and not force): |
| return |
| self._last_saved = datetime.now() |
| d = asdict(self) |
| if self.robust_flags is not None: |
| d["_robust_flags"] = d["_robust_flags"].cpu().tolist() |
| d["_run_attacks"] = list(self._run_attacks) |
| with self.path.open("w", ) as f: |
| json.dump(d, f, default=str) |
|
|
| @classmethod |
| def from_disk(cls, path: Path) -> "EvaluationState": |
| with path.open("r") as f: |
| d = json.load(f) |
| d["_robust_flags"] = torch.tensor(d["_robust_flags"], dtype=torch.bool) |
| d["path"] = Path(d["path"]) |
| if path != d["path"]: |
| warnings.warn( |
| UserWarning( |
| "The given path is different from the one found in the state file." |
| )) |
| d["_last_saved"] = datetime.fromisoformat(d["_last_saved"]) |
| return cls(**d) |
|
|
| @property |
| def robust_flags(self) -> Optional[torch.Tensor]: |
| return self._robust_flags |
|
|
| @robust_flags.setter |
| def robust_flags(self, robust_flags: torch.Tensor) -> None: |
| self._robust_flags = robust_flags |
| self.to_disk(force=True) |
|
|
| @property |
| def run_attacks(self) -> Set[str]: |
| return self._run_attacks |
|
|
| def add_run_attack(self, attack: str) -> None: |
| self._run_attacks.add(attack) |
| self.to_disk() |
| |
| @property |
| def attacks_to_run(self) -> Set[str]: |
| return self._attacks_to_run |
| |
| @attacks_to_run.setter |
| def attacks_to_run(self, _: Set[str]) -> None: |
| raise ValueError("attacks_to_run cannot be set outside of the constructor") |
|
|
| @property |
| def clean_accuracy(self) -> float: |
| return self._clean_accuracy |
|
|
| @clean_accuracy.setter |
| def clean_accuracy(self, accuracy) -> None: |
| self._clean_accuracy = accuracy |
| self.to_disk(force=True) |
|
|
| @property |
| def robust_accuracy(self) -> float: |
| if self.robust_flags is None: |
| raise ValueError("robust_flags is not set yet. Start the attack first.") |
| if self.attacks_to_run - self.run_attacks: |
| warnings.warn("You are checking `robust_accuracy` before all the attacks" |
| " have been run.") |
| return self.robust_flags.float().mean().item() |