awq-exllama / plot_awq.py
IlyasMoutawwakil's picture
Upload folder using huggingface_hub
999b799 verified
import glob
import json
from io import StringIO
import flatten_dict
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
dfs = []
# glob.glob("pytorch_awq_exllama/**/benchmark_report.json", recursive=True)
for report_path, experiment_path in zip(
glob.glob("awq_exllama/**/benchmark_report.json", recursive=True),
glob.glob("awq_exllama/**/experiment_config.json", recursive=True),
):
with open(f"{report_path}", "r") as f:
report = json.load(f)
with open(f"{experiment_path}", "r") as f:
experiment = json.load(f)
report = flatten_dict.flatten(report, reducer="dot")
report_df = pd.read_json(StringIO(json.dumps(report)), orient="index").T
experiment = flatten_dict.flatten(experiment, reducer="dot")
experiment_df = pd.read_json(StringIO(json.dumps(experiment)), orient="index").T
df = pd.concat([report_df, experiment_df], axis=1)
dfs.append(df)
# aggregate
df = pd.concat(dfs)
# save
df.to_csv("awq_exllama/pytorch_awq_exllama.csv", index=False)
for metric in ["per_token.latency.mean", "prefill.latency.mean"]:
sns.barplot(data=df, x="benchmark.input_shapes.batch_size", y=metric, hue="backend.quantization_config.version")
# add title
plt.title("Per Token Latency" if metric == "per_token.latency.mean" else "Prefill Latency")
# change x label
plt.xlabel("Batch Size")
# change y label
plt.ylabel("Latency (ms)")
# change legend title
plt.legend(title="AWQ Version")
# tight layout
plt.tight_layout()
# save plot
plt.savefig(f"awq_exllama/pytorch_awq_exllama_{metric}.png")
# flush
plt.close()