| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Dict, Optional |
|
|
| import torch |
|
|
| |
| substrings_to_ignore = [ |
| "_extra_state", |
| ] |
|
|
|
|
| def get_partial_state_dict( |
| state_dict: Dict[str, torch.Tensor], |
| prefix: str, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Get a partial state dict with keys starting with the given prefix |
| """ |
| return {k: v for k, v in state_dict.items() if k.startswith(prefix)} |
|
|
|
|
| def process_state_dict( |
| state_dict: Dict[str, torch.Tensor], |
| device: str = None, |
| dtype: torch.dtype = None, |
| prefix_to_remove: Optional[str] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| - Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8) |
| - Move tensors to specified device and dtype if provided |
| |
| Args: |
| state_dict (Dict[str, torch.Tensor]): The state dict to process |
| device (str, optional): The device to move tensors to. Defaults to None. |
| dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None. |
| prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None. |
| |
| Returns: |
| Dict[str, torch.Tensor]: The processed state dict |
| """ |
| new_state_dict = {} |
| tensor_kwargs = {} |
| if device is not None: |
| tensor_kwargs["device"] = device |
| if dtype is not None: |
| tensor_kwargs["dtype"] = dtype |
|
|
| for key, value in state_dict.items(): |
| |
| skip = False |
| for substr in substrings_to_ignore: |
| if substr in key: |
| skip = True |
| break |
| if skip: |
| continue |
| if len(tensor_kwargs) > 0: |
| value = value.to(**tensor_kwargs) |
| if prefix_to_remove is not None and key.startswith(prefix_to_remove): |
| key = key[len(prefix_to_remove) :] |
| new_state_dict[key] = value |
| return new_state_dict |
|
|