| | """ |
| | This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/python_executor.py |
| | """ |
| | import io |
| | import regex |
| | import pickle |
| | import traceback |
| | import copy |
| | import datetime |
| | import multiprocessing |
| | import dateutil.relativedelta |
| | import multiprocess |
| | from multiprocess import Pool |
| | from typing import Any, Dict, Optional |
| | from pebble import ProcessPool |
| | from tqdm import tqdm |
| | from concurrent.futures import TimeoutError |
| | from functools import partial |
| | from timeout_decorator import timeout |
| | from contextlib import redirect_stdout |
| |
|
| |
|
| | class GenericRuntime: |
| | GLOBAL_DICT = {} |
| | LOCAL_DICT = None |
| | HEADERS = [] |
| | def __init__(self): |
| | self._global_vars = copy.copy(self.GLOBAL_DICT) |
| | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None |
| |
|
| | for c in self.HEADERS: |
| | self.exec_code(c) |
| |
|
| | def exec_code(self, code_piece: str) -> None: |
| | if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece): |
| | raise RuntimeError() |
| | exec(code_piece, self._global_vars) |
| | |
| | def eval_code(self, expr: str) -> Any: |
| | return eval(expr, self._global_vars) |
| | |
| | def inject(self, var_dict: Dict[str, Any]) -> None: |
| | for k, v in var_dict.items(): |
| | self._global_vars[k] = v |
| | |
| | @property |
| | def answer(self): |
| | return self._global_vars['answer'] |
| |
|
| | class DateRuntime(GenericRuntime): |
| | GLOBAL_DICT = { |
| | 'datetime': datetime.datetime, |
| | 'timedelta': dateutil.relativedelta.relativedelta, |
| | 'relativedelta': dateutil.relativedelta.relativedelta |
| | } |
| |
|
| |
|
| | class CustomDict(dict): |
| | def __iter__(self): |
| | return list(super().__iter__()).__iter__() |
| |
|
| | class ColorObjectRuntime(GenericRuntime): |
| | GLOBAL_DICT = {'dict': CustomDict} |
| |
|
| |
|
| | class PythonExecutor: |
| | def __init__( |
| | self, |
| | runtime: Optional[Any] = None, |
| | get_answer_symbol: Optional[str] = None, |
| | get_answer_expr: Optional[str] = None, |
| | get_answer_from_stdout: bool = False, |
| | timeout_length: int = 5, |
| | ) -> None: |
| | self.runtime = runtime if runtime else GenericRuntime() |
| | self.answer_symbol = get_answer_symbol |
| | self.answer_expr = get_answer_expr |
| | self.get_answer_from_stdout = get_answer_from_stdout |
| | self.timeout_length = timeout_length |
| |
|
| | def process_generation_to_code(self, gens: str): |
| | return [g.split('\n') for g in gens] |
| |
|
| | @staticmethod |
| | def execute( |
| | code, |
| | get_answer_from_stdout = None, |
| | runtime = None, |
| | answer_symbol = None, |
| | answer_expr = None, |
| | timeout_length = 10, |
| | ): |
| | try: |
| | if get_answer_from_stdout: |
| | program_io = io.StringIO() |
| | with redirect_stdout(program_io): |
| | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) |
| | program_io.seek(0) |
| | result = program_io.readlines()[-1] |
| | elif answer_symbol: |
| | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) |
| | result = runtime._global_vars[answer_symbol] |
| | elif answer_expr: |
| | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) |
| | result = timeout(timeout_length)(runtime.eval_code)(answer_expr) |
| | else: |
| | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) |
| | result = timeout(timeout_length)(runtime.eval_code)(code[-1]) |
| | exec_info = "Done" |
| | str(result) |
| | pickle.dumps(result) |
| | except: |
| | result = '' |
| | exec_info = traceback.format_exc().split('\n')[-2] |
| | return result, exec_info |
| |
|
| | def apply(self, code): |
| | return self.batch_apply([code])[0] |
| |
|
| | def batch_apply(self, batch_code): |
| | all_code_snippets = self.process_generation_to_code(batch_code) |
| |
|
| | timeout_cnt = 0 |
| | all_exec_results = [] |
| | with ProcessPool(max_workers=min(len(all_code_snippets), multiprocessing.cpu_count())) as pool: |
| | executor = partial( |
| | self.execute, |
| | get_answer_from_stdout=self.get_answer_from_stdout, |
| | runtime=self.runtime, |
| | answer_symbol=self.answer_symbol, |
| | answer_expr=self.answer_expr, |
| | timeout_length=self.timeout_length, |
| | ) |
| | future = pool.map(executor, all_code_snippets, timeout=self.timeout_length) |
| | iterator = future.result() |
| |
|
| | if len(all_code_snippets) > 100: |
| | progress_bar = tqdm(total=len(all_code_snippets), desc="Execute") |
| | else: |
| | progress_bar = None |
| |
|
| | while True: |
| | try: |
| | result = next(iterator) |
| | all_exec_results.append(result) |
| | except StopIteration: |
| | break |
| | except TimeoutError as error: |
| | print(error) |
| | all_exec_results.append(("", "Timeout Error")) |
| | timeout_cnt += 1 |
| | except Exception as error: |
| | print(error) |
| | exit() |
| | if progress_bar is not None: |
| | progress_bar.update(1) |
| | |
| | if progress_bar is not None: |
| | progress_bar.close() |
| |
|
| | batch_results = [] |
| | for code, (result, exec_info) in zip(all_code_snippets, all_exec_results): |
| | batch_results.append((result, exec_info)) |
| | return batch_results |
| |
|
| |
|
| | def _test(): |
| | batch_code = [ |
| | """ |
| | print("Hello world!") |
| | """ |
| | ] |
| |
|
| | executor = PythonExecutor(get_answer_from_stdout=True) |
| | predictions = executor.apply(batch_code[0]) |
| | print(predictions) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | _test() |