import json import torch from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from transformers import AutoTokenizer, AutoModelForCausalLM with open("documents.json", "r", encoding="utf-8") as f: documents = json.load(f) docs_list = list(documents.values()) embedding_model = SentenceTransformer("MongoDB/mdbr-leaf-ir") doc_embeddings = embedding_model.encode(docs_list) tokenizer = AutoTokenizer.from_pretrained("PleIAs/Pleias-RAG-350M") language_model = AutoModelForCausalLM.from_pretrained("PleIAs/Pleias-RAG-350M") def recuperar_documentos(consulta, top_k=2, umbral=0.4): consulta_embedding = embedding_model.encode([consulta]) similitudes = cosine_similarity(consulta_embedding, doc_embeddings)[0] resultados = [] for i, score in enumerate(similitudes): if score >= umbral: resultados.append((score, docs_list[i])) resultados.sort(key=lambda x: x[0], reverse=True) documentos_recuperados = [doc for _, doc in resultados[:top_k]] return documentos_recuperados def generar_respuesta(consulta, documentos_recuperados): contexto = " ".join(documentos_recuperados) prompt = f"""Answer the question based only on the context provided Context: {contexto} Question: {consulta} Answer:""" inputs = tokenizer(prompt, return_tensors="pt", truncation=True) with torch.no_grad(): output = language_model.generate( **inputs, max_new_tokens=60, do_sample=False, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id ) respuesta_completa = tokenizer.decode(output[0], skip_special_tokens=True) if "Answer:" in respuesta_completa: respuesta = respuesta_completa.split("Answer:")[-1].strip() else: respuesta = respuesta_completa.strip() return respuesta def preguntar(consulta, top_k=2, umbral=0.4): documentos_recuperados = recuperar_documentos(consulta, top_k=top_k, umbral=umbral) respuesta = generar_respuesta(consulta, documentos_recuperados) return respuesta if __name__ == "__main__": pregunta = "Where is the hospital?" docs = recuperar_documentos(pregunta) respuesta = generar_respuesta(pregunta, docs) print("Documentos encontrados:") for i, doc in enumerate(docs, start=1): print(f"{i}. {doc}") print("\nRespuesta generada:") print(respuesta)