OminiControlRotation / omini /rotation /rotation_config.py
nvan15's picture
Batch upload part 19
b816a2c verified
from dataclasses import dataclass, field
from typing import List, Optional
from peft.config import PeftConfig
@dataclass
class RotationConfig(PeftConfig):
"""
Configuration class for Rotation-based Parameter-Efficient Fine-Tuning.
This configuration stores all parameters needed to apply the Rotation method
(based on Cayley transformation) to a model's linear layers.
Args:
r (`int`):
The rank parameter for the low-rank approximation in rotation matrices.
T (`float`, *optional*, defaults to 1.0):
Temperature parameter for the transformation.
num_rotations (`int`, *optional*, defaults to 4):
Number of rotation matrices to use in parallel.
target_modules (`Union[List[str], str]`):
Module names to apply rotation to (e.g., ["q_proj", "v_proj"]).
target_modules_to_skip (`Union[List[str], str]`, *optional*):
Module names to skip when applying rotation.
modules_to_save (`Union[List[str], str]`, *optional*):
Modules to save in addition to rotation parameters.
layers_to_transform (`Union[List[int], int]`, *optional*):
Layers to transform. If None, all layers matching target_modules are transformed.
apply_before (`bool`, *optional*, defaults to False):
If True, apply rotation before the base linear layer. If False, apply after.
"""
peft_type: str = field(default="ROTATION", init=False)
target_modules: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of module names to apply rotation to (e.g., ['q_proj', 'v_proj', 'linear'])"
},
)
target_modules_to_skip: Optional[List[str]] = field(
default=None,
metadata={"help": "List of module names to skip when applying rotation"},
)
modules_to_save: Optional[List[str]] = field(
default=None,
metadata={"help": "List of modules to save in addition to rotation parameters"},
)
r: int = field(
default=8,
metadata={"help": "Rank parameter for low-rank approximation"},
)
T: float = field(
default=1.0,
metadata={"help": "Temperature parameter for Cayley transformation"},
)
num_rotations: int = field(
default=4,
metadata={"help": "Number of rotation matrices to use in parallel"},
)
bias: str = field(
default="none",
metadata={
"help": "Bias training configuration. Options: 'none', 'all', 'rotation_only'"
}
)
layers_to_transform: Optional[List[int]] = field(
default=None,
metadata={"help": "Layers to transform. If None, all matching layers are transformed"},
)
def __post_init__(self):
self.peft_type = "ROTATION"
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)
self.target_modules_to_skip = (
set(self.target_modules_to_skip)
if isinstance(self.target_modules_to_skip, list)
else self.target_modules_to_skip
)