plot.py
Go to the documentation of this file.
1 #! /usr/bin/env python3
2 # Plot benchmark results
3 
4 import argparse
5 import itertools
6 import json
7 import typing
8 from pathlib import Path
9 
10 import numpy as np
11 from bokeh.layouts import column
12 from bokeh.models import ColumnDataSource, Whisker
13 from bokeh.plotting import figure, show
14 from bokeh.transform import factor_cmap, jitter
15 
16 TimeNS = typing.NewType("TimeNS", int)
17 
18 
19 class Variance:
20  """Variance informations"""
21 
22  def __init__(
23  self,
24  mean: TimeNS,
25  median: TimeNS,
26  stddev: TimeNS,
27  ):
28  self.mean = mean
29  self.median = median
30  self.stddev = stddev
31 
32  @staticmethod
33  def default():
34  return Variance(TimeNS(0), TimeNS(0), TimeNS(0))
35 
36  def __repr__(self):
37  return f"<Variance {self.mean=}, {self.median=}, {self.stddev=}>"
38 
39 
40 class Data:
41  """Benchmarks data"""
42 
43  def __init__(self, name: str, times: list[TimeNS], variance: None | Variance):
44  self.name = name
45  self.times = times
46  self.variance = variance
47 
48  @staticmethod
49  def from_name(name: str):
50  return Data(name, [], None)
51 
52  def __repr__(self):
53  return f"<Data {self.name=}, {self.times=}, {self.variance=}>"
54 
55 
56 def make_time_ns(bench: dict) -> TimeNS:
57  """Convert JSON output time into TimeNS"""
58  t = float(bench["real_time"])
59  if bench["time_unit"] == "ns":
60  return TimeNS(int(t))
61  elif bench["time_unit"] == "us":
62  return TimeNS(int(t * 1e3))
63  elif bench["time_unit"] == "ms":
64  return TimeNS(int(t * 1e6))
65  elif bench["time_unit"] == "s":
66  return TimeNS(int(t * 1e9))
67 
68 
69 def parse_json_output(json_content: dict) -> list[Data]:
70  """Parse google benchmark JSON output"""
71  data_dict = {}
72  variance_dict = {}
73  for bench in json_content["benchmarks"]:
74  run_name = bench["run_name"]
75  run_type = bench["run_type"]
76  if run_type == "aggregate":
77  aggregate_name = bench["aggregate_name"]
78  variance = variance_dict.setdefault(run_name, Variance.default())
79  if aggregate_name == "mean":
80  variance.mean = make_time_ns(bench)
81  elif aggregate_name == "median":
82  variance.median = make_time_ns(bench)
83  elif aggregate_name == "stddev":
84  variance.stddev = make_time_ns(bench)
85  elif run_type == "iteration":
86  data = data_dict.setdefault(run_name, Data.from_name(run_name))
87  data.times.append(make_time_ns(bench))
88 
89  for k, v in variance_dict.items():
90  data_dict[k].variance = v
91 
92  return list(data_dict.values())
93 
94 
95 def class_name(s: str) -> str:
96  """Convert data name into a nice label"""
97  return s.split("/")[1]
98 
99 
100 def compute_data_lower_upper(data: Data) -> tuple[float, float]:
101  """Compute data lower and upper bounds to display Whisker box"""
102  if data.variance:
103  v = data.variance
104  return (float(v.median - v.stddev), float(v.median + v.stddev))
105  else:
106  return (float("nan"), float("nan"))
107 
108 
109 def flatten_data(data: Data) -> list[tuple[str, TimeNS]]:
110  """Return a list with data name and time"""
111  return [(class_name(data.name), t) for t in data.times]
112 
113 
114 def plot_batch(datas: list[Data]):
115  # Extract all batch classes
116  classes = [class_name(d.name) for d in datas]
117 
118  # Compute lower and upper bounds
119  lower_upper = np.array([compute_data_lower_upper(d) for d in datas])
120  lower = lower_upper[:, 0]
121  upper = lower_upper[:, 1]
122 
123  p = figure(
124  height=768,
125  sizing_mode="stretch_width",
126  x_range=classes,
127  background_fill_color="#efefef",
128  title="Benchmark results",
129  )
130  p.xgrid.grid_line_color = None
131 
132  # Create whisker plot
133  whisker_source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))
134  error = Whisker(
135  base="base",
136  upper="upper",
137  lower="lower",
138  source=whisker_source,
139  level="annotation",
140  line_width=2,
141  )
142  error.upper_head.size = 20
143  error.lower_head.size = 20
144  p.add_layout(error)
145 
146  # Create scatter plot
147  flat_data = np.array(
148  list(itertools.chain.from_iterable([flatten_data(d) for d in datas]))
149  )
150  # We must convert back time to int to avoid rendering issues
151  scatter_source = ColumnDataSource(
152  data=dict(cl=flat_data[:, 0], time=flat_data[:, 1].astype(int))
153  )
154  # We use jitter to avoid plotting all class timing on the same X
155  p.scatter(
156  jitter("cl", 0.3, range=p.x_range),
157  "time",
158  source=scatter_source,
159  alpha=0.5,
160  size=13,
161  line_color="white",
162  color=factor_cmap("cl", "Light7", classes),
163  )
164  return p
165 
166 
167 def plot(datas: list[Data]):
168  figures = []
169  # Batch data by 7 because we use Light7 as color palette
170  for data_batch in itertools.batched(datas, 7):
171  figures.append(plot_batch(list(data_batch)))
172  show(
173  column(
174  figures,
175  sizing_mode="stretch_width",
176  )
177  )
178 
179 
180 def is_file(file: str) -> Path:
181  p = Path(file)
182  if not p.is_file():
183  raise argparse.ArgumentTypeError(f"{file} is not a file")
184  return p
185 
186 
187 def argument_parser() -> argparse.ArgumentParser:
188  parser = argparse.ArgumentParser(description="Plot benchmark results")
189  parser.add_argument("json_output", help="Include directory", type=is_file)
190  return parser
191 
192 
193 def main(args: list[str]):
194  parser = argument_parser()
195  args = parser.parse_args(args)
196 
197  datas = parse_json_output(json.loads(args.json_output.open().read()))
198  plot(datas)
199 
200 
201 if __name__ == "__main__":
202  import sys
203 
204  main(sys.argv[1:])
plot.Data.__init__
def __init__(self, str name, list[TimeNS] times, None|Variance variance)
Definition: plot.py:43
plot.Variance.mean
mean
Definition: plot.py:23
plot.Variance.median
median
Definition: plot.py:24
plot.make_time_ns
TimeNS make_time_ns(dict bench)
Definition: plot.py:56
plot.class_name
str class_name(str s)
Definition: plot.py:95
plot.Data
Definition: plot.py:40
plot
Definition: plot.py:1
plot.main
def main(list[str] args)
Definition: plot.py:193
plot.parse_json_output
list[Data] parse_json_output(dict json_content)
Definition: plot.py:69
plot.Variance.stddev
stddev
Definition: plot.py:25
plot.Data.times
times
Definition: plot.py:45
plot.TimeNS
TimeNS
Definition: plot.py:16
plot.flatten_data
list[tuple[str, TimeNS]] flatten_data(Data data)
Definition: plot.py:109
plot.Data.name
name
Definition: plot.py:44
plot.Variance.__repr__
def __repr__(self)
Definition: plot.py:36
plot.plot_batch
def plot_batch(list[Data] datas)
Definition: plot.py:114
plot.Variance
Definition: plot.py:19
plot.compute_data_lower_upper
tuple[float, float] compute_data_lower_upper(Data data)
Definition: plot.py:100
plot.Variance.default
def default()
Definition: plot.py:33
plot.is_file
Path is_file(str file)
Definition: plot.py:180
plot.Data.__repr__
def __repr__(self)
Definition: plot.py:52
plot.Variance.__init__
def __init__(self, TimeNS mean, TimeNS median, TimeNS stddev)
Definition: plot.py:22
plot.Data.variance
variance
Definition: plot.py:46
plot.argument_parser
argparse.ArgumentParser argument_parser()
Definition: plot.py:187
plot.Data.from_name
def from_name(str name)
Definition: plot.py:49
plot.plot
def plot(list[Data] datas)
Definition: plot.py:167


pinocchio
Author(s):
autogenerated on Fri Apr 25 2025 02:41:41