| | import os |
| |
|
| | import pandas as pd |
| | import argparse |
| |
|
| |
|
| | def main(input_file): |
| | |
| | df = pd.read_csv(input_file) |
| | df['acc1'] = (df['acc1'] * 100).round(2) |
| | df['dataset'] = df['dataset'].str.replace(r'^(wds/vtab/|wds/)', '', regex=True) |
| |
|
| | |
| | columns_to_pivot = ["dataset"] |
| |
|
| | |
| | index_columns = ["model", "pretrained", "attack", "eps", "iterations_adv"] |
| |
|
| | |
| | df_pivot = df.pivot_table(values="acc1", index=index_columns, columns=columns_to_pivot).reset_index() |
| | del df |
| |
|
| | |
| | output_file = "pivoted.csv" |
| | df_pivot.to_csv(output_file, index=False) |
| | print(df_pivot, "\n") |
| | print(df_pivot.to_csv(index=False)) |
| | print(f"Pivoted CSV saved as {output_file}") |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Pivot a CSV file.") |
| | parser.add_argument("input_file", type=str, default=None, help="The input CSV file to be pivoted.") |
| | args = parser.parse_args() |
| |
|
| | if not args.input_file: |
| | input_file = input("enter input file: ") |
| | |
| | else: |
| | input_file = args.input_file |
| |
|
| | main(input_file) |
| |
|