Source code for pfd.exploration.inference.eval_model

from abc import ABC, abstractmethod
from dataclasses import dataclass
from math import sqrt
from pathlib import Path
from typing import Optional, Union
import dpdata
import numpy as np


[docs] @dataclass class TestReport: name: str = "default_system" system: Optional[dpdata.System] = None atom_numb: int = 0 numb_frame: int = 0 mae_f: float = 0 rmse_f: float = 0 mae_e: float = 0 rmse_e: float = 0 mae_e_atom: float = 0 rmse_e_atom: float = 0 mae_v: float = 0 rmse_v: float = 0 lab_e: Optional[np.ndarray] = None pred_e: Optional[np.ndarray] = None lab_f: Optional[np.ndarray] = None pred_f: Optional[np.ndarray] = None lab_v: Optional[np.ndarray] = None pred_v: Optional[np.ndarray] = None
[docs] def report(self): return { "name": self.name, "atom_numb": self.atom_numb, "numb_frame": self.numb_frame, "MAE_force": self.mae_f, "RMSE_force": self.rmse_f, "MAE_energy": self.mae_e, "RMSE_energy": self.rmse_e, "MAE_energy_per_at": self.mae_e_atom, "RMSE_energy_per_at": self.rmse_e_atom, "MAE_virial": self.mae_v, "RMSE_virial": self.rmse_v, }
[docs] class TestReports: def __init__(self, name: str = "default_reports"): self._reports = [] self.name = name def __iter__(self): return iter(self._reports) def __getitem__(self, index): return self._reports[index] def __len__(self): return len(self._reports)
[docs] def add_report(self, report: TestReport): self._reports.append(report)
[docs] def get_weighted_rmse_f(self): if len(self._reports) > 0: return np.sqrt( sum( res.numb_frame * res.rmse_f**2 * 3 * res.atom_numb for res in self._reports ) / sum( res.numb_frame * 3 * res.atom_numb for res in self._reports if res.atom_numb > 0 ) )
[docs] def get_weighted_rmse_e_atom(self): if len(self._reports) > 0: return np.sqrt( sum(res.numb_frame * res.rmse_e_atom**2 for res in self._reports) / sum(res.numb_frame for res in self._reports) )
[docs] def get_systems(self): if len(self._reports) > 0: return [res.system for res in self._reports] else: return []
[docs] def get_and_output_systems(self, prefix: Union[Path, str] = "."): if isinstance(prefix, str): prefix = Path(prefix) prefix.mkdir(exist_ok=True) systems = [] for res in self._reports: path = prefix / res.name res.system.to("deepmd/npy", path) systems.append(path) return systems
[docs] def sub_reports(self, index): reports = TestReports() if len(self._reports) > 0: for ii in index: reports.add_report(self._reports[ii]) return reports
[docs] def get_nframes(self): n_frame = 0 for ii in self._reports: if getattr(ii, "system"): n_frame += ii.system.get_nframes() return n_frame
[docs] class EvalModel(ABC): """The base class for inference and evaluation. Args: ABC (_type_): _description_ Returns: _type_: _description_ """ __ModelTypes = {} def __init__( self, model: Optional[Union[Path, str]] = None, data: Optional[Union[Path, str]] = None, **kwargs ): self._data = None self._model = None if model: self.load_model(model, **kwargs) if data: self.read_data(data) @property def model(self): return self._model @property def data(self): return self._data
[docs] @staticmethod def register(key: str): """Register a model interface. Used as decorators Args: key (str): key of the model """ def decorator(object): EvalModel.__ModelTypes[key] = object return object return decorator
[docs] @staticmethod def get_driver(key: str): """Get a driver for ModelEval Args: key (str): _description_ Raises: RuntimeError: _description_ Returns: _type_: _description_ """ try: return EvalModel.__ModelTypes[key] except KeyError as e: raise RuntimeError("unknown driver: " + key) from e
[docs] @staticmethod def get_drivers() -> dict: """Get all drivers Returns: dict: all drivers """ return EvalModel.__ModelTypes
[docs] @abstractmethod def load_model(self, model: Union[Path, str], **kwargs): pass
[docs] @abstractmethod def read_data(self, data: Union[Path, str], **kwargs): pass
[docs] @abstractmethod def read_data_unlabeled(self, data: Union[Path, str], **kwargs): pass
[docs] @abstractmethod def evaluate(self, **kwargs): pass
[docs] @abstractmethod def inference(self, **kwargs): pass
[docs] def clear_data(self): self._data = None
[docs] def clear_model(self): self._model = None