| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import csv |
| | from collections import defaultdict |
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from matplotlib.ticker import ScalarFormatter |
| |
|
| | from transformers import HfArgumentParser |
| |
|
| |
|
| | def list_field(default=None, metadata=None): |
| | return field(default_factory=lambda: default, metadata=metadata) |
| |
|
| |
|
| | @dataclass |
| | class PlotArguments: |
| | """ |
| | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
| | """ |
| |
|
| | csv_file: str = field( |
| | metadata={"help": "The csv file to plot."}, |
| | ) |
| | plot_along_batch: bool = field( |
| | default=False, |
| | metadata={"help": "Whether to plot along batch size or sequence length. Defaults to sequence length."}, |
| | ) |
| | is_time: bool = field( |
| | default=False, |
| | metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."}, |
| | ) |
| | no_log_scale: bool = field( |
| | default=False, |
| | metadata={"help": "Disable logarithmic scale when plotting"}, |
| | ) |
| | is_train: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Whether the csv file has training results or inference results. Defaults to inference results." |
| | }, |
| | ) |
| | figure_png_file: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."}, |
| | ) |
| | short_model_names: Optional[list[str]] = list_field( |
| | default=None, metadata={"help": "List of model names that are used instead of the ones in the csv file."} |
| | ) |
| |
|
| |
|
| | def can_convert_to_int(string): |
| | try: |
| | int(string) |
| | return True |
| | except ValueError: |
| | return False |
| |
|
| |
|
| | def can_convert_to_float(string): |
| | try: |
| | float(string) |
| | return True |
| | except ValueError: |
| | return False |
| |
|
| |
|
| | class Plot: |
| | def __init__(self, args): |
| | self.args = args |
| | self.result_dict = defaultdict(lambda: {"bsz": [], "seq_len": [], "result": {}}) |
| |
|
| | with open(self.args.csv_file, newline="") as csv_file: |
| | reader = csv.DictReader(csv_file) |
| | for row in reader: |
| | model_name = row["model"] |
| | self.result_dict[model_name]["bsz"].append(int(row["batch_size"])) |
| | self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"])) |
| | if can_convert_to_int(row["result"]): |
| | |
| | self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = ( |
| | int(row["result"]) |
| | ) |
| | elif can_convert_to_float(row["result"]): |
| | |
| | self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = ( |
| | float(row["result"]) |
| | ) |
| |
|
| | def plot(self): |
| | fig, ax = plt.subplots() |
| | title_str = "Time usage" if self.args.is_time else "Memory usage" |
| | title_str = title_str + " for training" if self.args.is_train else title_str + " for inference" |
| |
|
| | if not self.args.no_log_scale: |
| | |
| | ax.set_xscale("log") |
| | ax.set_yscale("log") |
| |
|
| | for axis in [ax.xaxis, ax.yaxis]: |
| | axis.set_major_formatter(ScalarFormatter()) |
| |
|
| | for model_name_idx, model_name in enumerate(self.result_dict.keys()): |
| | batch_sizes = sorted(set(self.result_dict[model_name]["bsz"])) |
| | sequence_lengths = sorted(set(self.result_dict[model_name]["seq_len"])) |
| | results = self.result_dict[model_name]["result"] |
| |
|
| | (x_axis_array, inner_loop_array) = ( |
| | (batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes) |
| | ) |
| |
|
| | label_model_name = ( |
| | model_name if self.args.short_model_names is None else self.args.short_model_names[model_name_idx] |
| | ) |
| |
|
| | for inner_loop_value in inner_loop_array: |
| | if self.args.plot_along_batch: |
| | y_axis_array = np.asarray( |
| | [results[(x, inner_loop_value)] for x in x_axis_array if (x, inner_loop_value) in results], |
| | dtype=int, |
| | ) |
| | else: |
| | y_axis_array = np.asarray( |
| | [results[(inner_loop_value, x)] for x in x_axis_array if (inner_loop_value, x) in results], |
| | dtype=np.float32, |
| | ) |
| |
|
| | (x_axis_label, inner_loop_label) = ( |
| | ("batch_size", "len") if self.args.plot_along_batch else ("in #tokens", "bsz") |
| | ) |
| |
|
| | x_axis_array = np.asarray(x_axis_array, int)[: len(y_axis_array)] |
| | plt.scatter( |
| | x_axis_array, y_axis_array, label=f"{label_model_name} - {inner_loop_label}: {inner_loop_value}" |
| | ) |
| | plt.plot(x_axis_array, y_axis_array, "--") |
| |
|
| | title_str += f" {label_model_name} vs." |
| |
|
| | title_str = title_str[:-4] |
| | y_axis_label = "Time in s" if self.args.is_time else "Memory in MB" |
| |
|
| | |
| | plt.title(title_str) |
| | plt.xlabel(x_axis_label) |
| | plt.ylabel(y_axis_label) |
| | plt.legend() |
| |
|
| | if self.args.figure_png_file is not None: |
| | plt.savefig(self.args.figure_png_file) |
| | else: |
| | plt.show() |
| |
|
| |
|
| | def main(): |
| | parser = HfArgumentParser(PlotArguments) |
| | plot_args = parser.parse_args_into_dataclasses()[0] |
| | plot = Plot(args=plot_args) |
| | plot.plot() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|