from src.model.baseline_model import BartSummarizer from src.model.baseline_extractive_model import BartExtractiveSummarizer import torch from huggingface_hub import hf_hub_download from src.model.extabs import EXTABSModel import gc import os from safetensors.torch import load_file active_model_info = { "repo_id": None, "model": None } def clear_memory(): """Hàm dọn dẹp bộ nhớ triệt để trước khi load model mới""" if active_model_info["model"] is not None: del active_model_info["model"] active_model_info["model"] = None active_model_info["repo_id"] = None # Ép Python dọn rác gc.collect() # Dọn VRAM nếu đang dùng GPU if torch.cuda.is_available(): torch.cuda.empty_cache() def get_summarizer(repo_id: str): if active_model_info["repo_id"] != repo_id: clear_memory() # XÓA SẠCH MODEL CŨ print(f"Đang tải mô hình Baseline từ repo: {repo_id}...") active_model_info["model"] = BartSummarizer(model_path=repo_id) active_model_info["repo_id"] = repo_id return active_model_info["model"] def get_extractive_model(repo_id: str, base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"): """Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub""" if active_model_info["repo_id"] != repo_id: clear_memory() # XÓA SẠCH MODEL CŨ print(f"Đang tải mô hình Extractive từ repo: {repo_id}...") # Khởi tạo khung kiến trúc trống model = BartExtractiveSummarizer.from_pretrained(repo_id, model_name=base_model_name) # Sử dụng hf_hub_download để kéo file trọng số về local cache # Load trọng số vào model model = model.to(torch.float32) model.to(device) model.eval() # Chuyển mô hình sang chế độ inference active_model_info["model"] = model active_model_info["repo_id"] = repo_id return active_model_info["model"] def get_extractive_abstractive(repo_id: str,base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"): """Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub""" if active_model_info["repo_id"] != repo_id: clear_memory() # XÓA SẠCH MODEL CŨ print(f"Đang tải mô hình Extractive từ repo: {repo_id}...") # Khởi tạo khung kiến trúc trống model = EXTABSModel.from_pretrained(repo_id, model_name=base_model_name) # Sử dụng hf_hub_download để kéo file trọng số về local cache # Load trọng số vào model model = model.to(torch.float32) model.to(device) model.eval() # Chuyển mô hình sang chế độ inference active_model_info["model"] = model active_model_info["repo_id"] = repo_id return active_model_info["model"]