DemoApp / src /utils /get_model.py
Reality8081's picture
Update src
1dde759
from src.model.baseline_model import BartSummarizer
from src.model.baseline_extractive_model import BartExtractiveSummarizer
import torch
from huggingface_hub import hf_hub_download
from src.model.extabs import EXTABSModel
import gc
import os
from safetensors.torch import load_file
active_model_info = {
"repo_id": None,
"model": None
}
def clear_memory():
"""Hàm dọn dẹp bộ nhớ triệt để trước khi load model mới"""
if active_model_info["model"] is not None:
del active_model_info["model"]
active_model_info["model"] = None
active_model_info["repo_id"] = None
# Ép Python dọn rác
gc.collect()
# Dọn VRAM nếu đang dùng GPU
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_summarizer(repo_id: str):
if active_model_info["repo_id"] != repo_id:
clear_memory() # XÓA SẠCH MODEL CŨ
print(f"Đang tải mô hình Baseline từ repo: {repo_id}...")
active_model_info["model"] = BartSummarizer(model_path=repo_id)
active_model_info["repo_id"] = repo_id
return active_model_info["model"]
def get_extractive_model(repo_id: str, base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"):
"""Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub"""
if active_model_info["repo_id"] != repo_id:
clear_memory() # XÓA SẠCH MODEL CŨ
print(f"Đang tải mô hình Extractive từ repo: {repo_id}...")
# Khởi tạo khung kiến trúc trống
model = BartExtractiveSummarizer.from_pretrained(repo_id, model_name=base_model_name)
# Sử dụng hf_hub_download để kéo file trọng số về local cache
# Load trọng số vào model
model = model.to(torch.float32)
model.to(device)
model.eval() # Chuyển mô hình sang chế độ inference
active_model_info["model"] = model
active_model_info["repo_id"] = repo_id
return active_model_info["model"]
def get_extractive_abstractive(repo_id: str,base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"):
"""Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub"""
if active_model_info["repo_id"] != repo_id:
clear_memory() # XÓA SẠCH MODEL CŨ
print(f"Đang tải mô hình Extractive từ repo: {repo_id}...")
# Khởi tạo khung kiến trúc trống
model = EXTABSModel.from_pretrained(repo_id, model_name=base_model_name)
# Sử dụng hf_hub_download để kéo file trọng số về local cache
# Load trọng số vào model
model = model.to(torch.float32)
model.to(device)
model.eval() # Chuyển mô hình sang chế độ inference
active_model_info["model"] = model
active_model_info["repo_id"] = repo_id
return active_model_info["model"]