| from transformers import AutoModelForQuestionAnswering, AutoTokenizer |
| import streamlit as st |
| import json |
| from predict import run_prediction |
|
|
| st.set_page_config(layout="wide") |
|
|
| model_list = ['akdeniz27/roberta-base-cuad', |
| 'akdeniz27/roberta-large-cuad', |
| 'akdeniz27/deberta-v2-xlarge-cuad'] |
| st.sidebar.header("Select CUAD Model") |
| model_checkpoint = st.sidebar.radio("", model_list) |
|
|
| if model_checkpoint == "akdeniz27/deberta-v2-xlarge-cuad": import sentencepiece |
|
|
| st.sidebar.write("Project: https://www.atticusprojectai.org/cuad") |
| st.sidebar.write("Git Hub: https://github.com/TheAtticusProject/cuad") |
| st.sidebar.write("CUAD Dataset: https://huggingface.co/datasets/cuad") |
| st.sidebar.write("License: CC BY 4.0 https://creativecommons.org/licenses/by/4.0/") |
|
|
| @st.cache(allow_output_mutation=True) |
| def load_model(): |
| model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint) |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint , use_fast=False) |
| return model, tokenizer |
|
|
| @st.cache(allow_output_mutation=True) |
| def load_questions(): |
| with open('test.json') as json_file: |
| data = json.load(json_file) |
| |
|
|
| questions = [] |
| for i, q in enumerate(data['data'][0]['paragraphs'][0]['qas']): |
| question = data['data'][0]['paragraphs'][0]['qas'][i]['question'] |
| questions.append(question) |
| return questions |
|
|
| @st.cache(allow_output_mutation=True) |
| def load_contracts(): |
| with open('test.json') as json_file: |
| data = json.load(json_file) |
|
|
| contracts = [] |
| for i, q in enumerate(data['data']): |
| contract = ' '.join(data['data'][i]['paragraphs'][0]['context'].split()) |
| contracts.append(contract) |
| return contracts |
|
|
| model, tokenizer = load_model() |
| questions = load_questions() |
| contracts = load_contracts() |
|
|
| contract = contracts[0] |
|
|
| st.header("Contract Understanding Atticus Dataset (CUAD) Demo") |
| st.write("Based on https://github.com/marshmellow77/cuad-demo") |
|
|
|
|
| selected_question = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions) |
| question_set = [questions[0], selected_question] |
|
|
| contract_type = st.radio("Select Contract", ("Sample Contract", "New Contract")) |
| if contract_type == "Sample Contract": |
| sample_contract_num = st.slider("Select Sample Contract #") |
| contract = contracts[sample_contract_num] |
| with st.expander(f"Sample Contract #{sample_contract_num}"): |
| st.write(contract) |
| else: |
| contract = st.text_area("Input New Contract", "", height=256) |
|
|
| Run_Button = st.button("Run", key=None) |
| if Run_Button == True and not len(contract)==0 and not len(question_set)==0: |
| predictions = run_prediction(question_set, contract, 'akdeniz27/roberta-base-cuad') |
| |
| for i, p in enumerate(predictions): |
| if i != 0: st.write(f"Question: {question_set[int(p)]}\n\nAnswer: {predictions[p]}\n\n") |
| |