| | import torch |
| | import os |
| | import requests |
| | from tqdm import tqdm |
| | from diffusers import DDPMScheduler, EulerDiscreteScheduler |
| | from typing import Any, Optional, Union |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def make_1step_sched(pretrained_path): |
| | noise_scheduler_1step = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler") |
| | noise_scheduler_1step.set_timesteps(1, device="cuda") |
| | noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() |
| | return noise_scheduler_1step |
| |
|
| |
|
| | def my_lora_fwd(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
| | self._check_forward_args(x, *args, **kwargs) |
| | adapter_names = kwargs.pop("adapter_names", None) |
| |
|
| | if self.disable_adapters: |
| | if self.merged: |
| | self.unmerge() |
| | result = self.base_layer(x, *args, **kwargs) |
| | elif adapter_names is not None: |
| | result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) |
| | elif self.merged: |
| | result = self.base_layer(x, *args, **kwargs) |
| | else: |
| | result = self.base_layer(x, *args, **kwargs) |
| | torch_result_dtype = result.dtype |
| | for active_adapter in self.active_adapters: |
| | if active_adapter not in self.lora_A.keys(): |
| | continue |
| | lora_A = self.lora_A[active_adapter] |
| | lora_B = self.lora_B[active_adapter] |
| | dropout = self.lora_dropout[active_adapter] |
| | scaling = self.scaling[active_adapter] |
| | x = x.to(lora_A.weight.dtype) |
| |
|
| | if not self.use_dora[active_adapter]: |
| | _tmp = lora_A(dropout(x)) |
| | if isinstance(lora_A, torch.nn.Conv2d): |
| | _tmp = torch.einsum('...khw,...kr->...rhw', _tmp, self.de_mod) |
| | elif isinstance(lora_A, torch.nn.Linear): |
| | _tmp = torch.einsum('...lk,...kr->...lr', _tmp, self.de_mod) |
| | else: |
| | raise NotImplementedError('only conv and linear are supported yet.') |
| |
|
| | result = result + lora_B(_tmp) * scaling |
| | else: |
| | x = dropout(x) |
| | result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) |
| |
|
| | result = result.to(torch_result_dtype) |
| |
|
| | return result |
| |
|
| | def download_url(url, outf): |
| | if not os.path.exists(outf): |
| | print(f"Downloading checkpoint to {outf}") |
| | response = requests.get(url, stream=True) |
| | total_size_in_bytes = int(response.headers.get('content-length', 0)) |
| | block_size = 1024 |
| | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) |
| | with open(outf, 'wb') as file: |
| | for data in response.iter_content(block_size): |
| | progress_bar.update(len(data)) |
| | file.write(data) |
| | progress_bar.close() |
| | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: |
| | print("ERROR, something went wrong") |
| | print(f"Downloaded successfully to {outf}") |
| | else: |
| | print(f"Skipping download, {outf} already exists") |
| |
|