| import argparse |
| import os |
| from collections import defaultdict |
| from io import StringIO |
|
|
| import pandas as pd |
| from tqdm import tqdm |
|
|
| from perplexity import get_model_for |
| from subsampler import PerplexitySubsampler |
|
|
|
|
| def process_files( |
| directory, |
| reject_level, |
| model_override, |
| output_file, |
| group_by_prefix_lang, |
| prefix_lang_mapping=None, |
| ratio=None, |
| ratio_per_lang=None, |
| pa=None, |
| pb=None, |
| include=None, |
| ): |
| if ratio or ratio_per_lang: |
| rows = ["doc_type,model,language,reject,bad,medium,good,norm,mean,std"] |
| else: |
| rows = ["doc_type,model,language,reject,bad,medium,good"] |
| files = os.listdir(directory) |
| grouped_files = defaultdict(list) |
| if prefix_lang_mapping is None: |
| prefix_lang_mapping = {} |
|
|
| |
| description = "Processing files" |
| if group_by_prefix_lang: |
| description = "Processing files in groups" |
| for file in files: |
| parts = file.split('_') |
| prefix = parts[0] |
| if include and prefix not in include: |
| continue |
| lang = parts[-1].split(".")[0][:2] |
| group_key = prefix_lang_mapping.get(f"{prefix}_{lang}", f"{prefix}_{lang}") |
| grouped_files[group_key].append(file) |
| file_groups = grouped_files.values() |
| else: |
| file_groups = [] |
| for file in files: |
| if include and not any(file.startswith(prefix) for prefix in include): |
| continue |
| file_groups.append([file]) |
|
|
| if output_file: |
| progress = tqdm(file_groups, desc=description) |
| else: |
| progress = file_groups |
| print(rows[0]) |
| |
| for group in progress: |
| combined_perplexities = pd.DataFrame() |
| doc_type, lang = None, None |
|
|
| for file in group: |
| if not doc_type or not lang: |
| parts = file.split('_') |
| doc_type = file.split('_')[0] |
| lang = parts[-1].split(".")[0][:2] |
| doc_type, lang = prefix_lang_mapping.get(f"{doc_type}_{lang}", f"{doc_type}_{lang}").rsplit("_", 1) |
| perp = pd.read_json(os.path.join(directory, file), lines=True) |
| perplexities = pd.read_json(StringIO(perp["perplexities"].to_json(lines=True, orient="records")), lines=True) |
| combined_perplexities = pd.concat([combined_perplexities, perplexities], ignore_index=True) |
|
|
| if model_override: |
| model = model_override |
| else: |
| model, _ = get_model_for(doc_type) |
| model_with_suffix = f"{model}_pp" |
|
|
| |
| reject = round(combined_perplexities[model_with_suffix].quantile(q=reject_level), 2) |
| bad = round(combined_perplexities[model_with_suffix].quantile(q=0.75), 2) |
| medium = round(combined_perplexities[model_with_suffix].quantile(q=0.50), 2) |
| good = round(combined_perplexities[model_with_suffix].quantile(q=0.25), 2) |
|
|
| if ratio: |
| subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) |
| subsampler.set(ratio=ratio, pa=pa, pb=pb) |
| norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev |
| sampling_stats = f",{norm},{mean},{std}" |
| elif ratio_per_lang: |
| subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) |
| subsampler.set(ratio=ratio_per_lang.get(lang, ratio or 1.0), pa=pa, pb=pb) |
| norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev |
| sampling_stats = f",{norm},{mean},{std}" |
| else: |
| sampling_stats = "" |
|
|
| row = f"{doc_type},{model},{lang},{reject},{bad},{medium},{good}{sampling_stats}" |
| if output_file: |
| rows.append(row) |
| else: |
| print(row) |
|
|
|
|
| if output_file: |
| with open(output_file, "w") as f: |
| for row in rows: |
| f.write(f"{row}\n") |
|
|
|
|
| def main(): |
| """" |
| Each doc_type prefix needs to have an "no" lang, even of there's no real data. |
| These rows are crucial for the rest of the process. |
| """ |
| parser = argparse.ArgumentParser(description="Process files and compute perplexity metrics.") |
| parser.add_argument('directory', type=str, help='Directory containing the files to process') |
| parser.add_argument('--reject_level', type=float, default=0.95, help='Rejection quantile level (default: 0.95)') |
| parser.add_argument('--model_override', type=str, help='Override the model used') |
| parser.add_argument('--output_file', type=str, help='Output file in CSV format. If not given, prints to standard output.') |
| parser.add_argument('--group_by_prefix_lang', action='store_true', help='Group and calculate quantiles for files with the same prefix and language') |
| parser.add_argument('--overwrite_prefix_lang', type=str, help='Overwrite the assignment of languages to doc_type prefixes, e.g., "starcoder_en:starcoder_code,hplt_en:hplt_no"') |
| parser.add_argument('--sampling_ratio', type=float, help='Ratio of documents to keep for sampling. If passed, it generate distribution statistics (norm, mean, std) needed for sampling') |
| parser.add_argument('--sampling_ratio_per_lang', type=str, help='Ratio of documents per lang, e.g., "en:0.25,sv:0.34"') |
| parser.add_argument('--sampling_q1_prob', type=float, default=0.20, help='Probabilty for keeping documents in the Q1 range') |
| parser.add_argument('--sampling_q3_prob', type=float, default=0.05, help='Probabilty for keeping documents in the Q3 range') |
| parser.add_argument('--include', type=str, help='Comma separeted list of doc type prefixes to include') |
|
|
| args = parser.parse_args() |
|
|
| if args.sampling_ratio_per_lang: |
| |
| ratio_per_lang = dict( |
| (k.strip(), float(v.strip())) |
| for k, v in (item.split(":") |
| for item in args.sampling_ratio_per_lang.split(",") |
| ) |
| ) |
| else: |
| ratio_per_lang = None |
| if args.overwrite_prefix_lang: |
| |
| prefix_lang_mapping = dict( |
| (k.strip(), v.strip()) |
| for k, v in (item.split(":") |
| for item in args.overwrite_prefix_lang.split(",") |
| ) |
| ) |
| else: |
| prefix_lang_mapping = {} |
|
|
| process_files( |
| args.directory, |
| args.reject_level, |
| args.model_override, |
| args.output_file, |
| group_by_prefix_lang=args.group_by_prefix_lang, |
| prefix_lang_mapping=prefix_lang_mapping, |
| pa=args.sampling_q1_prob, |
| pb=args.sampling_q3_prob, |
| ratio=args.sampling_ratio, |
| ratio_per_lang=ratio_per_lang, |
| include=args.include.split(",") if args.include else None |
| ) |
|
|
| if __name__ == "__main__": |
| main() |
|
|