| from __future__ import annotations |
|
|
| """configuration_sedd.py |
| ==================================== |
| HuggingFace *Transformers* configuration class for the `SEDD` architecture. |
| |
| This mirrors the structure of other community models in π€ Transformers so that |
| `AutoConfig` can correctly instantiate the model. |
| |
| The default values roughly reproduce the "small" setup shipped in |
| `configs/model/small.yaml` of this repository. |
| """ |
|
|
| from typing import Any, Dict |
|
|
| from transformers.configuration_utils import PretrainedConfig |
|
|
| try: |
| |
| from omegaconf import OmegaConf |
| except ImportError: |
| OmegaConf = None |
|
|
| __all__ = [ |
| "SEDDConfig", |
| ] |
|
|
|
|
| class SEDDConfig(PretrainedConfig): |
| """Configuration class for the SEDD score-based model. |
| |
| Parameters |
| ---------- |
| tokens: |
| Size of the tokenizer vocabulary (default: 50257 β GPT-2 vocab). |
| graph_type: |
| Type of token graph to use ("absorb" matches the reference implementation). |
| model_hidden_size: |
| Dimension of the transformer hidden states. |
| model_cond_dim: |
| Dimension of the conditional embedding for the noise level. |
| model_length: |
| Maximum (fixed) sequence length the model was trained with. |
| model_n_blocks: |
| Number of *DDiT* blocks in the network. |
| model_n_heads: |
| Number of attention heads per *DDiT* block. |
| model_scale_by_sigma: |
| Whether to scale the output logits by the noise level (see |
| `SEDD.forward`). |
| model_dropout: |
| Drop-out probability used throughout the network. |
| tie_word_embeddings: |
| Standard Transformer flag β not used by SEDD but required by the base |
| class. Must be present so that the value is serialised in the resulting |
| JSON file. |
| """ |
|
|
| model_type: str = "sedd" |
|
|
| def __init__( |
| self, |
| *, |
| tokens: int = 50257, |
| |
| graph_type: str = "absorb", |
| |
| model_hidden_size: int = 768, |
| model_cond_dim: int = 128, |
| model_length: int = 1024, |
| model_n_blocks: int = 12, |
| model_n_heads: int = 12, |
| model_scale_by_sigma: bool = True, |
| model_dropout: float = 0.10, |
| |
| tie_word_embeddings: bool = False, |
| **kwargs, |
| ) -> None: |
| |
| |
| super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
|
|
| |
| |
| self.tokens = tokens |
| self.graph_type = graph_type |
|
|
| self.model_hidden_size = model_hidden_size |
| self.model_cond_dim = model_cond_dim |
| self.model_length = model_length |
| self.model_n_blocks = model_n_blocks |
| self.model_n_heads = model_n_heads |
| self.model_scale_by_sigma = model_scale_by_sigma |
| self.model_dropout = model_dropout |
|
|
| |
| |
| |
|
|
| def to_hydra(self): |
| """Convert this *flat* configuration to the nested OmegaConf structure |
| expected by the reference `SEDD` implementation. |
| """ |
| if OmegaConf is None: |
| raise RuntimeError("`omegaconf` is required to build a Hydra config") |
|
|
| nested: Dict[str, Any] = { |
| "tokens": self.tokens, |
| "graph": { |
| "type": self.graph_type, |
| }, |
| "model": { |
| "hidden_size": self.model_hidden_size, |
| "cond_dim": self.model_cond_dim, |
| "length": self.model_length, |
| "n_blocks": self.model_n_blocks, |
| "n_heads": self.model_n_heads, |
| "scale_by_sigma": self.model_scale_by_sigma, |
| "dropout": self.model_dropout, |
| }, |
| } |
| return OmegaConf.create(nested) |