| import gradio as gr |
| from transformers import pipeline |
| from huggingface_hub import InferenceClient, login, snapshot_download |
| from langchain_community.vectorstores import FAISS, DistanceStrategy |
| from langchain_huggingface import HuggingFaceEmbeddings |
| import os |
| import pandas as pd |
| from datetime import datetime |
|
|
|
|
| """ |
| For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
| """ |
| HF_TOKEN=os.getenv('TOKEN') |
| login(HF_TOKEN) |
|
|
| |
| |
| |
| model = "swiss-ai/Apertus-8B-Instruct-2509" |
|
|
| client = InferenceClient(model) |
|
|
| folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd()) |
|
|
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") |
|
|
| vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE) |
|
|
| df = pd.read_csv("bger_cedh_db 1954-2024.csv") |
|
|
| def respond( |
| message, |
| history: list[tuple[str, str]], |
| system_message, |
| max_tokens, |
| temperature, |
| top_p, |
| score, |
| ): |
| |
|
|
| print(datetime.now()) |
| print(system_message) |
|
|
| prompt_template = "Improve or translate the following user's prompt to {language} giving only the new prompt\ |
| without explanations or additional text and if you can't improve it, just return the same prompt, do not extrapolate: " |
| |
| prompt_en = client.text_generation(prompt_template.format(language="English") + message) |
| prompt_de = client.text_generation(prompt_template.format(language="German") + message) |
| prompt_fr = client.text_generation(prompt_template.format(language="French") + message) |
| prompt_it = client.text_generation(prompt_template.format(language="Italian") + message) |
|
|
| |
| |
| |
| |
|
|
| documents_en = vector_db.similarity_search_with_score(prompt_en, k=4) |
| print(prompt_en) |
| |
| documents_de = vector_db.similarity_search_with_score(prompt_de, k=4) |
| print(prompt_de) |
| documents_fr = vector_db.similarity_search_with_score(prompt_fr, k=4) |
| print(prompt_fr) |
|
|
| documents_it = vector_db.similarity_search_with_score(prompt_it, k=4) |
| print(prompt_it) |
|
|
| documents = documents_en + documents_de + documents_fr + documents_it |
|
|
| documents = sorted(documents, key=lambda x: x[1])[:4] |
| |
| spacer = " \n" |
| context = "" |
| nb_char = 2000 |
|
|
| |
| print(f"* Documents found: {len(documents)}") |
|
|
| for doc in documents: |
| case_text = df[df["case_url"] == doc[0].metadata["case_url"]].case_text.values[0] |
| index = case_text.find(doc[0].page_content) |
| start = max(0, index - nb_char) |
| end = min(len(case_text), index + len(doc[0].page_content) + nb_char) |
| case_text_summary = case_text[start:end] |
| |
| context += "#######" + spacer |
| context += "# Case number: " + doc[0].metadata["case_nb"] + spacer |
| context += "# Case source: " + ("Swiss Federal Court" if doc[0].metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer |
| context += "# Case date: " + doc[0].metadata["case_date"] + spacer |
| context += "# Case url: " + doc[0].metadata["case_url"] + spacer |
| |
| context += "Case extract: " + case_text_summary + spacer |
|
|
| |
| |
| |
| system_message += f"""A user is asking you the following question: {message} |
| Please answer the user in the same language that he used in his question using ONLY the following given context not any prior knowledge or information found on the internet. |
| # Context: |
| The following case extracts have been found in either Swiss Federal Court or European Court of Human Rights cases and could fit the question: |
| {context} |
| # Task: |
| If the retrieved context is not relevant cases or the issue has not been addressed within the context, just say "I can't find enough relevant information". |
| Don't make up an answer or give irrelevant information not requested by the user. |
| Otherwise, if relevant cases were found, answer in the user's question's language using the context that you found relevant and reference the sources, including the urls and dates. |
| # Instructions: |
| Always answer the user using the language used in his question: {message} |
| """ |
|
|
| print(system_message) |
| messages = [{"role": "system", "content": system_message}] |
| |
| |
| |
| |
| |
| |
|
|
| messages.append({"role": "user", "content": message}) |
|
|
| response = "" |
|
|
| for message in client.chat_completion( |
| messages, |
| max_tokens=max_tokens, |
| stream=True, |
| temperature=temperature, |
| top_p=top_p, |
| ): |
| token = message.choices[0].delta.content |
|
|
| response += token |
| yield response |
|
|
|
|
| """ |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
| """ |
| demo = gr.ChatInterface( |
| respond, |
| additional_inputs=[ |
| gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"), |
| gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"), |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.95, |
| step=0.05, |
| label="Top-p (nucleus sampling)", |
| ), |
| gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"), |
| ], |
| description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(debug=True) |