| import sys |
| import os |
| from datetime import datetime |
| from constants.paths import * |
|
|
| from models.Gemini import Gemini |
| from models.OpenAI import OpenAIModel |
|
|
| from results.Results import Results |
|
|
| from promptings.PromptingFactory import PromptingFactory |
| from datasets.DatasetFactory import DatasetFactory |
| from models.ModelFactory import ModelFactory |
|
|
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument( |
| "--dataset", |
| type=str, |
| default="HumanEval", |
| choices=[ |
| "HumanEval", |
| "MBPP", |
| "APPS", |
| "xCodeEval", |
| "CC", |
| ] |
| ) |
| parser.add_argument( |
| "--strategy", |
| type=str, |
| default="MapCoder", |
| choices=[ |
| "Direct", |
| "CoT", |
| "SelfPlanning", |
| "Analogical", |
| "MapCoder", |
| "DebateCoder", |
| ] |
| ) |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="ChatGPT", |
| choices=[ |
| "ChatGPT", |
| "GPT4", |
| "Gemini", |
| "DeepSeek", |
| "Pangu", |
| "Qwen", |
| "Pangu72B", |
| ] |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0 |
| ) |
| parser.add_argument( |
| "--pass_at_k", |
| type=int, |
| default=1 |
| ) |
| parser.add_argument( |
| "--language", |
| type=str, |
| default="Python3", |
| choices=[ |
| "C", |
| "C#", |
| "C++", |
| "Go", |
| "PHP", |
| "Python3", |
| "Ruby", |
| "Rust", |
| ] |
| ) |
|
|
| args = parser.parse_args() |
|
|
| DATASET = args.dataset |
| STRATEGY = args.strategy |
| MODEL_NAME = args.model |
| TEMPERATURE = args.temperature |
| PASS_AT_K = args.pass_at_k |
| LANGUAGE = args.language |
|
|
| RUN_NAME = f"{MODEL_NAME}-{STRATEGY}-{DATASET}-{LANGUAGE}-{TEMPERATURE}-{PASS_AT_K}" |
| RESULTS_PATH = f"./outputs/{RUN_NAME}.jsonl" |
|
|
| print(f"#########################\nRunning start {RUN_NAME}, Time: {datetime.now()}\n##########################\n") |
|
|
| |
| model_class = ModelFactory.get_model_class(MODEL_NAME) |
| model_instance = model_class(temperature=TEMPERATURE) |
|
|
| |
| def _format_model_info(model_obj): |
| info_lines = [] |
| info_lines.append(f"model_class={model_obj.__class__.__name__}") |
| |
| for k, v in getattr(model_obj, "__dict__", {}).items(): |
| try: |
| if isinstance(v, (str, int, float, bool)): |
| info_lines.append(f"{k}={v}") |
| else: |
| info_lines.append(f"{k}=<{type(v).__name__}>") |
| except Exception: |
| info_lines.append(f"{k}=<unrepr>") |
| return "; ".join(info_lines) |
|
|
| model_info_str = _format_model_info(model_instance) |
|
|
| |
| print(f"[MODEL INFO] {model_info_str}") |
|
|
| strategy = PromptingFactory.get_prompting_class(STRATEGY)( |
| model=model_instance, |
| data=DatasetFactory.get_dataset_class(DATASET)(), |
| language=LANGUAGE, |
| pass_at_k=PASS_AT_K, |
| results=Results(RESULTS_PATH), |
| ) |
|
|
| strategy.run() |
|
|
| print(f"#########################\nRunning end {RUN_NAME}, Time: {datetime.now()}\n##########################\n") |
|
|
|
|