python-ra3 / rag_engine.py
mauroserralvo's picture
initial commit
cd839b3
import json
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModelForCausalLM
with open("documents.json", "r", encoding="utf-8") as f:
documents = json.load(f)
docs_list = list(documents.values())
embedding_model = SentenceTransformer("MongoDB/mdbr-leaf-ir")
doc_embeddings = embedding_model.encode(docs_list)
tokenizer = AutoTokenizer.from_pretrained("PleIAs/Pleias-RAG-350M")
language_model = AutoModelForCausalLM.from_pretrained("PleIAs/Pleias-RAG-350M")
def recuperar_documentos(consulta, top_k=2, umbral=0.4):
consulta_embedding = embedding_model.encode([consulta])
similitudes = cosine_similarity(consulta_embedding, doc_embeddings)[0]
resultados = []
for i, score in enumerate(similitudes):
if score >= umbral:
resultados.append((score, docs_list[i]))
resultados.sort(key=lambda x: x[0], reverse=True)
documentos_recuperados = [doc for _, doc in resultados[:top_k]]
return documentos_recuperados
def generar_respuesta(consulta, documentos_recuperados):
contexto = " ".join(documentos_recuperados)
prompt = f"""Answer the question based only on the context provided
Context: {contexto}
Question: {consulta}
Answer:"""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
with torch.no_grad():
output = language_model.generate(
**inputs,
max_new_tokens=60,
do_sample=False,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
respuesta_completa = tokenizer.decode(output[0], skip_special_tokens=True)
if "Answer:" in respuesta_completa:
respuesta = respuesta_completa.split("Answer:")[-1].strip()
else:
respuesta = respuesta_completa.strip()
return respuesta
def preguntar(consulta, top_k=2, umbral=0.4):
documentos_recuperados = recuperar_documentos(consulta, top_k=top_k, umbral=umbral)
respuesta = generar_respuesta(consulta, documentos_recuperados)
return respuesta
if __name__ == "__main__":
pregunta = "Where is the hospital?"
docs = recuperar_documentos(pregunta)
respuesta = generar_respuesta(pregunta, docs)
print("Documentos encontrados:")
for i, doc in enumerate(docs, start=1):
print(f"{i}. {doc}")
print("\nRespuesta generada:")
print(respuesta)