| import importlib.util |
| import os |
| import sys |
| import typing as tp |
| import warnings |
| from pathlib import Path |
|
|
| from . import pipeline |
| from .translation import create_translation_models |
|
|
|
|
| def import_model_module(file_path: os.PathLike): |
| module_name = str(Path(file_path).relative_to(os.getcwd())).replace( |
| os.path.sep, "." |
| ) |
| spec = importlib.util.spec_from_file_location(module_name, file_path) |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[module_name] = module |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| models = {} |
| language_to_models = {} |
|
|
| file_dir = Path(__file__).parents[0] |
|
|
| for path in file_dir.glob("*"): |
| if path.is_dir(): |
| model_file_path = path / "model.py" |
| if not model_file_path.exists(): |
| continue |
| module = import_model_module(model_file_path) |
| name_key = "name" |
| get_model_key = "get_model" |
| supported_langs_key = "supported_langs" |
| name = getattr(module, name_key, None) |
| get_model = getattr(module, get_model_key, None) |
| supported_langs = getattr(module, supported_langs_key, None) |
|
|
| def check_attr_exists(attr_name, attr): |
| if attr is None: |
| warnings.warn( |
| f"Module {model_file_path} should define attribute '{attr_name}'" |
| ) |
| return False |
| return True |
|
|
| def check_attr_type(attr_name, attr, type): |
| if isinstance(attr, type): |
| return True |
| warnings.warn( |
| f"'{attr_name}' should be of type {type}, but it is of type {type(attr)}" |
| ) |
| return False |
|
|
| def check_attr_callable(attr_name, attr): |
| if callable(attr): |
| return True |
| warnings.warn(f"'{attr_name}' should be callable") |
| return False |
|
|
| if not check_attr_exists(name_key, name): |
| continue |
| if not check_attr_exists(get_model_key, get_model): |
| continue |
| if not check_attr_exists(supported_langs_key, supported_langs): |
| continue |
| if not check_attr_type(name_key, name, str): |
| continue |
| if not check_attr_callable(get_model_key, get_model): |
| continue |
| if not check_attr_type(supported_langs_key, supported_langs, tp.Iterable): |
| continue |
|
|
| models[name] = get_model |
| for lang in supported_langs: |
| language_to_models.setdefault(lang, {}) |
| language_to_models[lang][name] = get_model |
|
|
| translation_models = create_translation_models(language_to_models["en"]) |
| language_to_models.setdefault("ru", {}).update(translation_models) |
| models.update(translation_models) |
|
|
|
|
| def get_model(name: str): |
| if name not in models: |
| raise KeyError(f"No model with name {name}") |
| return models[name]() |
|
|
|
|
| def get_all_model_names(): |
| return list(models.keys()) |
|
|
|
|
| def get_model_names_by_lang(lang): |
| if lang not in language_to_models: |
| return [] |
| return language_to_models[lang] |
|
|