|
|
| import time |
| import os |
| import json |
| from werkzeug.utils import secure_filename |
| import re |
| import ast |
| import sqlite3 |
| import random |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from llmware.models import ModelCatalog |
| from llmware.prompts import Prompt |
|
|
| def model_test_run_general(): |
|
|
| t0 = time.time() |
|
|
| model_name = "llmware/slim-sql-1b-v0" |
|
|
| print("update: model_name - ", model_name) |
|
|
| custom_hf_model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True) |
|
|
| hf_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
| |
| model = ModelCatalog().load_hf_generative_model(custom_hf_model, hf_tokenizer, instruction_following=False, |
| prompt_wrapper="human_bot") |
|
|
| model.temperature = 0.3 |
| |
| print("\nupdate: Starting Generative Instruct Custom Fine-tuned Test") |
|
|
| t1 = time.time() |
|
|
| print("update: time loading model - ", t1 - t0) |
|
|
| fp = "" |
| fn = "sql_test_100_simple_s.jsonl" |
|
|
| opened_file = open(os.path.join(fp, fn), "r") |
|
|
| prompt_list = [] |
|
|
| for i, rows in enumerate(opened_file): |
| |
| rows = json.loads(rows) |
| new_entry = {"question": rows["question"], |
| "answer": rows["answer"], |
| "context": rows["context"]} |
|
|
| prompt_list.append(new_entry) |
|
|
| random.shuffle(prompt_list) |
|
|
| total_response_output = [] |
| perfect_match = 0 |
|
|
| for i, entries in enumerate(prompt_list): |
| prompt = entries["question"] |
| context = re.sub("[\n\r]","", entries["context"]) |
| context = re.sub("\s+", " ", context) |
| context = re.sub("\"", "", context) |
|
|
| answer = "" |
|
|
| if "answer" in entries: |
| answer = entries["answer"] |
|
|
| output = model.inference(prompt, add_context=context, add_prompt_engineering=True) |
|
|
| print("\nupdate: model question - ", prompt) |
|
|
| llm_response = re.sub("['\"]", "", output["llm_response"]) |
| answer = re.sub("['\"]", "", answer) |
|
|
| print("update: model response - ", i, llm_response) |
| print("update: model gold answer - ", answer) |
|
|
| if llm_response.strip().lower() == answer.strip().lower(): |
| perfect_match += 1 |
| print("update: 100% MATCH") |
|
|
| print("update: perfect match accuracy - ", perfect_match / (i+1)) |
|
|
| core_output = {"number": i, |
| "llm_response": output["llm_response"], |
| "gold_answer": answer, |
| "prompt": prompt, |
| "usage": output["usage"]} |
|
|
| total_response_output.append(core_output) |
|
|
| t2 = time.time() |
|
|
| print("update: total processing time: ", t2-t1) |
|
|
| return total_response_output |
|
|
| output = model_test_run_general() |
|
|