| from __future__ import annotations
|
|
|
| import os
|
| import base64
|
| import json
|
| import time
|
| import logging
|
| import folder_paths
|
| import glob
|
| import comfy.utils
|
| from aiohttp import web
|
| from PIL import Image
|
| from io import BytesIO
|
| from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
|
|
|
|
| class ModelFileManager:
|
| def __init__(self) -> None:
|
| self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
|
|
| def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
| return self.cache.get(key, default)
|
|
|
| def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
| self.cache[key] = value
|
|
|
| def clear_cache(self):
|
| self.cache.clear()
|
|
|
| def add_routes(self, routes):
|
|
|
| @routes.get("/experiment/models")
|
| async def get_model_folders(request):
|
| model_types = list(folder_paths.folder_names_and_paths.keys())
|
| folder_black_list = ["configs", "custom_nodes"]
|
| output_folders: list[dict] = []
|
| for folder in model_types:
|
| if folder in folder_black_list:
|
| continue
|
| output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
| return web.json_response(output_folders)
|
|
|
|
|
| @routes.get("/experiment/models/{folder}")
|
| async def get_all_models(request):
|
| folder = request.match_info.get("folder", None)
|
| if not folder in folder_paths.folder_names_and_paths:
|
| return web.Response(status=404)
|
| files = self.get_model_file_list(folder)
|
| return web.json_response(files)
|
|
|
| @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
| async def get_model_preview(request):
|
| folder_name = request.match_info.get("folder", None)
|
| path_index = int(request.match_info.get("path_index", None))
|
| filename = request.match_info.get("filename", None)
|
|
|
| if not folder_name in folder_paths.folder_names_and_paths:
|
| return web.Response(status=404)
|
|
|
| folders = folder_paths.folder_names_and_paths[folder_name]
|
| folder = folders[0][path_index]
|
| full_filename = os.path.join(folder, filename)
|
|
|
| previews = self.get_model_previews(full_filename)
|
| default_preview = previews[0] if len(previews) > 0 else None
|
| if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
| return web.Response(status=404)
|
|
|
| try:
|
| with Image.open(default_preview) as img:
|
| img_bytes = BytesIO()
|
| img.save(img_bytes, format="WEBP")
|
| img_bytes.seek(0)
|
| return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
| except:
|
| return web.Response(status=404)
|
|
|
| def get_model_file_list(self, folder_name: str):
|
| folder_name = map_legacy(folder_name)
|
| folders = folder_paths.folder_names_and_paths[folder_name]
|
| output_list: list[dict] = []
|
|
|
| for index, folder in enumerate(folders[0]):
|
| if not os.path.isdir(folder):
|
| continue
|
| out = self.cache_model_file_list_(folder)
|
| if out is None:
|
| out = self.recursive_search_models_(folder, index)
|
| self.set_cache(folder, out)
|
| output_list.extend(out[0])
|
|
|
| return output_list
|
|
|
| def cache_model_file_list_(self, folder: str):
|
| model_file_list_cache = self.get_cache(folder)
|
|
|
| if model_file_list_cache is None:
|
| return None
|
| if not os.path.isdir(folder):
|
| return None
|
| if os.path.getmtime(folder) != model_file_list_cache[1]:
|
| return None
|
| for x in model_file_list_cache[1]:
|
| time_modified = model_file_list_cache[1][x]
|
| folder = x
|
| if os.path.getmtime(folder) != time_modified:
|
| return None
|
|
|
| return model_file_list_cache
|
|
|
| def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
| if not os.path.isdir(directory):
|
| return [], {}, time.perf_counter()
|
|
|
| excluded_dir_names = [".git"]
|
|
|
| include_hidden_files = False
|
|
|
| result: list[str] = []
|
| dirs: dict[str, float] = {}
|
|
|
| for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
| subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
| if not include_hidden_files:
|
| subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
| filenames = [f for f in filenames if not f.startswith(".")]
|
|
|
| filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
|
|
| for file_name in filenames:
|
| try:
|
| relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
| result.append(relative_path)
|
| except:
|
| logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
| continue
|
|
|
| for d in subdirs:
|
| path: str = os.path.join(dirpath, d)
|
| try:
|
| dirs[path] = os.path.getmtime(path)
|
| except FileNotFoundError:
|
| logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
| continue
|
|
|
| return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
|
|
| def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
| dirname = os.path.dirname(filepath)
|
|
|
| if not os.path.exists(dirname):
|
| return []
|
|
|
| basename = os.path.splitext(filepath)[0]
|
| match_files = glob.glob(f"{basename}.*", recursive=False)
|
| image_files = filter_files_content_types(match_files, "image")
|
| safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
| safetensors_metadata = {}
|
|
|
| result: list[str | BytesIO] = []
|
|
|
| for filename in image_files:
|
| _basename = os.path.splitext(filename)[0]
|
| if _basename == basename:
|
| result.append(filename)
|
| if _basename == f"{basename}.preview":
|
| result.append(filename)
|
|
|
| if safetensors_file:
|
| safetensors_filepath = os.path.join(dirname, safetensors_file)
|
| header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
| if header:
|
| safetensors_metadata = json.loads(header)
|
| safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
| if safetensors_images:
|
| safetensors_images = json.loads(safetensors_images)
|
| for image in safetensors_images:
|
| result.append(BytesIO(base64.b64decode(image)))
|
|
|
| return result
|
|
|
| def __exit__(self, exc_type, exc_value, traceback):
|
| self.clear_cache()
|
|
|