Spaces:
Running
Running
| import json | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| with open("documents.json", "r", encoding="utf-8") as f: | |
| documents = json.load(f) | |
| docs_list = list(documents.values()) | |
| embedding_model = SentenceTransformer("MongoDB/mdbr-leaf-ir") | |
| doc_embeddings = embedding_model.encode(docs_list) | |
| tokenizer = AutoTokenizer.from_pretrained("PleIAs/Pleias-RAG-350M") | |
| language_model = AutoModelForCausalLM.from_pretrained("PleIAs/Pleias-RAG-350M") | |
| def recuperar_documentos(consulta, top_k=2, umbral=0.4): | |
| consulta_embedding = embedding_model.encode([consulta]) | |
| similitudes = cosine_similarity(consulta_embedding, doc_embeddings)[0] | |
| resultados = [] | |
| for i, score in enumerate(similitudes): | |
| if score >= umbral: | |
| resultados.append((score, docs_list[i])) | |
| resultados.sort(key=lambda x: x[0], reverse=True) | |
| documentos_recuperados = [doc for _, doc in resultados[:top_k]] | |
| return documentos_recuperados | |
| def generar_respuesta(consulta, documentos_recuperados): | |
| contexto = " ".join(documentos_recuperados) | |
| prompt = f"""Answer the question based only on the context provided | |
| Context: {contexto} | |
| Question: {consulta} | |
| Answer:""" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| output = language_model.generate( | |
| **inputs, | |
| max_new_tokens=60, | |
| do_sample=False, | |
| repetition_penalty=1.2, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| respuesta_completa = tokenizer.decode(output[0], skip_special_tokens=True) | |
| if "Answer:" in respuesta_completa: | |
| respuesta = respuesta_completa.split("Answer:")[-1].strip() | |
| else: | |
| respuesta = respuesta_completa.strip() | |
| return respuesta | |
| def preguntar(consulta, top_k=2, umbral=0.4): | |
| documentos_recuperados = recuperar_documentos(consulta, top_k=top_k, umbral=umbral) | |
| respuesta = generar_respuesta(consulta, documentos_recuperados) | |
| return respuesta | |
| if __name__ == "__main__": | |
| pregunta = "Where is the hospital?" | |
| docs = recuperar_documentos(pregunta) | |
| respuesta = generar_respuesta(pregunta, docs) | |
| print("Documentos encontrados:") | |
| for i, doc in enumerate(docs, start=1): | |
| print(f"{i}. {doc}") | |
| print("\nRespuesta generada:") | |
| print(respuesta) |