| | from prompt import TA_prompt |
| | import re |
| | from utils import generate_response, run_code |
| |
|
| |
|
| | def post_process_code(code, question): |
| | func_name = code.split("(")[0].split("def")[-1].strip() |
| | parameters = code.split("\n")[0].split(f"def {func_name}")[-1][1:-2].split(",") |
| | if '' in parameters: |
| | parameters.remove('') |
| | values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)] |
| | values = [int(v) for v in values] |
| | arguments = list(zip(parameters, values)) |
| |
|
| | arg_string = "" |
| | for param, val in arguments: |
| | arg_string += f"{param}={val}," |
| | func_call = f"\nprint({func_name}({arg_string[:-1]}))" |
| | code += func_call |
| | return code |
| |
|
| |
|
| | def solve_ta(question): |
| | question = question.strip() |
| | question = "Human: " + question |
| | query = TA_prompt + question |
| | query = query.strip() |
| | query += "\n" |
| | code = generate_response(query, 0.9) |
| | n = len(TA_prompt.strip()) |
| | code = code[n:].strip().split("-----")[0] |
| | |
| | splitting_string = "```" if "```python" not in code else "```python" |
| | if "```" in code: |
| | code = code.split(splitting_string)[1].split("```")[0].strip() |
| | |
| | code = post_process_code(code, question) |
| | print(code) |
| | |
| | if "input(" in code: |
| | return None, code |
| | pred = None |
| | try: |
| | pred = run_code(code) |
| | except Exception as ex: |
| | return None, code |
| | return pred, code |
| | else: |
| | res = re.findall(r"Assistant:(.*)", code, re.DOTALL)[0].split("Human:")[0] |
| | return res.strip(), "" |
| |
|
| |
|