| import importlib |
| import torch |
|
|
| from modules import shared |
|
|
|
|
| def check_for_npu(): |
| if importlib.util.find_spec("torch_npu") is None: |
| return False |
| import torch_npu |
|
|
| try: |
| |
| _ = torch_npu.npu.device_count() |
| return torch.npu.is_available() |
| except RuntimeError: |
| return False |
|
|
|
|
| def get_npu_device_string(): |
| if shared.cmd_opts.device_id is not None: |
| return f"npu:{shared.cmd_opts.device_id}" |
| return "npu:0" |
|
|
|
|
| def torch_npu_gc(): |
| with torch.npu.device(get_npu_device_string()): |
| torch.npu.empty_cache() |
|
|
|
|
| has_npu = check_for_npu() |
|
|