Source code for src.plotting.policy_gantt_chart
import matplotlib.pyplot as plt
import pandas as pd
from estimagic.visualization.colors import get_colors
[docs]def make_gantt_chart_of_policy_dict(
policies, title=None, bar_height=0.8, bar_color=None, edge_color=None, alpha=1
):
cm_names = sorted({pol["affected_contact_model"] for pol in policies.values()})
positions = dict(zip(cm_names, range(len(cm_names))))
fig, ax = plt.subplots(figsize=(12, len(cm_names)))
edge_color = get_colors("categorical", 1)[0] if edge_color is None else edge_color
bar_color = "#ffffff00" if bar_color is None else bar_color
for pol in policies.values():
affected_model = pol["affected_contact_model"]
start = pd.Timestamp(pol["start"])
end = pd.Timestamp(pol["end"])
ax.broken_barh(
xranges=[(start, end - start)],
yrange=(positions[affected_model] - 0.5 * bar_height, bar_height),
edgecolors=edge_color,
facecolors=bar_color,
alpha=alpha,
)
ax.set_yticks(range(len(cm_names)))
ax.set_yticklabels(cm_names)
if title is not None:
ax.set_title(title.replace("_", " ").title())
return fig, ax