| import ast |
| import os |
| import subprocess |
| from random import randint |
| from tqdm import tqdm |
| from shutil import copyfile |
| import datetime |
| import json |
| from parse_llm_code import extract_code_blocks |
| import numpy as np |
| import re |
|
|
| def get_temp_bash_file(prefix='temp_code'): |
| |
| temp_file_name = f'{prefix}_{randint(999, 999999)}.sh' |
| while os.path.exists(temp_file_name): |
| temp_file_name.replace('.sh', f'_{randint(999, 999999)}.sh') |
| return temp_file_name |
|
|
| def parse_profiler_content(profile_content): |
| delimiter = "--------------------------------------------------------------------------------" |
| |
| parts = profile_content.split(delimiter) |
| |
| section_data = {} |
|
|
| section_pattern = re.compile(r"^\s*(\d+)\..*$", re.MULTILINE) |
|
|
| for part in parts: |
| trimmed_part = part.strip() |
| if not trimmed_part: |
| continue |
| |
| match = section_pattern.search(trimmed_part) |
| if match: |
| section_number = match.group(1) |
| full_section_content = delimiter + part |
| section_data[section_number] = full_section_content |
| |
| return section_data |
|
|
| |
| def passk(n, c, k): |
| if n -c < k: return 1.0 |
| return 1 - np.prod( |
| 1 - k/ np.arange( |
| n-c+1, n+1 |
| ) |
| ) |
|
|
| def get_time(): |
| |
| return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
| def get_temp_file(prefix='temp_code'): |
| |
| temp_file_name = f'{prefix}_{randint(999, 999999)}.py' |
| while os.path.exists(temp_file_name): |
| temp_file_name.replace('.py', f'_{randint(999, 999999)}.py') |
| return temp_file_name |
|
|
| def code_call_exec_success_stdout(code, fname, temp_root="tmp2", tolerance=2, verbose=False): |
| |
| tmp_triton_folder = os.path.join(temp_root, "triton") |
| tmp_gen_folder = os.path.join(temp_root, "gen") |
| os.makedirs(tmp_triton_folder, exist_ok=True) |
| os.makedirs(tmp_gen_folder, exist_ok=True) |
| |
|
|
| triton_root = "dataloaders/TB_eval/TritonBench/data/TritonBench_G_v1" |
| RAND_FILE = os.path.join(triton_root, "rand_utils.py") |
|
|
| copyfile(RAND_FILE, os.path.join(tmp_triton_folder, "rand_utils.py")) |
| copyfile(RAND_FILE, os.path.join(tmp_gen_folder, "rand_utils.py")) |
|
|
| gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code') |
| triton_file = os.path.join(triton_root, fname) |
| temp_triton_file = get_temp_file(prefix=f'{fname}_temp_triton') |
|
|
| gen_file = os.path.join(tmp_gen_folder, gen_file) |
| temp_triton_file = os.path.join(tmp_triton_folder, temp_triton_file) |
|
|
| IMPORT_STATEMENT = f""" |
| from rand_utils import torch_rand, torch_randint, torch_randn |
| import torch |
| torch.set_printoptions(precision={tolerance},profile='full',sci_mode=False) |
| """ |
|
|
| hash_line = "#"*146 |
| |
| with open(triton_file, 'r') as f: |
| lines = f.readlines() |
| |
| |
| |
| for iL, line in enumerate(lines): |
| if line.strip() == hash_line: |
| break |
| test_code_lines = lines[iL+1:] |
| test_code_lines = IMPORT_STATEMENT.split('\n') + test_code_lines |
| test_code_lines_procs = [] |
| for line in test_code_lines: |
| if "torch.rand" in line: |
| line = line.replace("torch.rand", "torch_rand") |
| test_code_lines_procs.append(line) |
|
|
| with open(temp_triton_file, 'w') as f: |
| triton_lines = lines[:iL] + [hash_line] + test_code_lines_procs |
| for line in triton_lines: |
| f.write(line + "\n") |
|
|
| code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs) |
| with open(gen_file, 'w') as f: |
| f.write(code) |
|
|
| |
| |
| |
| |
|
|
| try: |
| |
| result_gen = subprocess.run(['python3', gen_file], capture_output=True, text=True, timeout=2*60) |
| stdout_gen = result_gen.stdout |
| stderr_gen = result_gen.stderr |
|
|
| |
| if result_gen.returncode != 0: |
| if verbose: |
| print(f"Error in generated code: {stderr_gen}") |
| return False, False, stdout_gen, stderr_gen |
|
|
| |
| result_triton = subprocess.run(['python3', temp_triton_file], capture_output=True, text=True, timeout=2*60) |
| stdout_triton = result_triton.stdout |
| stderr_triton = result_triton.stderr |
|
|
| |
| if result_triton.returncode != 0: |
| if verbose: |
| print(f"Error in Triton code: {stderr_triton}") |
| return None, None, None, None |
|
|
| with open(gen_file+".out", 'w') as f: |
| f.write(stdout_gen) |
| with open(temp_triton_file+".out", 'w') as f: |
| f.write(stdout_triton) |
|
|
| with open(gen_file+".err", 'w') as f: |
| f.write(stderr_gen) |
| with open(temp_triton_file+".err", 'w') as f: |
| f.write(stderr_triton) |
|
|
| |
| if stdout_gen == stdout_triton: |
| return True, True, None, None |
| else: |
| return True, False, stdout_gen, "Error: not all test cases passed. The generated code and ground truth code produced different outputs." |
| except Exception as e: |
| if verbose: |
| print(f"File: {fname}, Execution error: {e}") |
| return False, False, None, str(e) |
| |
| except subprocess.TimeoutExpired: |
| if verbose: |
| print(f"File: {fname} timed out!") |
| return None, None, None, "Time out" |
| finally: |
| pass |
| |
| |
| |
| return False, False, None, None |
|
|
| def code_kernel_profiling(code, fname, py_folder, target_gpu, temp_root="tmp2", atol=1e-3, rtol=1e-1, timeout=6*60, verbose=False): |
| tmp_gen_folder = os.path.join(temp_root, "gen") |
| os.makedirs(tmp_gen_folder, exist_ok=True) |
| |
| |
| triton_root = py_folder |
| triton_file = os.path.join(triton_root, fname) |
|
|
| gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code') |
| gen_file = os.path.join(tmp_gen_folder, gen_file) |
| |
| fname_split = fname.split('.')[0] |
| gen_bash_file = get_temp_bash_file(prefix=f'{fname_split}_gen_triton_code') |
| gen_bash_file = os.path.join(tmp_gen_folder, gen_bash_file) |
|
|
| hash_line = "#"*146 |
|
|
| with open(triton_file, 'r') as f: |
| lines = f.readlines() |
| for iL, line in enumerate(lines): |
| if line.strip() == hash_line: |
| break |
| test_code_lines = lines[iL+1:] |
| test_code_lines_procs = test_code_lines |
|
|
| |
|
|
| code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs) |
| |
| code_bash = f"python3 {gen_file}" |
| with open(gen_file, 'w') as f: |
| f.write(code) |
| with open(gen_bash_file, 'w') as f: |
| f.write(code_bash) |
| try: |
| |
| result_profile = subprocess.run([f'rocprof-compute profile -n {fname_split} -- /bin/bash {gen_bash_file}'], capture_output=True, text=True, timeout=timeout, shell=True) |
| analyze_profile = subprocess.run([f'rocprof-compute analyze -p workloads/{fname_split}/{target_gpu}'], capture_output=True, text=True, timeout=timeout, shell=True) |
| |
| |
| profile_status = result_profile.returncode == 0 |
| stdout_profile = result_profile.stdout |
| stderr_profile = result_profile.stderr |
| |
| except Exception as e: |
| if verbose: |
| print(f"File: {fname}, Execution error: {e}") |
| return None, None, str(e), None |
| |
| |
| except subprocess.TimeoutExpired: |
| if verbose: |
| print(f"File: {fname} timed out!") |
| return None, None, "Time out", None |
| finally: |
| pass |
| |
| |
| if result_profile.returncode != 0: |
| if verbose: |
| print(f"Error in profiling kernel") |
| else: |
| if verbose: |
| print(f"Success in in profiling kernel") |
| try: |
| section_text = parse_profiler_content(analyze_profile.stdout) |
| stdout_analyze = "\nBelow are some profiling info of this kernel generated by the tool of rocprof-compute on AMD MI250 gpu, you can reference these info to analyze and generate better kernel." |
| stdout_analyze += "\n1.Overview:Briefly describe the kernel type along with its runtime and dispatch statistics, such as the main kernel name, invocation count, and average execution time." |
| stdout_analyze += f"\n{section_text['0']}" |
| stdout_analyze += "\n2.Hardware & Resources:Key hardware details including model, architecture, number of CUs, capacities of LDS/SMEM/registers, and maximum workgroup size." |
| stdout_analyze += f"\n{section_text['1']}" |
| stdout_analyze += "\n3.Performance Utilization & Bottlenecks:Core bottleneck indicators such as FLOPs utilization, active CUs, occupancy, and memory bandwidth/utilization." |
| stdout_analyze += f"\n{section_text['2']}" |
| stdout_analyze += "\n4.Instruction Mix & Memory Access:Distribution of arithmetic, memory, and branch instructions (e.g., MFMA/FMA/VALU/VMEM), cache hit rates (L1/L2), memory bandwidth, and conflict statistics." |
| stdout_analyze += f"\n{section_text['10']}" |
| stdout_analyze += f"\n{section_text['16']}" |
| stdout_analyze += f"\n{section_text['17']}" |
| stdout_analyze += "\n5.Threading & Allocation:Wavefront/workgroup counts, allocation of VGPRs/SGPRs/LDS, thread concurrency, and resource usage per thread or workgroup." |
| stdout_analyze += f"\n{section_text['7']}" |
| except Exception as e: |
| return None, None, str(e), None |
| return profile_status, stdout_profile, stderr_profile, stdout_analyze |
|
|
| def extract_code_from_llm_output(response): |
| |
| code = None |
| if "```" not in response: |
| return response |
| code_blocks = extract_code_blocks(response) |
| for _code in code_blocks.code_dict_list: |
| code += _code['context'] + "\n" |
| return code |
|
|
| def get_fname_difficulty_from_label(label): |
| triton_root = "dataloaders/TB_eval/TritonBench/data/TritonBench_G_comp_alpac_v1_fixed_with_difficulty.json" |
| with open(triton_root, 'r') as f: |
| data = json.load(f) |
| for item in data: |
| if item['output'] == label: |
| return item['file'], item['difficulty'] |
| return None, None |
|
|
| def process_code(code: str): |
| if "```python" in code: |
| code = code.split("```python")[-1].replace("<|im_end|>", "").replace("<|EOT|>", "") |
| |
| try: |
| tree = ast.parse(code) |
| imports = [] |
| function_definitions = [] |
|
|
| |
| for node in ast.walk(tree): |
| if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): |
| |
| imports.append(ast.unparse(node)) |
| elif isinstance(node, ast.FunctionDef): |
| |
| function_code = ast.unparse(node) |
| function_definitions.append(function_code) |
|
|
| return "\n".join(imports) + "\n\n" + "\n".join(function_definitions) |
|
|
| except: |
| return code |
|
|
|
|
| def code_call_exec_success_allclose(code, fname, py_folder, temp_root="tmp2", atol=1e-3, rtol=1e-1, timeout=2*60, verbose=False, gpu_id=0): |
| tmp_gen_folder = os.path.join(temp_root, "gen") |
| os.makedirs(tmp_gen_folder, exist_ok=True) |
| match = re.match(r"^([a-zA-Z0-9_]+?)(?:_\d+)?\.py$", fname) |
| if match: |
| op = match.group(1) |
| filename = op + '.py' |
| triton_root = py_folder |
| triton_file = os.path.join(triton_root, filename) |
|
|
| gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code') |
| gen_file = os.path.join(tmp_gen_folder, gen_file) |
|
|
| hash_line = "#"*146 |
|
|
| with open(triton_file, 'r') as f: |
| lines = f.readlines() |
| for iL, line in enumerate(lines): |
| if line.strip() == hash_line: |
| break |
| test_code_lines = lines[iL+1:] |
| test_code_lines_procs = test_code_lines |
|
|
| |
|
|
| code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs) |
|
|
| with open(gen_file, 'w') as f: |
| f.write(code) |
|
|
| try: |
| |
| result_call = subprocess.run([f'HIP_VISIBLE_DEVICES={gpu_id} python3 {gen_file}'], capture_output=True, text=True, timeout=timeout, shell=True) |
| call_status = result_call.returncode == 0 |
|
|
| |
| result_corr = subprocess.run([f'HIP_VISIBLE_DEVICES={gpu_id} python3 dataloaders/TB_eval/correctness.py --gen_file {gen_file} --ref_file {triton_file} --atol {atol} --rtol {rtol}'], capture_output=True, text=True, timeout=timeout, shell=True) |
| stdout_corr = result_corr.stdout |
| stderr_corr = result_corr.stderr |
|
|
| except Exception as e: |
| if verbose: |
| print(f"File: {fname}, Execution error: {e}") |
| return None, None, None, str(e), None, None |
| |
| |
| except subprocess.TimeoutExpired: |
| if verbose: |
| print(f"File: {fname} timed out!") |
| return None, None, None, "Time out", None, None |
| finally: |
| pass |
|
|
| with open(gen_file+".stdout", 'w') as f: |
| f.write(stdout_corr) |
|
|
| with open(gen_file+".stderr", 'w') as f: |
| f.write(stderr_corr) |
|
|
| |
| if result_corr.returncode != 0: |
| if verbose: |
| print(f"Error in generated code: {stderr_corr}") |
| return call_status, None, result_call.stdout, result_call.stderr, stdout_corr, stderr_corr |
| else: |
| if verbose: |
| print(f"Success in generated code: {stdout_corr}") |
| _, exec_status, gen_stdout, gen_stderr = stdout_corr.split("*#*#") |
| return call_status, exec_status, result_call.stdout, result_call.stderr, gen_stdout, gen_stderr |
|
|
| |
|
|
| class bcolors: |
| HEADER = '\033[95m' |
| OKBLUE = '\033[94m' |
| OKCYAN = '\033[96m' |
| OKGREEN = '\033[92m' |
| WARNING = '\033[93m' |
| FAIL = '\033[91m' |
| ENDC = '\033[0m' |
| BOLD = '\033[1m' |
| UNDERLINE = '\033[4m' |
|
|
| def green_or_red(status): |
| if status: |
| return bcolors.OKGREEN |
| else: |
| return bcolors.FAIL |
|
|
| def color_end(): |
| return bcolors.ENDC |
|
|
| def bool_colorize(status): |
| if status: |
| return bcolors.OKGREEN + str(status) + bcolors.ENDC |
| else: |
| return bcolors.FAIL + str(status) + bcolors.ENDC |