File size: 2,446 Bytes
cd839b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModelForCausalLM

with open("documents.json", "r", encoding="utf-8") as f:
    documents = json.load(f)

docs_list = list(documents.values())

embedding_model = SentenceTransformer("MongoDB/mdbr-leaf-ir")
doc_embeddings = embedding_model.encode(docs_list)

tokenizer = AutoTokenizer.from_pretrained("PleIAs/Pleias-RAG-350M")
language_model = AutoModelForCausalLM.from_pretrained("PleIAs/Pleias-RAG-350M")

def recuperar_documentos(consulta, top_k=2, umbral=0.4):
    consulta_embedding = embedding_model.encode([consulta])
    similitudes = cosine_similarity(consulta_embedding, doc_embeddings)[0]

    resultados = []

    for i, score in enumerate(similitudes):
        if score >= umbral:
            resultados.append((score, docs_list[i]))

    resultados.sort(key=lambda x: x[0], reverse=True)

    documentos_recuperados = [doc for _, doc in resultados[:top_k]]
    return documentos_recuperados

def generar_respuesta(consulta, documentos_recuperados):
    contexto = " ".join(documentos_recuperados)

    prompt = f"""Answer the question based only on the context provided
Context: {contexto}
Question: {consulta}
Answer:"""

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)

    with torch.no_grad():
        output = language_model.generate(
            **inputs,
            max_new_tokens=60,
            do_sample=False,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.eos_token_id
        )

    respuesta_completa = tokenizer.decode(output[0], skip_special_tokens=True)

    if "Answer:" in respuesta_completa:
        respuesta = respuesta_completa.split("Answer:")[-1].strip()
    else:
        respuesta = respuesta_completa.strip()

    return respuesta

def preguntar(consulta, top_k=2, umbral=0.4):
    documentos_recuperados = recuperar_documentos(consulta, top_k=top_k, umbral=umbral)
    respuesta = generar_respuesta(consulta, documentos_recuperados)
    return respuesta

if __name__ == "__main__":
    pregunta = "Where is the hospital?"
    docs = recuperar_documentos(pregunta)
    respuesta = generar_respuesta(pregunta, docs)

    print("Documentos encontrados:")
    for i, doc in enumerate(docs, start=1):
        print(f"{i}. {doc}")

    print("\nRespuesta generada:")
    print(respuesta)