import itertools
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from estimagic.batch_evaluators import joblib_batch_evaluator
warnings.filterwarnings("ignore", message="Polyfit may be poorly conditioned")
[docs]def run_1d_gridsearch(func, params, loc, gridspec, n_seeds, n_cores):
"""Run a grid search over one parameter."""
seeds = _get_seeds(n_seeds)
grid = np.linspace(*gridspec)
n_points = gridspec[-1]
arguments = []
for point, seed in itertools.product(grid, seeds):
p = params.copy(deep=True)
p.loc[loc, "value"] = point
arguments.append({"params": p, "seed": seed})
results = joblib_batch_evaluator(
func=func,
arguments=arguments,
n_cores=n_cores,
unpack_symbol="**",
error_handling="raise",
)
reshaped_results = _reshape_flat_list_2d(results, (n_points, len(seeds)))
avg_values = []
for row in reshaped_results:
values = [res["value"] for res in row]
avg_values.append(np.mean(values))
best_index = np.argmin(avg_values)
if len(grid) <= 3:
order = 1
elif len(grid) <= 5:
order = 2
else:
order = 3
fig, ax = plt.subplots(figsize=(5, 4))
sns.regplot(
x=np.repeat(grid, len(seeds)),
y=[res["value"] for res in results],
order=order,
ax=ax,
)
plt.close()
return reshaped_results, grid, best_index, fig
[docs]def run_2d_gridsearch(
func,
params,
loc1,
gridspec1,
loc2,
gridspec2,
n_seeds,
n_cores,
mask=None,
names=("x_1", "x_2"),
):
"""Run a grid search over two parameters."""
# naming: _x refers to loc1, _y to loc2 and z to function values
names = list(names)
if mask is None:
mask = np.full((gridspec1[-1], gridspec2[-1]), True)
seeds = _get_seeds(n_seeds)
grid_x = np.linspace(*gridspec1)
grid_y = np.linspace(*gridspec2)
dense_grid = np.zeros((mask.sum(), 2))
counter = 0
for i, x in enumerate(grid_x):
for j, y in enumerate(grid_y):
if mask[i, j]:
dense_grid[counter] = x, y
counter += 1
arguments = []
for x, y in dense_grid:
for seed in seeds:
p = params.copy(deep=True)
p.loc[loc1, "value"] = x
p.loc[loc2, "value"] = y
arguments.append({"params": p, "seed": seed})
results = joblib_batch_evaluator(
func=func,
arguments=arguments,
n_cores=n_cores,
unpack_symbol="**",
error_handling="raise",
)
reshaped_results = _reshape_flat_list_2d(results, (mask.sum(), n_seeds))
avg_values = []
for row in reshaped_results:
values = [res["value"] for res in row]
avg_values.append(np.mean(values))
best_index = np.argmin(avg_values)
filled_z = np.full(mask.shape, np.nan)
filled_z[mask] = avg_values
fig, ax = plt.subplots(figsize=(6, 5))
df = pd.DataFrame(
data=filled_z.round(1),
index=map(lambda x: str(x.round(3)), grid_x),
columns=map(lambda x: str(x.round(3)), grid_y),
)
sns.heatmap(df, ax=ax, cmap="YlOrBr", annot=True)
fig.tight_layout()
return reshaped_results, dense_grid, best_index, fig
[docs]def get_mask_around_diagonal(dim, offset=1, flip=True):
"""Get a mask that is true around diagonal or flipped diagonal.
By flipped diagonal we mean the diagonal that goes from bottom left
to top right.
Args:
dim (int): Dimension of the (square) mask.
offset (int): How many rows around the diagonal are
included on each side.
flip (bool): Whether the standard or flipped diagonal
is requested.
Returns:
mask (np.ndarray)
"""
mask = np.full((dim, dim), True)
mask[np.tril_indices(dim, k=-1 - offset)] = False
mask[np.triu_indices(dim, k=1 + offset)] = False
if flip:
mask = np.fliplr(mask)
return mask
[docs]def _get_seeds(n_seeds):
return [500 + 100_000 * i for i in range(n_seeds)]
[docs]def _reshape_flat_list_2d(flat_list, shape):
n_rows, n_cols = shape
assert len(shape) == 2
assert np.prod(shape) == len(flat_list), f"{np.prod(shape)}: {len(flat_list)}"
reshaped = []
entries = iter(flat_list)
for _r in range(n_rows):
row = []
for _c in range(n_cols):
row.append(next(entries))
reshaped.append(row)
return reshaped