| import subprocess |
| import os |
| import re |
| import sys |
| import filecmp |
| import logging |
| import shutil |
| import sysconfig |
| import datetime |
| import platform |
| import pkg_resources |
|
|
| errors = 0 |
| log = logging.getLogger('sd') |
|
|
| |
| def setup_logging(clean=False): |
| |
| |
| |
|
|
| from rich.theme import Theme |
| from rich.logging import RichHandler |
| from rich.console import Console |
| from rich.pretty import install as pretty_install |
| from rich.traceback import install as traceback_install |
|
|
| console = Console( |
| log_time=True, |
| log_time_format='%H:%M:%S-%f', |
| theme=Theme( |
| { |
| 'traceback.border': 'black', |
| 'traceback.border.syntax_error': 'black', |
| 'inspect.value.border': 'black', |
| } |
| ), |
| ) |
| |
| |
|
|
| current_datetime = datetime.datetime.now() |
| current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S') |
| log_file = os.path.join( |
| os.path.dirname(__file__), |
| f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log', |
| ) |
|
|
| |
| log_directory = os.path.dirname(log_file) |
| os.makedirs(log_directory, exist_ok=True) |
|
|
| level = logging.INFO |
| logging.basicConfig( |
| level=logging.ERROR, |
| format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', |
| filename=log_file, |
| filemode='a', |
| encoding='utf-8', |
| force=True, |
| ) |
| log.setLevel( |
| logging.DEBUG |
| ) |
| pretty_install(console=console) |
| traceback_install( |
| console=console, |
| extra_lines=1, |
| width=console.width, |
| word_wrap=False, |
| indent_guides=False, |
| suppress=[], |
| ) |
| rh = RichHandler( |
| show_time=True, |
| omit_repeated_times=False, |
| show_level=True, |
| show_path=False, |
| markup=False, |
| rich_tracebacks=True, |
| log_time_format='%H:%M:%S-%f', |
| level=level, |
| console=console, |
| ) |
| rh.set_name(level) |
| while log.hasHandlers() and len(log.handlers) > 0: |
| log.removeHandler(log.handlers[0]) |
| log.addHandler(rh) |
|
|
|
|
| def configure_accelerate(run_accelerate=False): |
| |
| |
| |
|
|
| from pathlib import Path |
|
|
| def env_var_exists(var_name): |
| return var_name in os.environ and os.environ[var_name] != '' |
|
|
| log.info('Configuring accelerate...') |
| |
| source_accelerate_config_file = os.path.join( |
| os.path.dirname(os.path.abspath(__file__)), |
| '..', |
| 'config_files', |
| 'accelerate', |
| 'default_config.yaml', |
| ) |
|
|
| if not os.path.exists(source_accelerate_config_file): |
| if run_accelerate: |
| run_cmd('accelerate config') |
| else: |
| log.warning( |
| f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.' |
| ) |
| |
| log.debug( |
| f'Source accelerate config location: {source_accelerate_config_file}' |
| ) |
|
|
| target_config_location = None |
|
|
| log.debug( |
| f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, " |
| f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, " |
| f"USERPROFILE: {os.environ.get('USERPROFILE')}" |
| ) |
| if env_var_exists('HF_HOME'): |
| target_config_location = Path( |
| os.environ['HF_HOME'], 'accelerate', 'default_config.yaml' |
| ) |
| elif env_var_exists('LOCALAPPDATA'): |
| target_config_location = Path( |
| os.environ['LOCALAPPDATA'], |
| 'huggingface', |
| 'accelerate', |
| 'default_config.yaml', |
| ) |
| elif env_var_exists('USERPROFILE'): |
| target_config_location = Path( |
| os.environ['USERPROFILE'], |
| '.cache', |
| 'huggingface', |
| 'accelerate', |
| 'default_config.yaml', |
| ) |
|
|
| log.debug(f'Target config location: {target_config_location}') |
|
|
| if target_config_location: |
| if not target_config_location.is_file(): |
| target_config_location.parent.mkdir(parents=True, exist_ok=True) |
| log.debug( |
| f'Target accelerate config location: {target_config_location}' |
| ) |
| shutil.copyfile( |
| source_accelerate_config_file, target_config_location |
| ) |
| log.info( |
| f'Copied accelerate config file to: {target_config_location}' |
| ) |
| else: |
| if run_accelerate: |
| run_cmd('accelerate config') |
| else: |
| log.warning( |
| 'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' |
| ) |
| else: |
| if run_accelerate: |
| run_cmd('accelerate config') |
| else: |
| log.warning( |
| 'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' |
| ) |
|
|
|
|
| def check_torch(): |
| |
| |
| |
|
|
| |
| if shutil.which('nvidia-smi') is not None or os.path.exists( |
| os.path.join( |
| os.environ.get('SystemRoot') or r'C:\Windows', |
| 'System32', |
| 'nvidia-smi.exe', |
| ) |
| ): |
| log.info('nVidia toolkit detected') |
| elif shutil.which('rocminfo') is not None or os.path.exists( |
| '/opt/rocm/bin/rocminfo' |
| ): |
| log.info('AMD toolkit detected') |
| else: |
| log.info('Using CPU-only Torch') |
|
|
| try: |
| import torch |
|
|
| log.info(f'Torch {torch.__version__}') |
|
|
| |
| if not torch.cuda.is_available(): |
| log.warning('Torch reports CUDA not available') |
| else: |
| if torch.version.cuda: |
| |
| log.info( |
| f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' |
| ) |
| elif torch.version.hip: |
| |
| log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') |
| else: |
| log.warning('Unknown Torch backend') |
|
|
| |
| for device in [ |
| torch.cuda.device(i) for i in range(torch.cuda.device_count()) |
| ]: |
| log.info( |
| f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' |
| ) |
| return int(torch.__version__[0]) |
| except Exception as e: |
| |
| return 0 |
|
|
|
|
| |
| def check_repo_version(): |
| if os.path.exists('.release'): |
| with open(os.path.join('./.release'), 'r', encoding='utf8') as file: |
| release= file.read() |
| |
| log.info(f'Version: {release}') |
| else: |
| log.debug('Could not read release...') |
| |
| |
| def git(arg: str, folder: str = None, ignore: bool = False): |
| |
| |
| |
| |
| git_cmd = os.environ.get('GIT', "git") |
| result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.') |
| txt = result.stdout.decode(encoding="utf8", errors="ignore") |
| if len(result.stderr) > 0: |
| txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") |
| txt = txt.strip() |
| if result.returncode != 0 and not ignore: |
| global errors |
| errors += 1 |
| log.error(f'Error running git: {folder} / {arg}') |
| if 'or stash them' in txt: |
| log.error(f'Local changes detected: check log for details...') |
| log.debug(f'Git output: {txt}') |
|
|
|
|
| def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False): |
| |
| if not quiet: |
| log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}') |
| log.debug(f"Running pip: {arg}") |
| if show_stdout: |
| subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ) |
| else: |
| result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| txt = result.stdout.decode(encoding="utf8", errors="ignore") |
| if len(result.stderr) > 0: |
| txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") |
| txt = txt.strip() |
| if result.returncode != 0 and not ignore: |
| global errors |
| errors += 1 |
| log.error(f'Error running pip: {arg}') |
| log.debug(f'Pip output: {txt}') |
| return txt |
|
|
|
|
| def installed(package, friendly: str = None): |
| |
| |
| |
| |
| |
| |
| package = re.sub(r'\[.*?\]', '', package) |
|
|
| try: |
| if friendly: |
| pkgs = friendly.split() |
| else: |
| pkgs = [ |
| p |
| for p in package.split() |
| if not p.startswith('-') and not p.startswith('=') |
| ] |
| pkgs = [ |
| p.split('/')[-1] for p in pkgs |
| ] |
| |
| for pkg in pkgs: |
| if '>=' in pkg: |
| pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')] |
| elif '==' in pkg: |
| pkg_name, pkg_version = [x.strip() for x in pkg.split('==')] |
| else: |
| pkg_name, pkg_version = pkg.strip(), None |
|
|
| spec = pkg_resources.working_set.by_key.get(pkg_name, None) |
| if spec is None: |
| spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None) |
| if spec is None: |
| spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None) |
|
|
| if spec is not None: |
| version = pkg_resources.get_distribution(pkg_name).version |
| log.debug(f'Package version found: {pkg_name} {version}') |
|
|
| if pkg_version is not None: |
| if '>=' in pkg: |
| ok = version >= pkg_version |
| else: |
| ok = version == pkg_version |
|
|
| if not ok: |
| log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}') |
| return False |
| else: |
| log.debug(f'Package version not found: {pkg_name}') |
| return False |
|
|
| return True |
| except ModuleNotFoundError: |
| log.debug(f'Package not installed: {pkgs}') |
| return False |
|
|
|
|
| |
| def install( |
| |
| |
| |
| package, |
| friendly: str = None, |
| ignore: bool = False, |
| reinstall: bool = False, |
| show_stdout: bool = False, |
| ): |
| |
| package = package.split('#')[0].strip() |
|
|
| if reinstall: |
| global quick_allowed |
| quick_allowed = False |
| if reinstall or not installed(package, friendly): |
| pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout) |
|
|
|
|
|
|
| def process_requirements_line(line, show_stdout: bool = False): |
| |
| |
| package_name = re.sub(r'\[.*?\]', '', line) |
| install(line, package_name, show_stdout=show_stdout) |
|
|
|
|
| def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False): |
| if check_no_verify_flag: |
| log.info(f'Verifying modules instalation status from {requirements_file}...') |
| else: |
| log.info(f'Installing modules from {requirements_file}...') |
| with open(requirements_file, 'r', encoding='utf8') as f: |
| |
| if check_no_verify_flag: |
| lines = [ |
| line.strip() |
| for line in f.readlines() |
| if line.strip() != '' |
| and not line.startswith('#') |
| and line is not None |
| and 'no_verify' not in line |
| ] |
| else: |
| lines = [ |
| line.strip() |
| for line in f.readlines() |
| if line.strip() != '' |
| and not line.startswith('#') |
| and line is not None |
| ] |
|
|
| |
| for line in lines: |
| |
| if line.startswith('-r'): |
| |
| included_file = line[2:].strip() |
| |
| install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout) |
| else: |
| process_requirements_line(line, show_stdout=show_stdout) |
|
|
|
|
| def ensure_base_requirements(): |
| try: |
| import rich |
| except ImportError: |
| install('--upgrade rich', 'rich') |
|
|
|
|
| def run_cmd(run_cmd): |
| try: |
| subprocess.run(run_cmd, shell=True, check=False, env=os.environ) |
| except subprocess.CalledProcessError as e: |
| print(f'Error occurred while running command: {run_cmd}') |
| print(f'Error: {e}') |
|
|
|
|
| |
| def check_python(ignore=True, skip_git=False): |
| |
| |
| |
|
|
| supported_minors = [9, 10] |
| log.info(f'Python {platform.python_version()} on {platform.system()}') |
| if not ( |
| int(sys.version_info.major) == 3 |
| and int(sys.version_info.minor) in supported_minors |
| ): |
| log.error( |
| f'Incompatible Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} required 3.{supported_minors}' |
| ) |
| if not ignore: |
| sys.exit(1) |
| if not skip_git: |
| git_cmd = os.environ.get('GIT', 'git') |
| if shutil.which(git_cmd) is None: |
| log.error('Git not found') |
| if not ignore: |
| sys.exit(1) |
| else: |
| git_version = git('--version', folder=None, ignore=False) |
| log.debug(f'Git {git_version.replace("git version", "").strip()}') |
|
|
|
|
| def delete_file(file_path): |
| if os.path.exists(file_path): |
| os.remove(file_path) |
|
|
|
|
| def write_to_file(file_path, content): |
| try: |
| with open(file_path, 'w') as file: |
| file.write(content) |
| except IOError as e: |
| print(f'Error occurred while writing to file: {file_path}') |
| print(f'Error: {e}') |
|
|
|
|
| def clear_screen(): |
| |
| if os.name == 'nt': |
| os.system('cls') |
| else: |
| os.system('clear') |
|
|
|
|