| from typing import List |
| import tiktoken |
| import os |
| import copy |
| import time |
|
|
| from models.Base import BaseModel |
| from datasets.Dataset import Dataset |
| from results.Results import Results |
| from utils.parse import parse_response |
|
|
|
|
| class BaseStrategy(object): |
| def __init__( |
| self, |
| model: BaseModel, |
| data: Dataset, |
| language: str, |
| pass_at_k: int, |
| results: Results, |
| verbose: bool = True, |
| ): |
| self.model = model |
| self.data = data |
| self.pass_at_k = pass_at_k |
| self.results = results |
| self.language = language |
| self.verbose = verbose |
|
|
| def gpt_chat(self, processed_input: List[dict]) -> (str, int, int): |
| return self.model.prompt(processed_input=processed_input) |
|
|
| def run_single_pass(self, item: dict): |
| pass |
|
|
| def run(self): |
| num_items = len(self.data) |
| num_success = 0 |
|
|
| for i, item in enumerate(self.data): |
| print("", flush=True, end="") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| if i < len(self.results): |
| item = copy.deepcopy(self.results[i]) |
| cur_pass = len(item["source_codes"]) |
| is_solved = item["is_solved"] |
| |
| |
|
|
| cur_imp = item["source_codes"][-1] |
|
|
| |
| |
| |
| |
| |
| |
| |
| else: |
| item = copy.deepcopy(item) |
| item["source_codes"] = [] |
| item["responses"] = [] |
| item["prompt_tokens"] = [] |
| item["completion_tokens"] = [] |
| item["no_of_try"] = 0 |
|
|
| cur_pass = 0 |
| is_solved = False |
| cur_imp = "" |
|
|
| while cur_pass < self.pass_at_k and not is_solved: |
| try: |
| response, prompt_tokens, completion_tokens = self.run_single_pass( |
| item) |
| except Exception as e: |
| print(f"Error processing item {item.get('task_id', 'unknown')}: {e}") |
| cur_pass += 1 |
| continue |
|
|
| if hasattr(self, "parse_code"): |
| cur_imp = self.parse_code(response) |
| else: |
| cur_imp = parse_response(response) |
| |
|
|
| item["source_codes"].append(cur_imp) |
| item["responses"].append(response) |
| item["prompt_tokens"].append(prompt_tokens) |
| item["completion_tokens"].append(completion_tokens) |
| item["no_of_try"] += 1 |
|
|
| is_solved = self.data.evaluate( |
| item=item, |
| cur_imp=cur_imp, |
| language=self.language |
| ) |
|
|
| cur_pass += 1 |
|
|
| if is_solved: |
| num_success += 1 |
|
|
| item["is_solved"] = is_solved |
| item["language"] = self.language |
| item["task_id"] = item[self.data.id_key] |
|
|
| if i < len(self.results): |
| self.results.results[i] = item |
| self.results.save_results() |
| else: |
| self.results.add_result(item) |
|
|
| if self.verbose: |
| print( |
| f'completed {i+1}/{num_items}, Solved: {self.results[i]["is_solved"]}, number of success = {num_success}/{i+1}, acc = {round(num_success/(i+1)*100, 2)}') |
|
|
| |
|
|