Source code for src.contact_models.task_plot_assortativity

import matplotlib.pyplot as plt
import pandas as pd
import pytask
import seaborn as sns

from src.config import BLD

[docs]_PARAMETRIZATION = [ ( BLD / "contact_models" / "age_assort_params" / "other_non_recurrent.pkl", "assortative_matching_other_non_recurrent_age_group", BLD / "figures" / "data" / "assortativity_other_non_recurrent.pdf", ), ( BLD / "contact_models" / "age_assort_params" / "work_non_recurrent.pkl", "assortative_matching_work_non_recurrent_age_group", BLD / "figures" / "data" / "assortativity_work_non_recurrent.pdf",
), ] @pytask.mark.parametrize("depends_on, loc, produces", _PARAMETRIZATION)
[docs]def task_create_assortativity_heatmap(depends_on, loc, produces): sr = pd.read_pickle(depends_on) fig, ax = _create_heatmap(sr, loc) fig.tight_layout() fig.savefig(produces) plt.close()
[docs]def _create_heatmap(sr, loc): probs = sr.unstack().loc[loc] probs.index.name = "age group" probs.columns.name = "age group of contact" if "work" in loc: non_working_groups = ["0-9", "70-79", "80-100"] probs = probs.drop(non_working_groups, axis=1).drop(non_working_groups, axis=0) assert (probs.sum(axis=1) > 0.85).all() fig, ax = plt.subplots(figsize=(10, 8)) ax = sns.heatmap( probs, annot=True, fmt=".2f", cbar=False, cmap="coolwarm", center=0, ax=ax ) ax.set_yticklabels(ax.get_yticklabels(), rotation=0) return fig, ax