| | import argparse |
| | import json |
| | import os |
| | import os.path as osp |
| | import re |
| |
|
| | import mmengine |
| | from mmengine import Config |
| | from mmengine.utils import mkdir_or_exist |
| |
|
| | from opencompass.datasets.humanevalx import _clean_up_code |
| | from opencompass.utils import (dataset_abbr_from_cfg, get_infer_output_path, |
| | get_logger, model_abbr_from_cfg) |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser( |
| | description='Collect Humanevalx dataset predictions.') |
| | parser.add_argument('config', help='Config file path') |
| | parser.add_argument('-r', |
| | '--reuse', |
| | nargs='?', |
| | type=str, |
| | const='latest', |
| | help='Reuse previous outputs & results, and run any ' |
| | 'missing jobs presented in the config. If its ' |
| | 'argument is not specified, the latest results in ' |
| | 'the work_dir will be reused. The argument should ' |
| | 'also be a specific timestamp, e.g. 20230516_144254'), |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | _LANGUAGE_NAME_DICT = { |
| | 'cpp': 'CPP', |
| | 'go': 'Go', |
| | 'java': 'Java', |
| | 'js': 'JavaScript', |
| | 'python': 'Python', |
| | 'rust': 'Rust', |
| | } |
| | FAILED = 0 |
| | SUCCEED = 1 |
| |
|
| |
|
| | def gpt_python_postprocess(ori_prompt: str, text: str) -> str: |
| | """Better answer postprocessor for better instruction-aligned models like |
| | GPT.""" |
| | if '```' in text: |
| | blocks = re.findall(r'```(.*?)```', text, re.DOTALL) |
| | if len(blocks) == 0: |
| | text = text.split('```')[1] |
| | else: |
| | text = blocks[0] |
| | if not text.startswith('\n'): |
| | text = text[max(text.find('\n') + 1, 0):] |
| |
|
| | match_ori = re.search(r'def(.*?)\(', ori_prompt) |
| | match = re.search(r'def(.*?)\(', text) |
| | if match: |
| | if match.group() == match_ori.group(): |
| | text = re.sub('def(.*?)\n', '', text, count=1) |
| |
|
| | for c_index, c in enumerate(text[:5]): |
| | if c != ' ': |
| | text = ' ' * (4 - c_index) + text |
| | break |
| |
|
| | text = text.split('\n\n\n')[0] |
| | return text |
| |
|
| |
|
| | def wizardcoder_postprocess(text: str) -> str: |
| | """Postprocess for WizardCoder Models.""" |
| | if '```' in text: |
| | blocks = re.findall(r'```(.*?)```', text, re.DOTALL) |
| | if len(blocks) == 0: |
| | text = text.split('```')[1] |
| | else: |
| | text = blocks[0] |
| | if not text.startswith('\n'): |
| | text = text[max(text.find('\n') + 1, 0):] |
| | else: |
| | match = re.search(r'Here(.*?)\n', text) |
| | if match: |
| | text = re.sub('Here(.*?)\n', '', text, count=1) |
| |
|
| | return text |
| |
|
| |
|
| | def collect_preds(filename: str): |
| | |
| | root, ext = osp.splitext(filename) |
| | partial_filename = root + '_0' + ext |
| | |
| | if not osp.exists(osp.realpath(filename)) and not osp.exists( |
| | osp.realpath(partial_filename)): |
| | print(f'No predictions found for {filename}') |
| | return FAILED, None, None |
| | else: |
| | if osp.exists(osp.realpath(filename)): |
| | preds = mmengine.load(filename) |
| | pred_strs = [ |
| | preds[str(i)]['prediction'] for i in range(len(preds)) |
| | ] |
| | ori_prompt_strs = [ |
| | preds[str(i)]['origin_prompt'] for i in range(len(preds)) |
| | ] |
| | else: |
| | filename = partial_filename |
| | pred_strs = [] |
| | ori_prompt_strs = [] |
| | i = 1 |
| | while osp.exists(osp.realpath(filename)): |
| | preds = mmengine.load(filename) |
| | filename = root + f'_{i}' + ext |
| | i += 1 |
| | pred_strs += [ |
| | preds[str(i)]['prediction'] for i in range(len(preds)) |
| | ] |
| | ori_prompt_strs += [ |
| | preds[str(i)]['origin_prompt'] for i in range(len(preds)) |
| | ] |
| | return SUCCEED, ori_prompt_strs, pred_strs |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | |
| | logger = get_logger(log_level='INFO') |
| | cfg = Config.fromfile(args.config) |
| | cfg.setdefault('work_dir', './outputs/default/') |
| |
|
| | assert args.reuse, 'Please provide the experienment work dir.' |
| | if args.reuse: |
| | if args.reuse == 'latest': |
| | if not os.path.exists(cfg.work_dir) or not os.listdir( |
| | cfg.work_dir): |
| | logger.warning('No previous results to reuse!') |
| | else: |
| | dirs = os.listdir(cfg.work_dir) |
| | dir_time_str = sorted(dirs)[-1] |
| | else: |
| | dir_time_str = args.reuse |
| | logger.info(f'Reusing experiements from {dir_time_str}') |
| | |
| | cfg['work_dir'] = osp.join(cfg.work_dir, dir_time_str) |
| |
|
| | for model in cfg.models: |
| | model_abbr = model_abbr_from_cfg(model) |
| | for dataset in cfg.datasets: |
| | dataset_abbr = dataset_abbr_from_cfg(dataset) |
| | filename = get_infer_output_path( |
| | model, dataset, osp.join(cfg.work_dir, 'predictions')) |
| |
|
| | succeed, ori_prompt_strs, pred_strs = collect_preds(filename) |
| | if not succeed: |
| | continue |
| |
|
| | |
| | for k, v in _LANGUAGE_NAME_DICT.items(): |
| | if k in dataset_abbr: |
| | lang = k |
| | task = v |
| | break |
| |
|
| | |
| | if model_abbr in [ |
| | 'WizardCoder-1B-V1.0', |
| | 'WizardCoder-3B-V1.0', |
| | 'WizardCoder-15B-V1.0', |
| | 'WizardCoder-Python-13B-V1.0', |
| | 'WizardCoder-Python-34B-V1.0', |
| | ]: |
| | predictions = [{ |
| | 'task_id': f'{task}/{i}', |
| | 'generation': wizardcoder_postprocess(pred), |
| | } for i, pred in enumerate(pred_strs)] |
| | elif 'CodeLlama' not in model_abbr and lang == 'python': |
| | predictions = [{ |
| | 'task_id': |
| | f'{task}/{i}', |
| | 'generation': |
| | gpt_python_postprocess(ori_prompt, pred), |
| | } for i, (ori_prompt, |
| | pred) in enumerate(zip(ori_prompt_strs, pred_strs))] |
| | else: |
| | predictions = [{ |
| | 'task_id': f'{task}/{i}', |
| | 'generation': _clean_up_code(pred, lang), |
| | } for i, pred in enumerate(pred_strs)] |
| |
|
| | |
| | result_file_path = os.path.join(cfg['work_dir'], 'humanevalx', |
| | model_abbr, |
| | f'humanevalx_{lang}.json') |
| | if osp.exists(result_file_path): |
| | logger.info( |
| | f'File exists for {model_abbr}, skip copy from predictions.' |
| | ) |
| | else: |
| | mkdir_or_exist(osp.split(result_file_path)[0]) |
| | with open(result_file_path, 'w') as f: |
| | for pred in predictions: |
| | f.write(json.dumps(pred) + '\n') |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|