| | |
| | |
| | |
| | |
| | |
| |
|
| | from abc import ABC |
| |
|
| | from torch import nn, Tensor |
| |
|
| |
|
| | class ModelWrapper(ABC, nn.Module): |
| | """ |
| | This class is used to wrap around another model, adding custom forward pass logic. |
| | """ |
| |
|
| | def __init__(self, model: nn.Module): |
| | super().__init__() |
| | self.model = model |
| |
|
| | def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: |
| | r""" |
| | This method defines how inputs should be passed through the wrapped model. |
| | Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input, |
| | along with any additional keyword arguments. |
| | |
| | Optional things to do here: |
| | - check that t is in the dimensions that the model is expecting. |
| | - add a custom forward pass logic. |
| | - call the wrapped model. |
| | |
| | | given x, t |
| | | returns the model output for input x at time t, with extra information `extra`. |
| | |
| | Args: |
| | x (Tensor): input data to the model (batch_size, ...). |
| | t (Tensor): time (batch_size). |
| | **extras: additional information forwarded to the model, e.g., text condition. |
| | |
| | Returns: |
| | Tensor: model output. |
| | """ |
| | return self.model(x=x, t=t, **extras) |
| |
|