| | import argparse |
| | import os |
| | from collections import defaultdict |
| | from io import StringIO |
| |
|
| | import pandas as pd |
| | from tqdm import tqdm |
| |
|
| | from perplexity import get_model_for |
| | from subsampler import PerplexitySubsampler |
| |
|
| |
|
| | def process_files( |
| | directory, |
| | reject_level, |
| | model_override, |
| | output_file, |
| | group_by_prefix_lang, |
| | prefix_lang_mapping=None, |
| | ratio=None, |
| | ratio_per_lang=None, |
| | pa=None, |
| | pb=None, |
| | include=None, |
| | ): |
| | if ratio or ratio_per_lang: |
| | rows = ["doc_type,model,language,reject,bad,medium,good,norm,mean,std"] |
| | else: |
| | rows = ["doc_type,model,language,reject,bad,medium,good"] |
| | files = os.listdir(directory) |
| | grouped_files = defaultdict(list) |
| | if prefix_lang_mapping is None: |
| | prefix_lang_mapping = {} |
| |
|
| | |
| | description = "Processing files" |
| | if group_by_prefix_lang: |
| | description = "Processing files in groups" |
| | for file in files: |
| | parts = file.split('_') |
| | prefix = parts[0] |
| | if include and prefix not in include: |
| | continue |
| | lang = parts[-1].split(".")[0][:2] |
| | group_key = prefix_lang_mapping.get(f"{prefix}_{lang}", f"{prefix}_{lang}") |
| | grouped_files[group_key].append(file) |
| | file_groups = grouped_files.values() |
| | else: |
| | file_groups = [] |
| | for file in files: |
| | if include and not any(file.startswith(prefix) for prefix in include): |
| | continue |
| | file_groups.append([file]) |
| |
|
| | if output_file: |
| | progress = tqdm(file_groups, desc=description) |
| | else: |
| | progress = file_groups |
| | print(rows[0]) |
| | |
| | for group in progress: |
| | combined_perplexities = pd.DataFrame() |
| | doc_type, lang = None, None |
| |
|
| | for file in group: |
| | if not doc_type or not lang: |
| | parts = file.split('_') |
| | doc_type = file.split('_')[0] |
| | lang = parts[-1].split(".")[0][:2] |
| | doc_type, lang = prefix_lang_mapping.get(f"{doc_type}_{lang}", f"{doc_type}_{lang}").rsplit("_", 1) |
| | perp = pd.read_json(os.path.join(directory, file), lines=True) |
| | perplexities = pd.read_json(StringIO(perp["perplexities"].to_json(lines=True, orient="records")), lines=True) |
| | combined_perplexities = pd.concat([combined_perplexities, perplexities], ignore_index=True) |
| |
|
| | if model_override: |
| | model = model_override |
| | else: |
| | model, _ = get_model_for(doc_type) |
| | model_with_suffix = f"{model}_pp" |
| |
|
| | |
| | reject = round(combined_perplexities[model_with_suffix].quantile(q=reject_level), 2) |
| | bad = round(combined_perplexities[model_with_suffix].quantile(q=0.75), 2) |
| | medium = round(combined_perplexities[model_with_suffix].quantile(q=0.50), 2) |
| | good = round(combined_perplexities[model_with_suffix].quantile(q=0.25), 2) |
| |
|
| | if ratio: |
| | subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) |
| | subsampler.set(ratio=ratio, pa=pa, pb=pb) |
| | norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev |
| | sampling_stats = f",{norm},{mean},{std}" |
| | elif ratio_per_lang: |
| | subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values) |
| | subsampler.set(ratio=ratio_per_lang.get(lang, ratio or 1.0), pa=pa, pb=pb) |
| | norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev |
| | sampling_stats = f",{norm},{mean},{std}" |
| | else: |
| | sampling_stats = "" |
| |
|
| | row = f"{doc_type},{model},{lang},{reject},{bad},{medium},{good}{sampling_stats}" |
| | if output_file: |
| | rows.append(row) |
| | else: |
| | print(row) |
| |
|
| |
|
| | if output_file: |
| | with open(output_file, "w") as f: |
| | for row in rows: |
| | f.write(f"{row}\n") |
| |
|
| |
|
| | def main(): |
| | """" |
| | Each doc_type prefix needs to have an "no" lang, even of there's no real data. |
| | These rows are crucial for the rest of the process. |
| | """ |
| | parser = argparse.ArgumentParser(description="Process files and compute perplexity metrics.") |
| | parser.add_argument('directory', type=str, help='Directory containing the files to process') |
| | parser.add_argument('--reject_level', type=float, default=0.95, help='Rejection quantile level (default: 0.95)') |
| | parser.add_argument('--model_override', type=str, help='Override the model used') |
| | parser.add_argument('--output_file', type=str, help='Output file in CSV format. If not given, prints to standard output.') |
| | parser.add_argument('--group_by_prefix_lang', action='store_true', help='Group and calculate quantiles for files with the same prefix and language') |
| | parser.add_argument('--overwrite_prefix_lang', type=str, help='Overwrite the assignment of languages to doc_type prefixes, e.g., "starcoder_en:starcoder_code,hplt_en:hplt_no"') |
| | parser.add_argument('--sampling_ratio', type=float, help='Ratio of documents to keep for sampling. If passed, it generate distribution statistics (norm, mean, std) needed for sampling') |
| | parser.add_argument('--sampling_ratio_per_lang', type=str, help='Ratio of documents per lang, e.g., "en:0.25,sv:0.34"') |
| | parser.add_argument('--sampling_q1_prob', type=float, default=0.20, help='Probabilty for keeping documents in the Q1 range') |
| | parser.add_argument('--sampling_q3_prob', type=float, default=0.05, help='Probabilty for keeping documents in the Q3 range') |
| | parser.add_argument('--include', type=str, help='Comma separeted list of doc type prefixes to include') |
| |
|
| | args = parser.parse_args() |
| |
|
| | if args.sampling_ratio_per_lang: |
| | |
| | ratio_per_lang = dict( |
| | (k.strip(), float(v.strip())) |
| | for k, v in (item.split(":") |
| | for item in args.sampling_ratio_per_lang.split(",") |
| | ) |
| | ) |
| | else: |
| | ratio_per_lang = None |
| | if args.overwrite_prefix_lang: |
| | |
| | prefix_lang_mapping = dict( |
| | (k.strip(), v.strip()) |
| | for k, v in (item.split(":") |
| | for item in args.overwrite_prefix_lang.split(",") |
| | ) |
| | ) |
| | else: |
| | prefix_lang_mapping = {} |
| |
|
| | process_files( |
| | args.directory, |
| | args.reject_level, |
| | args.model_override, |
| | args.output_file, |
| | group_by_prefix_lang=args.group_by_prefix_lang, |
| | prefix_lang_mapping=prefix_lang_mapping, |
| | pa=args.sampling_q1_prob, |
| | pb=args.sampling_q3_prob, |
| | ratio=args.sampling_ratio, |
| | ratio_per_lang=ratio_per_lang, |
| | include=args.include.split(",") if args.include else None |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|