| | |
| |
|
| | |
| | |
| | |
| |
|
| | """ |
| | Base class for trainable models. |
| | """ |
| |
|
| | from abc import ABCMeta, abstractmethod |
| | from copy import copy |
| |
|
| | from omegaconf import OmegaConf |
| | from torch import nn |
| |
|
| |
|
| | class BaseModel(nn.Module, metaclass=ABCMeta): |
| |
|
| | required_data_keys = [] |
| | strict_conf = True |
| |
|
| | def __init__(self, conf): |
| | """Perform some logic and call the _init method of the child model.""" |
| | super().__init__() |
| | self.conf = conf |
| | OmegaConf.set_readonly(conf, True) |
| | OmegaConf.set_struct(conf, True) |
| | self.required_data_keys = copy(self.required_data_keys) |
| | self._init(conf) |
| |
|
| | def forward(self, data): |
| | """Check the data and call the _forward method of the child model.""" |
| |
|
| | def recursive_key_check(expected, given): |
| | for key in expected: |
| | assert key in given, f"Missing key {key} in data" |
| | if isinstance(expected, dict): |
| | recursive_key_check(expected[key], given[key]) |
| |
|
| | recursive_key_check(self.required_data_keys, data) |
| | return self._forward(data) |
| |
|
| | @abstractmethod |
| | def _init(self, conf): |
| | """To be implemented by the child class.""" |
| | raise NotImplementedError |
| |
|
| | @abstractmethod |
| | def _forward(self, data): |
| | """To be implemented by the child class.""" |
| | raise NotImplementedError |
| |
|
| | def loss(self, pred, data): |
| | """To be implemented by the child class.""" |
| | raise NotImplementedError |
| |
|
| | def metrics(self): |
| | return {} |
| |
|