Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions pypesto/visualize/_style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Visual style for ``pypesto.visualize``.

Default constants, the ``style_kwargs`` registry, and small cross-module
helpers.

Users override any default per call via ``style_kwargs``, validated against
:data:`_DEFAULTS`::

waterfall(result, style_kwargs={"mle_color": "tab:purple"})
"""

from __future__ import annotations

import warnings

# Colors — semantic roles
# -----------------------
MLE_COLOR = "#d62728" # tab:red — best cluster + MLE markers
OUTLIER_COLOR = "#b3b3b3" # mid-grey — singleton / outlier starts

# Colormaps
# ---------
CMAP_DISCRETE = "tab10" # qualitative: cluster + per-variable colours

# Style registry
# --------------

_DEFAULTS: dict[str, object] = {
"mle_color": MLE_COLOR,
"outlier_color": OUTLIER_COLOR,
"cmap_discrete": CMAP_DISCRETE,
}


def resolve_style(style_kwargs: dict | None = None) -> dict:
"""Return the effective style dict, merging defaults with caller overrides.

Parameters
----------
style_kwargs:
User-supplied overrides. Unknown keys raise a ``UserWarning`` so
typos surface immediately.

Comment thread
Doresic marked this conversation as resolved.
Returns
-------
dict
Merged style dict with all keys from :data:`_DEFAULTS`, with
caller overrides applied on top.
"""
style = dict(_DEFAULTS)
if style_kwargs:
unknown = set(style_kwargs) - set(_DEFAULTS)
if unknown:
warnings.warn(
f"Unknown style_kwargs keys: {sorted(unknown)}. "
f"Valid keys: {sorted(_DEFAULTS)}.",
UserWarning,
stacklevel=3,
)
style.update(style_kwargs)
return style
Comment thread
Doresic marked this conversation as resolved.
100 changes: 86 additions & 14 deletions pypesto/visualize/clust_color.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,64 @@
import matplotlib.cm as cm
from __future__ import annotations

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import is_color_like

from pypesto.util import assign_clusters

# for typehints
from ..C import COLOR
from ._style import resolve_style


def _build_cluster_palette(style: dict) -> np.ndarray:
"""Sample non-best cluster colors from ``cmap_discrete``.

Colors close to ``mle_color`` or ``outlier_color`` are filtered out so
those reserved roles remain visually distinct from cycled cluster colors.
"""
cmap = plt.get_cmap(style["cmap_discrete"])
reserved = [
np.array(mcolors.to_rgb(style["mle_color"])),
np.array(mcolors.to_rgb(style["outlier_color"])),
]

# We remove colors from the cluster palette that are too close to the
# reserved colors (MLE red and outlier grey). The distance threshold is
# just a reasonable heuristic.
_RESERVED_COLOR_DISTANCE = 0.2

# Number of evenly-spaced samples taken from a continuous cmap (e.g. viridis)
# when ``cmap_discrete`` is set to one. Categorical cmaps (e.g. tab10) use
# all their listed colors and ignore this.
_CMAP_DISCRETE_SAMPLES = 10
Comment thread
Doresic marked this conversation as resolved.

if hasattr(cmap, "colors"):
candidates = [mcolors.to_rgba(c) for c in cmap.colors]
else:
candidates = [
cmap(i / (_CMAP_DISCRETE_SAMPLES - 1))
for i in range(_CMAP_DISCRETE_SAMPLES)
]
palette = [
c
for c in candidates
if all(
np.linalg.norm(np.array(c[:3]) - r) > _RESERVED_COLOR_DISTANCE
for r in reserved
)
]
if not palette:
palette = candidates
return np.array(palette)


def assign_clustered_colors(
vals: np.ndarray, balance_alpha: bool = True, highlight_global: bool = True
vals: np.ndarray,
balance_alpha: bool = True,
highlight_global: bool = True,
style: dict | None = None,
):
"""
Cluster and assign colors.
Expand All @@ -23,6 +72,10 @@ def assign_clustered_colors(
avoid overplotting
highlight_global:
flag indicating whether global optimum should be highlighted
style:
Pre-resolved visualization style dict, as returned by
:func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults
are used.

Returns
-------
Expand All @@ -36,20 +89,25 @@ def assign_clustered_colors(
# assign clusters
clusters, cluster_size = assign_clusters(vals)

if style is None:
style = resolve_style(None)
palette = _build_cluster_palette(style)
mle_rgba = list(mcolors.to_rgba(style["mle_color"]))
outlier_rgb = list(mcolors.to_rgb(style["outlier_color"]))

# create list of colors, which has the correct shape
n_clusters = 1 + max(clusters) - sum(cluster_size == 1)

# if best value was found more than once: we need one color less
if highlight_global and cluster_size[0] > 1:
n_clusters -= 1

# fill color array from colormap
colormap = cm.ScalarMappable().to_rgba
color_list = colormap(np.linspace(0.0, 1.0, n_clusters))
# fill color array by cycling through the categorical cluster palette
color_list = palette[np.arange(n_clusters) % len(palette)].copy()

# best optimum should be colored in red
# best optimum should be colored in MLE red
if highlight_global and cluster_size[0] > 1:
color_list = np.concatenate(([[1.0, 0.0, 0.0, 1.0]], color_list))
color_list = np.concatenate(([mle_rgba], color_list))

# We have clustered the results. However, clusters may have size 1,
# so we need to rearrange the regroup the results into "no_clusters",
Expand All @@ -64,8 +122,8 @@ def assign_clustered_colors(
if balance_alpha:
# set minimal alpha value to avoid non-visible colors
min_alpha = 0.01
# assign neutral color, add 1 for avoiding division by zero
grey = [0.7, 0.7, 0.7, min(1.0, 5.0 / (no_clusters.size + 1.0))]
# alpha shrinks with the number of singletons to avoid overplotting
grey = [*outlier_rgb, min(1.0, 5.0 / (no_clusters.size + 1.0))]

# reduce alpha level depend on size of each cluster
n_cluster_size = np.delete(cluster_size, no_clusters)
Expand All @@ -74,8 +132,7 @@ def assign_clustered_colors(
1.0, max(5.0 / n_cluster_size[icluster], min_alpha)
)
else:
# assign neutral color
grey = [0.7, 0.7, 0.7, 1.0]
grey = [*outlier_rgb, 1.0]

# create a color list, prfilled with grey values
Comment thread
Doresic marked this conversation as resolved.
Outdated
colors = np.array([grey] * clusters.size)
Expand All @@ -86,9 +143,9 @@ def assign_clustered_colors(
ind_of_iclust = np.argwhere(clusters == iclust).flatten()
colors[ind_of_iclust, :] = color_list[icol, :]

# if best value was found only once: replace it with red
# if best value was found only once: replace it with MLE red
if highlight_global and cluster_size[0] == 1:
colors[0] = [1.0, 0.0, 0.0, 1.0]
colors[0] = mle_rgba

return colors

Expand All @@ -98,6 +155,7 @@ def assign_colors(
colors: COLOR | list[COLOR] | np.ndarray | None = None,
balance_alpha: bool = True,
highlight_global: bool = True,
style: dict | None = None,
) -> np.ndarray:
"""
Assign colors or format user specified colors.
Expand All @@ -113,6 +171,10 @@ def assign_colors(
avoid overplotting
highlight_global:
flag indicating whether global optimum should be highlighted
style:
Pre-resolved visualization style dict, as returned by
:func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults
are used.

Returns
-------
Expand All @@ -129,6 +191,7 @@ def assign_colors(
vals,
balance_alpha=balance_alpha,
highlight_global=highlight_global,
style=style,
)

# Get number of elements and use user assigned colors
Expand Down Expand Up @@ -160,6 +223,7 @@ def assign_colors(
def assign_colors_for_list(
num_entries: int,
colors: COLOR | list[COLOR] | np.ndarray | None = None,
style: dict | None = None,
) -> list[list[float]] | np.ndarray:
"""
Create a list of colors for a list of items.
Expand All @@ -173,6 +237,10 @@ def assign_colors_for_list(
number of results in list
colors:
list of colors, or single color
style:
Pre-resolved visualization style dict, as returned by
:func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults
are used.

Returns
-------
Expand All @@ -188,7 +256,10 @@ def assign_colors_for_list(

# we don't want alpha levels for all plotting routines in this case...
colors = assign_colors(
dummy_clusters, balance_alpha=False, highlight_global=False
dummy_clusters,
balance_alpha=False,
highlight_global=False,
style=style,
)

# dummy cluster had twice as many entries as really there. Reduce.
Expand All @@ -201,4 +272,5 @@ def assign_colors_for_list(
colors=colors,
balance_alpha=False,
highlight_global=False,
style=style,
)
7 changes: 6 additions & 1 deletion pypesto/visualize/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def process_result_list(
results: Result | list[Result],
colors: COLOR | list[COLOR] | np.ndarray | None = None,
legends: str | list[str] | None = None,
style: dict | None = None,
) -> tuple[list[Result], list[COLOR], list[str]]:
"""
Assign colors and legends to a list of results, check user provided lists.
Expand All @@ -47,6 +48,10 @@ def process_result_list(
list of colors recognized by matplotlib, or single color
legends:
labels for line plots
style:
Pre-resolved visualization style dict, as returned by
:func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults
are used.

Returns
-------
Expand Down Expand Up @@ -84,7 +89,7 @@ def process_result_list(
legend_type_error = True
else:
# if more than one result is passed, we use one color per result
colors = assign_colors_for_list(len(results), colors)
colors = assign_colors_for_list(len(results), colors, style=style)

# check whether list of legends has the correct length
if legends is None:
Expand Down
Loading