| import logging |
| import math |
| import os |
| from contextlib import contextmanager |
|
|
| import timm.models.hub as timm_hub |
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
|
|
|
|
| def is_dist_avail_and_initialized(): |
| if not dist.is_available(): |
| return False |
| if not dist.is_initialized(): |
| return False |
| return True |
|
|
|
|
| def get_rank(): |
| if not is_dist_avail_and_initialized(): |
| return 0 |
| return dist.get_rank() |
|
|
|
|
| def is_main_process(): |
| return get_rank() == 0 |
|
|
|
|
| def download_cached_file(url, check_hash=True, progress=False): |
| """ |
| Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. |
| If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. |
| """ |
| def get_cached_file_path(): |
| |
| parts = torch.hub.urlparse(url) |
| filename = os.path.basename(parts.path) |
| cached_file = os.path.join(timm_hub.get_cache_dir(), filename) |
|
|
| return cached_file |
|
|
| if is_main_process(): |
| timm_hub.download_cached_file(url, check_hash, progress) |
|
|
| if is_dist_avail_and_initialized(): |
| dist.barrier() |
|
|
| return get_cached_file_path() |
|
|
|
|
| @contextmanager |
| def all_logging_disabled(highest_level=logging.CRITICAL): |
| """ |
| A context manager that will prevent any logging messages |
| triggered during the body from being processed. |
| :param highest_level: the maximum logging level in use. |
| This would only need to be changed if a custom level greater than CRITICAL |
| is defined. |
| """ |
| |
| |
| |
| |
|
|
| previous_level = logging.root.manager.disable |
|
|
| logging.disable(highest_level) |
|
|
| try: |
| yield |
| finally: |
| logging.disable(previous_level) |
|
|
|
|
| class LoRALinear(nn.Linear): |
| def __init__(self, |
| in_features: int, |
| out_features: int, |
| bias: bool = True, |
| device=None, |
| dtype=None, |
| lora_r=8, |
| lora_alpha=16, |
| lora_dropout=0.05, |
| **kwargs) -> None: |
| super().__init__(in_features, out_features, bias, device, dtype) |
| self.lora_r = lora_r |
| self.lora_alpha = lora_alpha |
| if lora_dropout > 0.: |
| self.lora_dropout = nn.Dropout(p=lora_dropout) |
| else: |
| self.lora_dropout = lambda x: x |
| self.lora_scaling = self.lora_alpha / self.lora_r |
|
|
| self.lora_A = nn.Linear(in_features, |
| self.lora_r, |
| bias=False, |
| device=device, |
| dtype=dtype) |
| self.lora_B = nn.Linear(self.lora_r, |
| out_features, |
| bias=False, |
| device=device, |
| dtype=dtype) |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| if hasattr(self, 'lora_A'): |
| |
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B.weight) |
| |
|
|
| def forward(self, x): |
| orig_type = x.dtype |
| res = super().forward(x) |
| x = x.float() |
| res += self.lora_B(self.lora_A( |
| self.lora_dropout(x))) * self.lora_scaling |
| return res.to(orig_type) |
|
|