| | import re
|
| |
|
| | languge_settings = {
|
| | 'python': {
|
| | 'full_name': 'Python',
|
| | 'indent': 4,
|
| | },
|
| | 'cpp': {
|
| | 'full_name': 'cpp',
|
| | 'indent': 0,
|
| | 'main': "int main()",
|
| | },
|
| | 'java': {
|
| | 'full_name': 'Java',
|
| | 'indent': 4,
|
| | 'main': "public static void main",
|
| | },
|
| | 'cs': {
|
| | 'full_name': "csharp",
|
| | 'indent': 0,
|
| | 'main': "public static void Main",
|
| | },
|
| | 'php': {
|
| | 'full_name': "PHP",
|
| | 'indent': 0,
|
| | },
|
| | 'ts': {
|
| | 'full_name': "TypeScript",
|
| | 'indent': 0,
|
| | },
|
| | 'js': {
|
| | 'full_name': "JavaScript",
|
| | 'indent': 0
|
| | },
|
| | 'sh': {
|
| | 'full_name': "Bash",
|
| | 'indent': 0
|
| | }
|
| | }
|
| |
|
| | def get_function_name(question: str, lang: str):
|
| | func_lines = [x for x in question.strip().split('\n') if x.strip()]
|
| |
|
| | if lang.lower() == 'python':
|
| | func_idx = [i for i in range(len(func_lines)) if func_lines[i].startswith("def ")][-1]
|
| | func_name = func_lines[func_idx].split('(')[0].strip()
|
| | func_prefix = "\n".join(func_lines[:func_idx])
|
| | return func_name, func_prefix
|
| |
|
| | func_name = func_lines[-1].split('{')[0].strip()
|
| | func_prefix = "\n".join(func_lines[:-1])
|
| | return func_name, func_prefix
|
| |
|
| | def extract_generation_code(example: str, lang_code: str, verbose: bool=False):
|
| | task_id = example['task_id']
|
| | output = example.get('output', example.get("gpt_completion"))
|
| | question = example["prompt"].strip()
|
| | setting = languge_settings[lang_code]
|
| | lang = setting['full_name']
|
| | indent = setting['indent']
|
| |
|
| | try:
|
| | code_block: str = re.findall(f'```{lang.lower()}\n(.*?)```', output, re.DOTALL | re.IGNORECASE)[0]
|
| | if verbose:
|
| | print(">>> Task: {}\n{}".format(task_id, code_block))
|
| |
|
| |
|
| | if setting.get('main', None) and setting['main'] in code_block:
|
| | main_start = code_block.index(setting['main'])
|
| | code_block = code_block[:main_start]
|
| |
|
| | func_name, func_prefix = get_function_name(question, lang)
|
| |
|
| | try:
|
| | start = code_block.lower().index(func_name.lower())
|
| | indent = 0
|
| | while start - indent >= 0 and code_block[start - indent-1] == ' ':
|
| | indent += 1
|
| |
|
| | try:
|
| | end = code_block.rindex('\n' + ' '*indent + '}')
|
| | except:
|
| | end = len(code_block)
|
| | except:
|
| | start = 0
|
| | try:
|
| | end = code_block.rindex('\n' + ' '*indent + '}')
|
| | except:
|
| | end = len(code_block)
|
| |
|
| | body = code_block[start:end]
|
| |
|
| | if lang_code.lower() in ['php', 'ts', 'js']:
|
| | body += '\n' + ' '*indent + '}'
|
| |
|
| | generation = func_prefix + '\n' + body + '\n'
|
| | example['generation'] = generation
|
| |
|
| | except Exception as ex:
|
| | print("Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format(
|
| | ex, task_id, output
|
| | ))
|
| | example['generation'] = example['prompt'] + '\n' + output
|
| |
|
| | return example
|
| |
|
| | def cleanup_code(
|
| | code: str,
|
| | language_type: str = None,
|
| | dataset: str = None,
|
| | issft: bool = False,
|
| | stop_words = []
|
| | ):
|
| | """
|
| | Cleans up the generated code.
|
| | """
|
| |
|
| | if language_type.lower() == "python":
|
| | if issft:
|
| | code = _clean_python_code_for_sft(code)
|
| | stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
|
| | code = _truncate_code_at_stopwords(code, stop_words)
|
| | elif language_type.lower() == "ts":
|
| | code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"])
|
| | else:
|
| | code = _truncate_code_at_stopwords(code, stop_words)
|
| |
|
| | return code
|
| |
|
| | def _clean_python_code_for_sft(code):
|
| | code = code.replace("\r", "")
|
| | if "```python" in code:
|
| | code_start_idx = code.index("```python")
|
| | code = code[code_start_idx:].replace("```python", "").strip()
|
| | end_idx = code.find("```") if "```" in code else len(code)
|
| | code = code[:end_idx].strip()
|
| |
|
| | return code
|
| |
|
| | def _truncate_code_at_stopwords(code, stop_words):
|
| | min_stop_idx = len(code)
|
| | for stop_word in stop_words:
|
| | stop_index = code.find(stop_word)
|
| | if 0 <= stop_index < min_stop_idx:
|
| | min_stop_idx = stop_index
|
| | return code[:min_stop_idx] |