Spaces:
Running
Running
| from src.model.baseline_model import BartSummarizer | |
| from src.model.baseline_extractive_model import BartExtractiveSummarizer | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from src.model.extabs import EXTABSModel | |
| import gc | |
| import os | |
| from safetensors.torch import load_file | |
| active_model_info = { | |
| "repo_id": None, | |
| "model": None | |
| } | |
| def clear_memory(): | |
| """Hàm dọn dẹp bộ nhớ triệt để trước khi load model mới""" | |
| if active_model_info["model"] is not None: | |
| del active_model_info["model"] | |
| active_model_info["model"] = None | |
| active_model_info["repo_id"] = None | |
| # Ép Python dọn rác | |
| gc.collect() | |
| # Dọn VRAM nếu đang dùng GPU | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_summarizer(repo_id: str): | |
| if active_model_info["repo_id"] != repo_id: | |
| clear_memory() # XÓA SẠCH MODEL CŨ | |
| print(f"Đang tải mô hình Baseline từ repo: {repo_id}...") | |
| active_model_info["model"] = BartSummarizer(model_path=repo_id) | |
| active_model_info["repo_id"] = repo_id | |
| return active_model_info["model"] | |
| def get_extractive_model(repo_id: str, base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"): | |
| """Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub""" | |
| if active_model_info["repo_id"] != repo_id: | |
| clear_memory() # XÓA SẠCH MODEL CŨ | |
| print(f"Đang tải mô hình Extractive từ repo: {repo_id}...") | |
| # Khởi tạo khung kiến trúc trống | |
| model = BartExtractiveSummarizer.from_pretrained(repo_id, model_name=base_model_name) | |
| # Sử dụng hf_hub_download để kéo file trọng số về local cache | |
| # Load trọng số vào model | |
| model = model.to(torch.float32) | |
| model.to(device) | |
| model.eval() # Chuyển mô hình sang chế độ inference | |
| active_model_info["model"] = model | |
| active_model_info["repo_id"] = repo_id | |
| return active_model_info["model"] | |
| def get_extractive_abstractive(repo_id: str,base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"): | |
| """Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub""" | |
| if active_model_info["repo_id"] != repo_id: | |
| clear_memory() # XÓA SẠCH MODEL CŨ | |
| print(f"Đang tải mô hình Extractive từ repo: {repo_id}...") | |
| # Khởi tạo khung kiến trúc trống | |
| model = EXTABSModel.from_pretrained(repo_id, model_name=base_model_name) | |
| # Sử dụng hf_hub_download để kéo file trọng số về local cache | |
| # Load trọng số vào model | |
| model = model.to(torch.float32) | |
| model.to(device) | |
| model.eval() # Chuyển mô hình sang chế độ inference | |
| active_model_info["model"] = model | |
| active_model_info["repo_id"] = repo_id | |
| return active_model_info["model"] |