# -----------------------------------------------------------------------------.
# MIT License
# Copyright (c) 2024 pycolorbar developers
#
# This file is part of pycolorbar.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----------------------------------------------------------------------------.
"""Module defining BivariateColormap functionalities."""
import base64
import io
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import BoundaryNorm, LogNorm, Normalize, SymLogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from pycolorbar.norm import CategorizeNorm, CategoryNorm, is_categorical_norm
from pycolorbar.utils.docstring import copy_docstring
from pycolorbar.utils.mpl_legend import add_colorbar_inset
# Import optional packages
try:
import pandas as pd
_PANDAS_AVAILABLE = True
except ImportError:
_PANDAS_AVAILABLE = False
try:
import geopandas as gpd
_GEOPANDAS_AVAILABLE = True
except ImportError:
_GEOPANDAS_AVAILABLE = False
try:
import xarray as xr
_XARRAY_AVAILABLE = True
except ImportError:
_XARRAY_AVAILABLE = False
# Global settings
_BIVAR_REPR_PNG_SIZE = (256, 256)
####----------------------------------------------------------------------------------------------.
[docs]
def check_n(n, both_integers=True):
"""Check n value validity."""
# Ensure n is handled as a tuple
if isinstance(n, (int, type(None))):
n = (n, n)
# Check values validity
n_x, n_y = n
if isinstance(n_x, int) and n_x < 2:
raise ValueError("Expected n_x >= 2.")
if isinstance(n_y, int) and n_y < 2:
raise ValueError("Expected n_x >= 2.")
if both_integers and (n_x is None or n_y is None):
raise ValueError("n_x and n_y must be integers.")
return n
[docs]
def ensure_rgba_array(rgba_array):
"""Ensure a RGBA array is returned."""
# Ensure RGBA array
if rgba_array.shape[2] == 3:
alpha_array = np.ones(
(rgba_array.shape[0], rgba_array.shape[1], 1),
dtype=rgba_array.dtype,
)
rgba_array = np.concatenate([rgba_array, alpha_array], axis=2)
return rgba_array
[docs]
def apply_luminance_gradient(image_rgb, luminance_factor=None):
"""
Apply a bivariate luminance gradient to an RGB image.
It add a radial whitening/darkening effect.
Parameters
----------
image_rgb : numpy.ndarray
An (N, N, 4) or (N, N, 3) array of RGBA or RGB values.
luminance_factor : float or None
Radial darkening is obtained with values < 1.
Radial whitening is obtained with values > 1.
None or 1 produce no change.
Returns
-------
numpy.ndarray
The luminance-adjusted image array, same shape as `image_rgb`.
"""
if luminance_factor is None or luminance_factor == 1:
return image_rgb # No change
height, width = image_rgb.shape[:2]
# Create a 2D radial pattern from -1..+1 in x and y
x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
# Compute radial distance
radial_dist = np.sqrt(x**2 + y**2)
# Weight to add "white" (like a radial whitening/darkening effect)
white_term = luminance_factor ** ((np.sqrt(2) - radial_dist) / (2 * np.sqrt(2)))
white_term = np.expand_dims(white_term, axis=2) # for broadcasting
# Add white_term to the first 3 channels (RGB)
image_rgb[..., :3] = image_rgb[..., :3] + white_term
# Rescale each channel to [0..1]
for i in range(3):
channel = image_rgb[..., i]
min_val, max_val = np.nanmin(channel), np.nanmax(channel)
if max_val > min_val: # Avoid division by zero
channel = (channel - min_val) / (max_val - min_val)
image_rgb[..., i] = channel
return image_rgb
[docs]
def interpolate_bivariate_cmap_colors(coords, rgba_colors, n_x, n_y, method="cubic"):
"""Interpolate RGBA colors for a bivariate colormap.
Parameters
----------
coords : array-like, shape (n_points, 2)
The coordinates of the known data points.
rgba_colors : array-like, shape (n_points, 4)
The RGBA color values at the known data points.
n_x : int
The number of points along the x-axis for the output grid.
n_y : int
The number of points along the y-axis for the output grid.
method : str, optional
The interpolation method to use.
Options are 'linear', 'nearest', and 'cubic'.
Default is 'cubic'.
Returns
-------
rgba_array : ndarray, shape (n_y, n_x, 4)
The interpolated RGBA color values on a grid of shape (n_y, n_x).
"""
# Check if the scipy package is available
try:
from scipy.interpolate import griddata
except ImportError:
raise ImportError(
"The 'scipy' package is required but not found. "
"Please install it using the following command: "
"conda install -c conda-forge scipy",
) from None
# Create a mesh for the final NxN image
x_mesh, y_mesh = np.meshgrid(np.linspace(0, 1, n_x), np.linspace(0, 1, n_y))
x_req = x_mesh.ravel()
y_req = y_mesh.ravel()
# Interpolate each channel separately
output_rgba = np.ones((n_y * n_x, 4)) * np.nan
for i in range(4):
channel_vals = griddata(
points=coords,
values=rgba_colors[:, i],
xi=np.column_stack((x_req, y_req)),
method=method,
)
# Clip channel values to [0..1]
channel_vals[channel_vals < 0] = 0
channel_vals[channel_vals > 1] = 1
output_rgba[:, i] = channel_vals
# Reshape to 2D
rgba_array = output_rgba.reshape(n_y, n_x, 4)
return rgba_array
[docs]
def resample_rgba_array(rgba_array, n_x=None, n_y=None, interp_method=None):
"""Resample an RGBA array to a new number of colors.
Parameters
----------
rgba_array : numpy.ndarray
Input RGBA array of shape (height, width, 4).
n_x : int, optional
Desired width of the resampled array. If None, the original width is used.
n_y : int, optional
Desired height of the resampled array. If None, the original height is used.
interp_method : str, optional
Interpolation method to use.
If None, 'nearest' is used for downsampling and 'cubic' for upsampling.
Returns
-------
numpy.ndarray
Resampled RGBA array of shape (n_y, n_x, 4).
Notes
-----
- If both `n_x` and `n_y` are None, the original `rgba_array`
is returned without resampling.
- If the desired size (`n_x`, `n_y`) is the same as the input size,
the original `rgba_array` is returned.
- The interpolation method defaults to 'nearest' for downsampling
and 'cubic' for upsampling if not specified.
"""
# Ensure rgba array
rgba_array = ensure_rgba_array(rgba_array)
# Return rgba array without resampling if n_x and n_y not specified
if n_x is None and n_y is None:
return rgba_array
# Initialize n_x and n_y if None
if n_x is None:
n_x = rgba_array.shape[1]
if n_y is None:
n_y = rgba_array.shape[0]
# If desired size equal to input size, return array as it is
if n_x == rgba_array.shape[1] and n_y == rgba_array.shape[0]:
return rgba_array
# Define interpolation method
# - If downsampling, use 'nearest' by default
# - If upsampling, use 'cubic' by default
if interp_method is None:
if n_x < rgba_array.shape[1] and n_y < rgba_array.shape[0]: # noqa
interp_method = "nearest"
else:
interp_method = "cubic"
# Define mesh
x_mesh, y_mesh = np.meshgrid(np.linspace(0, 1, rgba_array.shape[1]), np.linspace(0, 1, rgba_array.shape[0]))
# Flatten to get coords of shape (self.n_x*self.n_y, 2)
coords = np.column_stack((x_mesh.ravel(), y_mesh.ravel()))
# Flatten RGBA array to shape (self.n_x*self.n_y, 4)
rgba_colors = rgba_array.reshape(-1, 4)
# Interpolate colors to the new mesh
rgba_array = interpolate_bivariate_cmap_colors(
coords=coords,
rgba_colors=rgba_colors,
n_x=n_x,
n_y=n_y,
method=interp_method,
)
return rgba_array
# def interpolate_corners_colors(color_array, n_x, n_y):
# """
# Zoom a 2x2 corner color array to an (n_y, n_x) shape.
# Parameters
# ----------
# color_array : numpy.ndarray
# Shape (2, 2, 4) RGBA color corners.
# n_x : int
# Output width.
# n_y : int
# Output height.
# Returns
# -------
# numpy.ndarray
# An (n_y, n_x, 4) array of interpolated RGBA colors.
# """
# from scipy.ndimage import zoom
# zoom_factor_y = n_y / 2.0
# zoom_factor_x = n_x / 2.0
# return zoom(color_array, (zoom_factor_y, zoom_factor_x, 1), order=1) # order=1 => bilinear interpolation
####-----------------------------------------------------------------------------.
##################################
#### Create bivariate palette ####
##################################
# def get_bivariate_cmap_from_corners(
# colors=("grey", "green", "red", "blue"),
# n=(3, 3),
# ):
# """
# Create a bivariate colormap from the colors at the four corners.
# The interpolation is performed into the sRGB colorspace.
# Parameters
# ----------
# colors : list or tuple
# Four color recognized by matplotlib (e.g. 'red', '#RRGGBB', etc.).
# The order correspond to [top_left, top_right, bottom_right, bottom_left]
# n : int or tuple
# Either a single integer or a (n_x, n_y) tuple specifying the number of colormap colors
# on the x and y axis.
# Returns
# -------
# numpy.ndarray
# A 2D array of shape (n_y, n_x, 4) representing RGBA colors.
# """
# # Check 4 colors are specified
# if len(colors) != 4:
# raise ValueError("The BivariateColormap definition from corners requires the specification of 4 colors.")
# # Retrieve number of colors per axis
# n_x, n_y = check_n(n)
# # Convert corner colors into an RGBA 2x2 matrix
# rgba_corners = np.array([
# [mpl.colors.to_rgba(colors[0]),
# mpl.colors.to_rgba(colors[1])],
# [mpl.colors.to_rgba(colors[3]),
# mpl.colors.to_rgba(colors[2])],
# ])
# # Zoom up to desired (n_y, n_x)
# rgba_array = interpolate_corners_colors(rgba_corners, n_x=n_x, n_y=n_y)
# return rgba_array
[docs]
def get_bivariate_cmap_from_colors(colors, n=5, interp_method=None):
"""
Create a bivariate colormap from a set of color points.
This function takes a set of RGBA (or named) colors arranged in a clockwise layout
(plus an optional center color if 5 or 9 points), and interpolates them over an
(n_y, n_x) grid in sRGB colorspace.
Parameters
----------
colors : list or tuple
A list or array of color specifications recognized by Matplotlib
(e.g., ['red', 'blue', 'green', 'black']) or an array of shape (N, 4) in RGBA format.
The length of `colors` must be one of {4, 5, 8, 9}. See the Notes below.
n : int or tuple
Either a single integer or a (n_x, n_y) tuple specifying the number of colormap colors
on the x and y axis.
interp_method : str, optional
The interpolation method to use for generating the colormap.
The default is 'cubic'.
Notes
-----
- The color points are assumed to be laid out clockwise, with optional
center colors for 5 or 9 entries. Specifically:
- For 4 colors: [top_left, top_right, bottom_right, bottom_left].
- For 5 or 9 colors: the extra point(s) become center row (or center cell).
- For 8 colors: midpoints along each side plus corners.
Returns
-------
numpy.ndarray
Shape (n_y, n_x, 4) array of interpolated RGBA.
"""
# Check interp_method
if interp_method is None:
interp_method = "cubic"
# Check number of colors
colors = np.array(colors)
n_colors = colors.shape[0]
if n_colors not in [4, 5, 8, 9]:
raise ValueError("You must specify either 4, 5, 8 or 9 colors.")
# Retrieve number of colors per axis
n_x, n_y = check_n(n)
# Convert list of color specs to RGBA
rgba_colors = mpl.colors.to_rgba_array(colors)
dict_coords = {
4: np.column_stack(([0, 1, 1, 0], [1, 1, 0, 0])),
5: np.column_stack(([0, 1, 1, 0, 0.5], [1, 1, 0, 0, 0.5])),
8: np.column_stack(([0, 0.5, 1, 1, 1, 0.5, 0, 0.5], [1, 1, 1, 0.5, 0, 0, 0, 0.5])),
9: np.column_stack(([0, 0.5, 1, 1, 1, 0.5, 0, 0.5, 0.5], [1, 1, 1, 0.5, 0, 0, 0, 0.5, 0.5])),
}
coords = dict_coords[n_colors]
rgba_array = interpolate_bivariate_cmap_colors(coords, rgba_colors, n_x=n_x, n_y=n_y, method=interp_method)
return rgba_array
[docs]
def get_bivariate_cmap_from_two_cmaps(cmap_x=plt.cm.Blues, cmap_y=plt.cm.Reds, n=256):
"""
Construct a bivariate colormap by blending two univariate colormaps along x and y axes.
Parameters
----------
cmap_x : matplotlib.colors.Colormap or str
Univariate colormap for the x axis.
cmap_y : matplotlib.colors.Colormap or str
Univariate colormap for the y axis.
n : int or tuple
Either a single integer or a (n_x, n_y) tuple specifying the number of colormap colors
on the x and y axis.
Returns
-------
numpy.ndarray
Shape (n_y, n_x, 4) array of RGBA representing the combined colormap.
"""
import pycolorbar
cmap_x = pycolorbar.get_cmap(cmap_x)
cmap_y = pycolorbar.get_cmap(cmap_y)
# Retrieve number of colors per axis
n_x, n_y = check_n(n)
# Generate a mesh grid in [0..1]
x_mesh, y_mesh = np.meshgrid(np.linspace(0, 1, n_x), np.linspace(0, 1, n_y))
# Evaluate each colormap along its axis
x_rgba = cmap_x(x_mesh) # shape (n_y, n_x, 4)
y_rgba = cmap_y(y_mesh) # shape (n_y, n_x, 4)
# Blending by average
rgba_array = np.mean([x_rgba, y_rgba], axis=0)
return rgba_array
####-----------------------------------------------------------------------------.
################################
#### Load bivariate palette ####
################################
#### BIVARIATE_CMAPS_DICT
BIVARIATE_CMAPS_DICT = {
# -------------------------------
#### - N=3
# Transparency
"brewer.qualseq": [
["#cce8d7", "#cedced", "#fbb4d9"],
["#80c39b", "#85a8d0", "#f668b3"],
["#008837", "#0a50a1", "#d60066"],
],
"brewer.blueyellow": [
["#efd100", "#4eb87b", "#007fc4"],
["#fef3a9", "#bedebc", "#a1c8ea"],
["#fffdef", "#e6f1df", "#d2e4f6"],
],
# Sequentials
"brewer.seqseq1": [
["#e8e6f2", "#b5d3e7", "#4fadd0"],
["#e5b4d9", "#b8b3d8", "#3983bb"],
["#de4fa6", "#b03598", "#2a1a8a"],
],
"brewer.seqseq2": [
["#f3f3f3", "#b4d3e1", "#509dc2"],
["#f3e6b3", "#b3b3b3", "#376387"],
["#f3b300", "#b36600", "#000000"],
],
"stevens.greenblue": [
["#e8e8e8", "#b5c0da", "#6c83b5"],
["#b8d6be", "#90b2b3", "#567994"],
["#73ae80", "#5a9178", "#2a5a5b"],
],
"stevens.bluered": [
["#e8e8e8", "#e4acac", "#c85a5a"],
["#b0d5df", "#ad9ea5", "#985356"],
["#64acbe", "#627f8c", "#574249"],
],
"stevens.pinkgreen": [
["#f3f3f3", "#c2f1ce", "#8be2af"],
["#eac5dd", "#9ec6d3", "#7fc6b1"],
["#e6a3d0", "#bc9fce", "#7b8eaf"],
],
"stevens.purplegold": [
["#e8e8e8", "#e4d9ac", "#c8b35a"],
["#cbb8d7", "#c8ada0", "#af8e53"],
["#9972af", "#976b82", "#804d36"],
],
"stevens.pinkblue": [
["#e8e8e8", "#ace4e4", "#5ac8c8"],
["#dfb0d6", "#a5add3", "#5698b9"],
["#be64ac", "#8c62aa", "#3b4994"],
],
"tolochko.redblue": [
["#dddddd", "#7bb3d1", "#016eae"],
["#dd7c8a", "#8d6c8f", "#4a4779"],
["#cc0024", "#8a274a", "#4b264d"],
],
# Divergents
"brewer.divdiv": [
["#f37300", "#cce88b", "#008837"],
["#fe9aa6", "#e6e6e6", "#9ac9d5"],
["#f0047f", "#cd9acc", "#5a4da4"],
],
"teuling.choro3": [
["#F4843A", "#F16879", "#A63F96"],
["#BCD85E", "#F0F0F0", "#5C6AB1"],
["#00AC51", "#00B18A", "#18B3E5"],
],
# Divergent-sequentials
"brewer.divseq": [
["#c3b3d8", "#e6e6e6", "#ffcc80"],
["#7b67ab", "#bfbfbf", "#f35926"],
["#240d5e", "#7f7f7f", "#b30000"],
],
# -------------------------------
#### - N=4
"arc.bluepink": [
["#ffffff", "#ffe6fe", "#ffbdff", "#ff80fe"],
["#e7ffff", "#d7dafd", "#d8a6ff", "#c065fe"],
["#c0fcfd", "#a7caff", "#8d7efd", "#7f65fe"],
["#74feff", "#64c0ff", "#5873fe", "#4b4cff"],
],
# -------------------------------
#### - N=5
"teuling.choro5": [
["#F4843A", "#F37C57", "#F16879", "#EE4497", "#A63F96"],
["#FCB73D", "#F9B381", "#F9B381", "#C980B5", "#8054A1"],
["#BCD85E", "#D4E6A3", "#F0F0F0", "#A9A4D0", "#5C6AB1"],
["#5C6AB1", "#80C99C", "#78CDCE", "#49B5E7", "#2477BC"],
["#00AC51", "#00B18A", "#00B18A", "#18B3E5", "#18B3E5"],
],
#### - FROM NPY
"bremm": "bremm.npy",
"cubediagonal": "cubediagonal.npy",
"schumann": "schumann.npy",
"steiger": "steiger.npy",
"ziegler": "ziegler.npy",
"teuling2": "teuling2.npy",
}
#### TEULING_CMAPS
TEULING_CMAPS = ["teuling.GRMB", "teuling.YGBR", "teuling.RBCG"]
[docs]
def available_bivariate_colormaps():
"""Get the list of the predefined available bivariate colormaps."""
names = list(BIVARIATE_CMAPS_DICT) + TEULING_CMAPS + ["LABspace"]
return sorted(names)
[docs]
def check_name(name):
"""Check name validity."""
if not isinstance(name, str):
raise TypeError("'name' must be a string.")
valid_names = available_bivariate_colormaps()
if name not in valid_names:
raise ValueError(f"Invalid 'name' {name}. Available names are {valid_names}")
return name
[docs]
def load_bivariate_palette(name, n, interp_method=None):
"""Load predefined bivariate palette."""
from pycolorbar import _root_path
colors_array_or_filename = BIVARIATE_CMAPS_DICT[name]
if isinstance(colors_array_or_filename, str):
filepath = os.path.join(_root_path, "pycolorbar", "bivariate", "data", colors_array_or_filename)
rgba_array = np.load(filepath) / 255
else:
colors_array = np.array(colors_array_or_filename)
rgba_array = mpl.colors.to_rgba_array(colors_array.ravel()).reshape(*((*colors_array.shape, 4)))
# Ensure rgba array
rgba_array = ensure_rgba_array(rgba_array)
# Resample if asked
n_x, n_y = check_n(n, both_integers=False)
rgba_array = resample_rgba_array(rgba_array, n_x=n_x, n_y=n_y, interp_method=interp_method)
return rgba_array
def _get_teuling_colors(name="teuling_GRMB", diagonal_tilt=0.8, offdiag_tilt=1):
dt = diagonal_tilt
odt = offdiag_tilt
color_dict = {
"teuling.GRMB": np.array(
[dt, 1, dt, odt, 0.5, 1 - odt, 1 - dt, 0, 1 - dt, 1 - odt, 0.5, odt, 0.5, 0.5, 0.5],
).reshape(5, 3),
"teuling.YGBR": np.array(
[dt, dt, 1, odt, 1 - odt, 0.5, 1 - dt, 1 - dt, 0, 1 - odt, odt, 0.5, 0.5, 0.5, 0.5],
).reshape(5, 3),
"teuling.RBCG": np.array(
[1, dt, dt, 0.5, odt, 1 - odt, 0, 1 - dt, 1 - dt, 0.5, 1 - odt, odt, 0.5, 0.5, 0.5],
).reshape(5, 3),
}
return color_dict[name]
[docs]
def get_bivariate_cmap_teuling(name, n, diagonal_tilt, offdiag_tilt, interp_method=None):
"""Generate a bivariate colormap using the Teuling method.
Parameters
----------
name : str
The name of the colormap.
n : int
The number of colors in the colormap.
diagonal_tilt : float
The tilt angle for the diagonal colors.
offdiag_tilt : float
The tilt angle for the off-diagonal colors.
interp_method : str, optional
The interpolation method to use for generating the colormap.
Returns
-------
rgba_array : numpy.ndarray
An array of RGBA colors representing the bivariate colormap.
"""
colors = _get_teuling_colors(name=name, diagonal_tilt=diagonal_tilt, offdiag_tilt=offdiag_tilt)
rgba_array = get_bivariate_cmap_from_colors(colors=colors, n=n, interp_method=interp_method)
return rgba_array
[docs]
def get_bivariate_cmap_labspace(n):
"""Generate a bivariate colormap in CIELAB space.
This function creates a 2D grid of (L, a, b) values, applies non-linear
transformations to the 'a' and 'b' channels, and converts the result to
an RGB image.
Parameters
----------
n : int or tuple
- If int, produce an (n, n) 2D grid.
- If tuple (n_x, n_y), produce a (n_y, n_x) 2D grid.
Returns
-------
rgb_array : numpy.ndarray
A 3D array of shape (n_y, n_x, 3), representing the RGB image.
Pixel values lie in [0, 1].
Notes
-----
- LAB color space is perceptually uniform, so small changes in L, a, b
correspond to relatively uniform changes in visual perception.
- The tanh scaling on 'a' and 'b' compresses large values, preventing
extreme color shifts.
"""
# Check if the colorspacious package is available
try:
from colorspacious import cspace_convert
except ImportError:
raise ImportError(
"The 'colorspacious' package is required but not found. "
"Please install it using the following command: "
"conda install -c conda-forge colorspacious",
) from None
# Check argument
n_x, n_y = check_n(n)
# Create a grid in the range [-100, 100] for both a and b dimensions.
spacing_x = np.linspace(-100, 100, n_x)
spacing_y = np.linspace(-100, 100, n_y)
# Create a meshgrid with shape (n_y, n_x)
a, b = np.meshgrid(spacing_x, spacing_y)
# Define the L channel
# - Center L at 75
# - Perturb with a linear combination of 'a' and 'b'
l = np.ones_like(a) * 75 + b * 0.15 - a * 0.3 # noqa: E741
# Create the LAB array
lab_array = np.dstack([l, a, b])
# Compress the a and b channels using tanh, to prevent extreme color shifts.
lab_array[:, :, 1] = np.tanh(lab_array[:, :, 1] / 130) * 100 # a
lab_array[:, :, 2] = np.tanh(lab_array[:, :, 2] / 190) * 100 # b
# Convert from LAB to RGB
# - a and b in [-128, 127]
lab_array[:, :, 0] = np.clip(lab_array[:, :, 0], 0, 100)
lab_array[:, :, 1] = np.clip(lab_array[:, :, 1], -128, 127)
lab_array[:, :, 2] = np.clip(lab_array[:, :, 2], -128, 127)
rgba_array = cspace_convert(lab_array, "CIELab", "sRGB1")
rgba_array = np.clip(rgba_array, 0, 1)
return rgba_array
[docs]
def get_bivariate_cmap_from_name(name, n, diagonal_tilt=0.8, offdiag_tilt=1, interp_method=None):
"""Retrieve a bivariate colormap based on the specified name and parameters.
Parameters
----------
name : str
The name of the bivariate colormap to retrieve.
See available_biviariate_colormaps().
n : int
The number of colors in the colormap.
diagonal_tilt : float, optional
The tilt of the diagonal in the colormap, by default 0.8.
Used only for Teuling colormaps.
offdiag_tilt : float, optional
The tilt of the off-diagonal in the colormap, by default 1.
Used only for Teuling colormaps.
interp_method : str or None, optional
The interpolation method to use, by default None.
Returns
-------
rgba_array : numpy.ndarray
An array of RGBA values representing the colormap.
"""
# Check arguments
n = check_n(n)
name = check_name(name)
##------------------------------------------------------------------------.
# Retrieve Teuling bivariate cmap
if name in TEULING_CMAPS:
rgba_array = get_bivariate_cmap_teuling(
name=name,
n=n,
diagonal_tilt=diagonal_tilt,
offdiag_tilt=offdiag_tilt,
interp_method=interp_method,
)
return rgba_array
##------------------------------------------------------------------------.
# Retrieve LABspace cmap
if name == "LABspace":
rgba_array = get_bivariate_cmap_labspace(n)
return rgba_array
##------------------------------------------------------------------------.
# Retrieve bivariate cmap from file or dictionary
# if name in BIVARIATE_CMAPS_DICT:
rgba_array = load_bivariate_palette(name=name, n=n, interp_method=interp_method)
return rgba_array
####------------------------------------------------------------------------.
#### Color Mapping
def _map_colors(x_normalized, y_normalized, n_x, n_y, rgba_array, origin="lower"):
# Create output filled with NaN
out_shape = (*x_normalized.shape, 4)
rgba_mapped = np.full(out_shape, np.nan, dtype=float)
# Valid mask: both coordinates finite
valid = np.isfinite(x_normalized) & np.isfinite(y_normalized)
if not np.any(valid):
return rgba_mapped
# Extract valid values
x_valid = x_normalized[valid]
y_valid = y_normalized[valid]
# Compute indices
# - left-inclusive
# - right-exclusive
# - except for 1.0 which is clipped
x_idx = np.clip(np.floor(x_valid * n_x).astype(int), 0, n_x - 1)
y_idx = np.clip((np.floor(y_valid * n_y)).astype(int), 0, n_y - 1)
# Apply origin
y_idx = (n_y - 1) - y_idx if origin == "lower" else y_idx
# Assign mapped colors
rgba_mapped[valid] = rgba_array[y_idx, x_idx]
# Final safety clip
np.clip(rgba_mapped, 0.0, 1.0, out=rgba_mapped)
return rgba_mapped
[docs]
def map_colors(
x_normalized,
y_normalized,
n_x,
n_y,
rgba_array,
mask,
bad_color=(0.0, 0.0, 0.0, 0.0),
origin="lower",
):
"""
Map normalized (x,y) data in [0..1] to RGBA by interpolating in the 2D colormap RGBA array.
Parameters
----------
x_normalized : numpy.ndarray
Values in [0..1], same shape as y_normalized.
y_normalized : numpy.ndarray
Values in [0..1], same shape as x_normalized.
Note y=0 corresponds to the bottom row, y=1 to the top row.
n_x, n_y : int
Number of columns (x-axis) and rows (y-axis) in the colormap.
rgba_array : numpy.ndarray
The bivariate colormap RGBA array of shape (n_y, n_x, 4).
mask : numpy.ndarray
Boolean mask of shape (n_y, n_x) where True indicates NaN or invalid data.
bad_color : tuple of float
RGBA color to assign to invalid points or out-of-bounds values.
origin : str
Either "upper" or "lower". The default is "lower".
When "lower", the axis origin is on the bottom left and values (0,0) and (0,1) are
mapped respectively to rgba_array[-1, 0, :] and rgba_array[0, 0, :]
Returns
-------
rgba_mapped : numpy.ndarray
An RGBA array with out-of-range values or masked points set to `bad_color`.
"""
# Map colors
rgba_mapped = _map_colors(
x_normalized=x_normalized,
y_normalized=y_normalized,
n_x=n_x,
n_y=n_y,
rgba_array=rgba_array,
origin=origin,
)
# Assign bad color for NaNs or out-of-bounds
bad_mask = mask | (x_normalized < 0) | (x_normalized > 1) | (y_normalized < 0) | (y_normalized > 1)
rgba_mapped[bad_mask] = bad_color
return rgba_mapped
[docs]
def define_norm(arr):
"""
Define a matplotlib.Normalize object based on array's values.
Parameters
----------
arr : array-like
Input array to be normalized.
Returns
-------
Normalize or None
A Normalize object with the minimum and maximum values of the array,
or None if the array contains non-numeric data.
Raises
------
ValueError
If the array has all identical values or contains only NaNs.
Notes
-----
If the array contains non-numeric data, it returns None.
"""
arr = np.asanyarray(arr)
if not np.issubdtype(arr.dtype, np.number): # edge case for category dtype
return None
vmin, vmax = np.nanmin(arr), np.nanmax(arr)
if vmin == vmax or np.isnan(vmin):
raise ValueError(f"Please specify the norm because all array values are {vmin}.")
return Normalize(vmin=vmin, vmax=vmax)
[docs]
def check_expected_number_categories(n, n_categories, dim_name, obj="norm"):
"""Check that number of color matches the number of categories."""
if n_categories != n:
msg = (
f"The colormap has {n} colors on {dim_name}, but the {obj} indicates {n_categories} categories. "
+ "Please adapt the bivariate colormap to the number of expected categories."
)
raise ValueError(msg)
[docs]
def check_cmap_ncolors(norm, n, dim_name):
"""Check for the number of colors for categorical norms."""
if norm is not None and is_categorical_norm(norm):
n_categories = norm.Ncmap
check_expected_number_categories(n, n_categories, dim_name=dim_name)
[docs]
def create_pandas_category_norm(series):
"""Create a CategoryNorm object for a pandas Categorical series.
Parameters
----------
series : pandas.Series or geopandas.GeoSeries
A pandas or geopandas Series with categorical data.
Returns
-------
pycolorbar.norm.CategoryNorm
A CategoryNorm object that maps category integer indices to category names.
Notes
-----
This function assumes that the input series is of categorical dtype.
It creates a dictionary mapping category integer indices to category names and
uses this dictionary to initialize a CategoryNorm object.
"""
indices = np.arange(0, len(series.cat.categories)).astype(int).tolist()
categories = list(series.cat.categories)
categories_dict = dict(zip(indices, categories, strict=False))
norm = CategoryNorm(categories_dict)
return norm
[docs]
def normalize_array(arr, norm):
"""
Normalize a numeric numpy array to [0..1] using either a Matplotlib norm or min-max scaling.
If `norm` is a BoundaryNorm, we also scale the resulting integer bin indices from [0..(N-1)] to [0..1].
Parameters
----------
arr : array-like
Numeric array (possibly containing NaNs).
norm : None or matplotlib.colors.Normalize
If not None, `norm(arr)` is called.
If it's a BoundaryNorm, we further scale the integer output to [0..1].
Returns
-------
numpy.ndarray
Array of same shape as `arr` with values in [0..1], except for NaNs in `arr`,
which propagate as NaNs here.
"""
# Deal with case if all NaN values
if np.all(np.isnan(arr)):
return np.ones_like(arr) * np.nan
# If norm not provided, normalize based on min/max values of data
# --> TODO: This does not happen anymore
# if norm is None:
# vmin, vmax = np.nanmin(arr), np.nanmax(arr)
# rng = vmax - vmin
# if rng > 0:
# arr_normed = (arr - vmin) / rng
# return arr_normed
# Normalize using provided norm to [0-1]
# - The norm typically normalize values to [0-1]
arr_normed = norm(arr)
# - If norm scaled to integer indices instead of 1
# --> (i.e. Boundary Norm, CategoryNorm, CategorizeNorm), scale to [0-1]
if hasattr(norm, "Ncmap"):
arr_normed = arr_normed.data / norm.Ncmap
return arr_normed
[docs]
def normalize_pandas_series(series, norm):
"""
Normalize or encode a pandas Series to [0..1].
- If it's categorical, we map category indices -> [0..1].
- Otherwise, it relies on `normalize_array`.
Parameters
----------
series : pandas.Series or geopandas.GeoSeries
The data series to normalize.
norm : None or matplotlib.colors.Normalize
Same as in `normalize_array`.
Returns
-------
numpy.ndarray
Array of float in [0..1] (except for NaNs).
"""
if isinstance(series.dtype, pd.CategoricalDtype):
# Scale category indices to [0-1]
cat_indices = series.cat.codes.to_numpy(float)
n_categories = len(series.cat.categories)
normed_values = cat_indices / max(n_categories - 1, 1)
return normed_values
return normalize_array(series.to_numpy(float), norm=norm)
[docs]
def map_array_data(x, y, norm_x, norm_y, rgba_array, bad_color=(0.0, 0.0, 0.0, 0.0)):
"""
Map two numeric arrays (x, y) to a bivariate colormap.
Parameters
----------
x : array-like
Numeric array for the x dimension.
y : array-like
Numeric array for the y dimension. Must match shape of x.
norm_x : None or matplotlib.colors.Normalize
Normalization for x dimension. If None, do min-max scaling.
norm_y : None or matplotlib.colors.Normalize
Normalization for y dimension. If None, do min-max scaling.
rgba_array : numpy.ndarray
The base bivariate colormap of shape (n_y, n_x, 4).
bad_color : tuple
RGBA color for invalid points.
Returns
-------
numpy.ndarray
Mapped RGBA array of shape = x.shape + (4,).
"""
# Retrieve number of x and y colors
n_y, n_x, _ = rgba_array.shape
# Ensure numpy arrays (i.e. put dask array into memory)
x = np.asanyarray(x, dtype=float)
y = np.asanyarray(y, dtype=float)
# Check same shape
if x.shape != y.shape:
raise ValueError("x and y must have the same shape.")
# Check number of colors if categorical norm
check_cmap_ncolors(norm_y, n=n_y, dim_name="y")
check_cmap_ncolors(norm_x, n=n_x, dim_name="x")
# Build a mask for NaNs
mask = np.isnan(x) | np.isnan(y)
# Normalize arrays to [0-1]
x_normalized = normalize_array(x, norm_x)
y_normalized = normalize_array(y, norm_y)
# Map values to colors
rgba_mapped = map_colors(
x_normalized=x_normalized,
y_normalized=y_normalized,
n_x=n_x,
n_y=n_y,
rgba_array=rgba_array,
mask=mask,
bad_color=bad_color,
)
return rgba_mapped
[docs]
def map_pandas_data(x, y, norm_x, norm_y, rgba_array, bad_color=(0.0, 0.0, 0.0, 0.0)):
"""
Map two pandas Series (x, y) to a bivariate colormap.
Supports categorical Series and geopandas.GeoSeries.
Parameters
----------
x : pandas.Series or geopandas.Series
Data for the x dimension.
y : pandas.Series or geopandas.Series
Data for the y dimension. Must match shape of x.
norm_x : None or matplotlib.colors.Normalize
Normalization for x dimension. If None, do min-max or categorical scaling.
norm_y : None or matplotlib.colors.Normalize
Normalization for y dimension.
rgba_array : numpy.ndarray
The bivariate colormap, shape (n_y, n_x, 4).
bad_color : tuple
RGBA color for invalid or out-of-bounds points.
Returns
-------
numpy.ndarray
Mapped RGBA array of shape = x.shape + (4,).
"""
# Check same shape
if x.shape != y.shape:
raise ValueError("x and y must have the same shape.")
# Retrieve number of x and y colors
n_y, n_x, _ = rgba_array.shape
# Check for categorical pd.Series
if isinstance(x.dtype, pd.CategoricalDtype):
check_expected_number_categories(
n=n_x,
n_categories=len(x.cat.categories),
dim_name="x",
obj="categorical pd.Series",
)
if isinstance(y.dtype, pd.CategoricalDtype):
check_expected_number_categories(
n=n_y,
n_categories=len(y.cat.categories),
dim_name="y",
obj="categorical pd.Series",
)
# Build initial mask for NaNs
mask = x.isna() | y.isna()
mask = mask.to_numpy()
# Normalize series to [0-1]
x_normalized = normalize_pandas_series(x, norm_x)
y_normalized = normalize_pandas_series(y, norm_y)
# Map values to colors
rgba_mapped = map_colors(
x_normalized=x_normalized,
y_normalized=y_normalized,
n_x=n_x,
n_y=n_y,
rgba_array=rgba_array,
mask=mask,
bad_color=bad_color,
)
return rgba_mapped
[docs]
def map_xarray_data(x, y, norm_x, norm_y, rgba_array, bad_color=(0.0, 0.0, 0.0, 0.0)):
"""
Map two xarray DataArrays (x, y) to a bivariate colormap.
Broadcasts x and y if needed, then returns an xarray DataArray with a new "rgba" dimension.
Parameters
----------
x : xarray.DataArray
Data for x dimension.
y : xarray.DataArray
Data for y dimension.
norm_x : None or matplotlib.colors.Normalize
Normalization for x dimension.
norm_y : None or matplotlib.colors.Normalize
Normalization for y dimension.
rgba_array : numpy.ndarray
Bivariate colormap of shape (n_y, n_x, 4).
bad_color : tuple
RGBA color for invalid points.
Returns
-------
xarray.DataArray
Same shape as x,y plus a final "rgba" dimension of size 4.
"""
# Broadcast x,y if they differ in dimension but are broadcastable
x, y = xr.broadcast(x, y)
# Retrieve RGBA numpy array
rgba_arr = map_array_data(
x.data,
y.data,
norm_x=norm_x,
norm_y=norm_y,
rgba_array=rgba_array,
bad_color=bad_color,
)
# Convert back to xarray
new_dims = (*x.dims, "rgba")
coords_dict = dict(x.coords) # copy original coords if desired
coords_dict["rgba"] = ["R", "G", "B", "A"]
da_rgba = xr.DataArray(rgba_arr, coords=coords_dict, dims=new_dims)
return da_rgba
####------------------------------------------------------------------------.
#### Plotting
[docs]
def plot_bivariate_palette(
rgba_array,
ax=None,
*,
xlim=None,
ylim=None,
disable_axis=False,
origin="upper",
aspect="auto",
**imshow_kwargs,
):
"""Plot the bivariate colormap.
Parameters
----------
rgba_array : numpy.ndarray
A 2D array of RGBA values representing the bivariate colormap.
ax : matplotlib.axes.Axes, optional
The axes on which to plot the colormap. If None, a new figure and axes will be created.
xlim : list or tuple, optional
The x-axis limits for the plot. If None, defaults to [0, n_x - 1].
ylim : list or tuple, optional
The y-axis limits for the plot. If None, defaults to [0, n_y - 1].
disable_axis : bool, optional
If True, the axis will be turned off. Default is False.
origin : {'upper', 'lower'}, optional
The origin of the colormap. Default is 'upper'.
aspect: str or float
Either 'equal' or 'auto' or float.
Controls the axes scaling (y/x-scale).
If 'auto' fill the Axes position rectangle with data.
If 'equal', ensure same scaling between x and y axis.
The default is 'auto'.
**imshow_kwargs : dict, optional
Additional keyword arguments to pass to `imshow`.
Returns
-------
matplotlib.image.AxesImage
The image object created by `imshow` representing the bivariate colormap.
Notes
-----
If no axes are provided, a new figure and axes are created.
The x and y limits can be specified, and the axis can be disabled if desired.
The origin parameter determines the placement of the origin axis.
"""
# Retrieve n_y and n_x
n_y, n_x = rgba_array.shape[0:2]
# Define xlim and ylim
if xlim is None:
xlim = [0 - 0.5, n_x - 1 + 0.5]
xlim = list(xlim)
if ylim is None:
ylim = [0 - 0.5, n_y - 1 + 0.5]
ylim = list(ylim)
# Initialize plot if necessary
axis_not_provided = ax is None
if axis_not_provided:
fig, ax = plt.subplots(1, 1) # noqa: RUF059
# Define extent (at pixel outer corners)
extent = xlim + ylim
# Flip RGBA 2D array on y axis depending on the origin
# --> origin is used to specify where the origin axis is located
# --> BUT the image is always displayed from top to bottom
if origin == "upper":
extent = (extent[0], extent[1], extent[3], extent[2])
else:
rgba_array = rgba_array[::-1, ...]
# Plot bivariate colormap
p = ax.imshow(rgba_array, origin=origin, extent=extent, **imshow_kwargs)
# Set axis off
if disable_axis:
ax.axis("off")
# Set aspect
ax.set_aspect(aspect)
# Return
return p
[docs]
def get_log_ticks(vmin, vmax):
"""
Generate logarithmic tick values for a given data range.
Generates ticks at powers of 10 within the specified range, suitable for
logarithmic scaling of axes.
Parameters
----------
vmin : float
Minimum value (must be positive).
vmax : float
Maximum value (must be positive).
Returns
-------
numpy.ndarray
Array of tick positions at powers of 10 within [vmin, vmax].
Raises
------
ValueError
If vmin or vmax are not positive or if vmin >= vmax.
"""
if vmin <= 0 or vmax <= 0:
raise ValueError("vmin and vmax must be positive for logarithmic ticks.")
if vmin >= vmax:
raise ValueError("vmin must be less than vmax.")
# Generate ticks at powers of 10
ticks = np.power(10, np.arange(np.floor(np.log10(vmin)), np.ceil(np.log10(vmax)) + 1))
# Filter ticks to only include those >= vmin (ensure first tick is >= vmin)
ticks = ticks[ticks >= vmin]
return ticks
[docs]
def get_symlog_ticks(vmin, vmax, linthresh, base=10):
"""
Generate symmetric logarithmic tick values for a given data range.
For symmetric log (SymLogNorm), the scale is linear near zero (within ±linthresh)
and logarithmic outside this region. This function generates ticks that respect
both regions.
Parameters
----------
vmin : float
Minimum value of the data range (typically negative).
vmax : float
Maximum value of the data range (typically positive).
linthresh : float
The range around zero where scaling is linear (±linthresh).
Must be positive.
base : int, optional
The logarithmic base. Default is 10.
Returns
-------
numpy.ndarray
Array of tick positions suitable for symmetric log scaling.
The array is symmetric around zero.
Raises
------
ValueError
If linthresh is not positive or if vmin >= vmax.
Notes
-----
The returned ticks respect the linear region [-linthresh, linthresh]
and include logarithmically-spaced ticks outside this region.
"""
if linthresh <= 0:
raise ValueError("linthresh must be positive.")
if vmin >= vmax:
raise ValueError("vmin must be less than vmax.")
ticks = []
# Add zero if the linear region includes it
if vmin <= 0 <= vmax:
ticks.append(0)
# Generate positive ticks
if vmax > linthresh:
# Log region: start from first power of base >= linthresh
log_vmax = np.log(vmax) / np.log(base)
log_linthresh = np.log(linthresh) / np.log(base)
log_start = np.ceil(log_linthresh)
log_end = np.ceil(log_vmax)
log_ticks = np.arange(log_start, log_end + 1)
pos_ticks = np.power(base, log_ticks)
ticks.extend(pos_ticks[pos_ticks <= vmax])
# # Add intermediate ticks in the linear region (only if range is large)
# if vmin < -linthresh and linthresh < vmax:
# # Add fine-grained ticks between 0 and linthresh
# if linthresh > 0:
# intermediate_ticks = np.linspace(0, linthresh, 3)[1:-1] # Exclude endpoints
# ticks.extend(intermediate_ticks)
# Generate negative ticks (symmetric to positive)
if vmin < -linthresh:
log_vmin = np.log(abs(vmin)) / np.log(base)
log_linthresh = np.log(linthresh) / np.log(base)
log_start = np.ceil(log_linthresh)
log_end = np.ceil(log_vmin)
log_ticks = np.arange(log_start, log_end + 1)
neg_ticks = -np.power(base, log_ticks)
ticks.extend(neg_ticks[neg_ticks >= vmin])
# Add fine-grained negative ticks in linear region if needed
# if vmin < -linthresh and linthresh < vmax:
# if linthresh > 0:
# intermediate_ticks = -np.linspace(0, linthresh, 3)[1:-1] # Exclude endpoints
# ticks.extend(intermediate_ticks)
# Add linthresh ticks if they are within the range
ticks.extend([t for t in [-linthresh, linthresh] if vmin <= t <= vmax])
# Sort and remove duplicates
ticks = np.array(sorted(set(np.round(ticks, 10)))) # Round to avoid float precision issues
return ticks
def _normalize_log_value(value, axis_min, axis_max):
"""Normalize a value using logarithmic scaling.
Maps a positive value to [0, 1] using logarithmic transformation.
Parameters
----------
value : float
The value to normalize (must be positive).
axis_min : float
Minimum value of the data range (must be positive).
axis_max : float
Maximum value of the data range (must be positive).
Returns
-------
float
Normalized value in [0, 1] range.
"""
if value <= 0 or axis_min <= 0:
raise ValueError("Values must be positive for logarithmic normalization.")
return (np.log10(value) - np.log10(axis_min)) / (np.log10(axis_max) - np.log10(axis_min))
def _normalize_symlog_value(value, axis_min, axis_max, linthresh, base=10, linscale=1.0):
"""Normalize a value using symmetric logarithmic scaling.
Maps a value (positive or negative) to [0, 1] using symmetric logarithmic
transformation. The scaling is linear within [-linthresh, +linthresh]
and logarithmic outside this region.
Parameters
----------
value : float
The value to normalize.
axis_min : float
Minimum value of the data range.
axis_max : float
Maximum value of the data range.
linthresh : float
The range around zero where scaling is linear (±linthresh).
Must be positive.
base : int, optional
The logarithmic base. Default is 10.
linscale : float, optional
The number of decades to use for the linear part of the norm.
Controls visual space for the linear region. Default is 1.0.
Returns
-------
float
Normalized value in [0, 1] range.
Notes
-----
The linscale parameter controls the fraction of the range allocated to the
linear region: linear_fraction = linscale / (linscale + 1).
The range is divided into positive and negative halves, each scaled independently.
"""
# Compute linear region fraction
linear_fraction = linscale / (linscale + 1)
# Handle zero case
if value == 0:
return 0.5 if axis_min < 0 < axis_max else (0 if axis_max > 0 else 1)
# Compute the full range scaling factor
abs_vmin = abs(axis_min)
if value > 0:
# Positive side
if value <= linthresh:
# Linear region: maps to [0.5, 0.5 + 0.5*linear_fraction]
norm_pos = 0.5 + (value / linthresh) * 0.5 * linear_fraction
else:
# Log region: maps to [0.5 + 0.5*linear_fraction, 1.0]
log_value = np.log(value) / np.log(base)
log_linthresh = np.log(linthresh) / np.log(base)
log_vmax = np.log(axis_max) / np.log(base)
log_fraction = (log_value - log_linthresh) / (log_vmax - log_linthresh)
norm_pos = 0.5 + linear_fraction * 0.5 + (1.0 - linear_fraction) * 0.5 * log_fraction
return np.clip(norm_pos, 0, 1)
# Negative side
abs_value = abs(value)
if abs_value <= linthresh:
# Linear region: maps to [0.5 - 0.5*linear_fraction, 0.5]
norm_neg = 0.5 - (abs_value / linthresh) * 0.5 * linear_fraction
else:
# Log region: maps to [0.0, 0.5 - 0.5*linear_fraction]
log_value = np.log(abs_value) / np.log(base)
log_linthresh = np.log(linthresh) / np.log(base)
log_vmin = np.log(abs_vmin) / np.log(base)
log_fraction = (log_value - log_linthresh) / (log_vmin - log_linthresh)
norm_neg = 0.5 - linear_fraction * 0.5 - (1.0 - linear_fraction) * 0.5 * log_fraction
return np.clip(norm_neg, 0, 1)
[docs]
def set_log_axis(ax, major_ticks, axis):
"""Set logarithmic-like ticks on a linear axis for imshow plots.
Positions major and minor ticks using logarithmic transformation. Only suitable
for data with values that are all positive.
Parameters
----------
ax : matplotlib.axes.Axes
The axes object to modify.
major_ticks : array-like
The positions for major ticks in data coordinates (must be positive).
axis : {'x', 'y'}
The axis to modify ("x" or "y").
Raises
------
ValueError
If major_ticks contain non-positive values.
"""
# Validate inputs
major_ticks = np.asarray(major_ticks)
if np.any(major_ticks <= 0):
raise ValueError("All major_ticks must be positive for logarithmic axis.")
# Get the image extent
extent = ax.get_images()[0].get_extent()
# Generate minor ticks
minor_ticks = []
for i in range(len(major_ticks) - 1):
current_major = major_ticks[i]
minor_ticks.extend(np.arange(2, 10) * current_major)
minor_ticks = np.array(minor_ticks)
# Get relevant dimension for pixel conversion
if axis.lower() == "y":
extent_min, extent_max = extent[2], extent[3]
else:
extent_min, extent_max = extent[0], extent[1]
# Use extent limits (actual vmin, vmax) for normalization, not tick extremes
# This ensures correct positioning when vmin is not a power of 10
axis_min, axis_max = extent_min, extent_max
# Convert to pixel positions using logarithmic normalization
major_norm = np.array([_normalize_log_value(t, axis_min, axis_max) for t in major_ticks])
major_data_pos = extent_min + (extent_max - extent_min) * major_norm
minor_norm = np.array([_normalize_log_value(t, axis_min, axis_max) for t in minor_ticks])
minor_data_pos = extent_min + (extent_max - extent_min) * minor_norm
# Create tick labels
major_tick_labels = [f"{x:g}" for x in major_ticks]
# Set ticks based on axis
if axis.lower() == "y":
print(major_data_pos)
print(major_tick_labels)
ax.set_yticks(major_data_pos)
ax.set_yticklabels(major_tick_labels)
ax.set_yticks(minor_data_pos, minor=True)
ax.yaxis.set_tick_params(which="minor", length=4)
ax.yaxis.set_tick_params(which="major", length=8)
else:
ax.set_xticks(major_data_pos)
ax.set_xticklabels(major_tick_labels)
ax.set_xticks(minor_data_pos, minor=True)
ax.xaxis.set_tick_params(which="minor", length=4)
ax.xaxis.set_tick_params(which="major", length=8)
[docs]
def set_symlog_axis(ax, major_ticks, axis, linthresh, base=10, linscale=1.0):
"""Set symmetric logarithmic ticks on a linear axis for imshow plots.
Positions major and minor ticks using symmetric logarithmic transformation.
Suitable for data that includes both positive and negative values with a
linear region near zero.
Parameters
----------
ax : matplotlib.axes.Axes
The axes object to modify.
major_ticks : array-like
The positions for major ticks in data coordinates (can include negative values).
axis : {'x', 'y'}
The axis to modify ("x" or "y").
linthresh : float
The range around zero where scaling is linear (±linthresh).
Must be positive.
base : int, optional
The logarithmic base. Default is 10.
linscale : float, optional
The number of decades to use for the linear part. Default is 1.0.
Raises
------
ValueError
If linthresh is not positive.
"""
# Validate inputs
if linthresh <= 0:
raise ValueError("linthresh must be positive.")
major_ticks = np.asarray(major_ticks)
# Get the image extent
extent = ax.get_images()[0].get_extent()
# Get relevant dimension for pixel conversion
if axis.lower() == "y":
extent_min, extent_max = extent[2], extent[3]
else:
extent_min, extent_max = extent[0], extent[1]
# Use extent limits (actual vmin, vmax) for normalization, not tick extremes
# This ensures correct positioning when vmin is not exactly at a power of 10
axis_min, axis_max = extent_min, extent_max
# Convert to pixel positions using symmetric logarithmic normalization
major_norm = np.array(
[_normalize_symlog_value(t, axis_min, axis_max, linthresh, base, linscale) for t in major_ticks],
)
major_data_pos = extent_min + (extent_max - extent_min) * major_norm
# Generate minor ticks in log regions (skip linear region to avoid clutter)
minor_ticks = []
# Positive side minor ticks
for i in range(len(major_ticks) - 1):
if major_ticks[i] >= linthresh: # In log region (including boundary)
current_major = major_ticks[i]
minor_ticks.extend(np.arange(2, 10) * current_major)
# Negative side minor ticks
for i in range(len(major_ticks) - 1):
if major_ticks[i] <= -linthresh: # In log region (including boundary)
current_major = abs(major_ticks[i])
minor_ticks.extend(-np.arange(2, 10) * current_major)
minor_ticks = np.array(minor_ticks)
# Convert minor ticks to pixel positions
if len(minor_ticks) > 0:
minor_norm = np.array(
[_normalize_symlog_value(t, axis_min, axis_max, linthresh, base, linscale) for t in minor_ticks],
)
minor_data_pos = extent_min + (extent_max - extent_min) * minor_norm
else:
minor_data_pos = np.array([])
# Create tick labels with smart formatting
# When linscale is very small, don't show labels for ±linthresh to avoid overlap with 0
major_tick_labels = []
skip_linthresh_labels = linscale < 0.1 # Threshold for suppressing linthresh labels
for x in major_ticks:
# Skip labels for ±linthresh if linscale is very small
if skip_linthresh_labels and (np.isclose(x, linthresh) or np.isclose(x, -linthresh)):
major_tick_labels.append("") # Empty label but tick still visible
elif abs(x) < linthresh:
major_tick_labels.append(f"{x:g}")
else:
major_tick_labels.append(f"{x:g}")
# Set ticks based on axis
if axis.lower() == "y":
ax.set_yticks(major_data_pos)
ax.set_yticklabels(major_tick_labels)
if len(minor_data_pos) > 0:
ax.set_yticks(minor_data_pos, minor=True)
ax.yaxis.set_tick_params(which="minor", length=4)
ax.yaxis.set_tick_params(which="major", length=8)
else:
ax.set_xticks(major_data_pos)
ax.set_xticklabels(major_tick_labels)
if len(minor_data_pos) > 0:
ax.set_xticks(minor_data_pos, minor=True)
ax.xaxis.set_tick_params(which="minor", length=4)
ax.xaxis.set_tick_params(which="major", length=8)
[docs]
def get_axis_defaults(norm):
"""
Extract axis configuration from a matplotlib normalization object.
Generates appropriate axis limits, tick positions, and tick labels based
on the type of normalization applied to the data.
Parameters
----------
norm : matplotlib.colors.Normalize
A matplotlib normalization object. Supported types include:
- CategoryNorm, CategorizeNorm: categorical data
- BoundaryNorm: discrete binned data
- LogNorm: logarithmic scaling
- SymLogNorm: symmetric logarithmic scaling
- Normalize, CenterNorm, and other standard norms: linear scaling
Returns
-------
tuple
A tuple of (value_lims, ticks, ticklabels) where:
- value_lims : tuple of (float, float)
The (min, max) range for the axis
- ticks : numpy.ndarray or None
Tick positions in data coordinates
- ticklabels : list or None
Formatted tick labels, or None for automatic formatting
Raises
------
NotImplementedError
If the norm type is not supported.
"""
# Handle categorical norms
if isinstance(norm, (CategoryNorm, CategorizeNorm)):
value_lims = (0, norm.Ncmap)
ticks = np.arange(0, norm.Ncmap) + 0.5
ticklabels = norm.ticklabels.copy()
return (value_lims, ticks, ticklabels)
# Handle boundary (discrete binned) norms
if isinstance(norm, BoundaryNorm):
value_lims = norm.boundaries[0], norm.boundaries[-1]
ticks = norm.boundaries.copy()
ticklabels = None # ticklabels = _dynamic_formatting_floats(ticks)
return (value_lims, ticks, ticklabels)
# Handle norms with vmin/vmax attributes
if hasattr(norm, "vmin") and hasattr(norm, "vmax"):
value_lims = norm.vmin, norm.vmax
# Handle logarithmic norm
if isinstance(norm, LogNorm):
ticks = get_log_ticks(vmin=norm.vmin, vmax=norm.vmax)
ticklabels = None
# Handle symmetric logarithmic norm
elif isinstance(norm, SymLogNorm):
linthresh = norm._scale.linthresh
# linscale = norm.scale.linscale
base = norm._scale.base
ticks = get_symlog_ticks(
vmin=norm.vmin,
vmax=norm.vmax,
linthresh=linthresh,
# linscale=linscale,
base=base,
)
ticklabels = None
# Handle standard linear norms (Normalize, CenterNorm, etc.)
else:
ticks = np.linspace(norm.vmin, norm.vmax, 3)
ticklabels = None
return (value_lims, ticks, ticklabels)
# If we can't detect boundaries or vmin/vmax
raise NotImplementedError(f"Unsupported {type(norm).__name__!s} norm.")
[docs]
def add_bivariate_colorbar(
*,
bivariate_cmap,
cax,
origin="lower",
aspect="auto",
# Options
xlabel=None,
ylabel=None,
title=None,
title_kwargs=None,
xlabel_kwargs=None,
ylabel_kwargs=None,
xticks_kwargs=None,
yticks_kwargs=None,
**imshow_kwargs,
):
"""Add a bivariate colorbar to the specified axis.
Parameters
----------
bivariate_cmap : pycolorbar.BivariateColormap
The bivariate colormap object containing the color mapping and norms.
cax : matplotlib.axes.Axes
The axis on which to draw the colorbar.
origin : {'lower', 'upper'}, optional
The origin of the colorbar. Default is 'lower'.
aspect: str, optional
Either 'equal' or 'auto'.
If 'auto' fill the Axes position rectangle with data.
If 'equal', ensure same scaling between x and y axis.
The default is 'auto'.
xlabel : str, optional
The label for the x-axis. Default is None.
ylabel : str, optional
The label for the y-axis. Default is None.
title : str, optional
The title for the colorbar. Default is None.
title_kwargs : dict, optional
Additional keyword arguments for the title. Default is None.
xlabel_kwargs : dict, optional
Additional keyword arguments for the x-axis label. Default is None.
ylabel_kwargs : dict, optional
Additional keyword arguments for the y-axis label. Default is None.
xticks_kwargs : dict, optional
Additional keyword arguments for the x-axis ticks. Default is None.
yticks_kwargs : dict, optional
Additional keyword arguments for the y-axis ticks. Default is None.
**imshow_kwargs : dict, optional
Additional keyword arguments for the `imshow` function.
Returns
-------
matplotlib.image.AxesImage
The image object representing the bivariate colorbar.
Raises
------
ValueError
If the norms for the bivariate colormap are not defined.
It occurs when the bivariate colormaps has not yet been used to
map some values to RGBA colors.
"""
# Initialize arguments
xticks_kwargs = {} if xticks_kwargs is None else xticks_kwargs
yticks_kwargs = {} if yticks_kwargs is None else yticks_kwargs
xlabel_kwargs = {} if xlabel_kwargs is None else xlabel_kwargs
ylabel_kwargs = {} if ylabel_kwargs is None else ylabel_kwargs
title_kwargs = {} if title_kwargs is None else title_kwargs
# Retrieve norms
norm_x = bivariate_cmap.norm_x
norm_y = bivariate_cmap.norm_y
if norm_x is None or norm_y is None:
raise ValueError("You first need to map some values before plotting the colorbar.")
# Define default axis options
xlim, x_ticks, x_ticklabels = get_axis_defaults(norm_x)
ylim, y_ticks, y_ticklabels = get_axis_defaults(norm_y)
# Display the bivariate colormap as an image
rgba_array = bivariate_cmap.rgba_array
p = plot_bivariate_palette(
rgba_array,
ax=cax,
xlim=xlim,
ylim=ylim,
disable_axis=False,
origin="lower",
aspect=aspect,
**imshow_kwargs,
)
# Deal with log axis and symmetric log axis
is_log_x = isinstance(norm_x, LogNorm)
is_log_y = isinstance(norm_y, LogNorm)
is_symlog_x = isinstance(norm_x, SymLogNorm)
is_symlog_y = isinstance(norm_y, SymLogNorm)
# Add ticks and ticklabels
xticks_kwargs.setdefault("ticks", x_ticks)
xticks_kwargs.setdefault("labels", x_ticklabels)
yticks_kwargs.setdefault("ticks", y_ticks)
yticks_kwargs.setdefault("labels", y_ticklabels)
if xticks_kwargs.get("ticks", None) is not None:
if is_log_x:
set_log_axis(ax=cax, major_ticks=xticks_kwargs["ticks"], axis="x")
elif is_symlog_x:
linthresh = norm_x._scale.linthresh
base = norm_x._scale.base
linscale = norm_x._scale.linscale
set_symlog_axis(
ax=cax,
major_ticks=xticks_kwargs["ticks"],
axis="x",
linthresh=linthresh,
base=base,
linscale=linscale,
)
else:
cax.set_xticks(**xticks_kwargs)
if yticks_kwargs.get("ticks", None) is not None:
if is_log_y:
set_log_axis(ax=cax, major_ticks=yticks_kwargs["ticks"], axis="y")
elif is_symlog_y:
linthresh = norm_y._scale.linthresh
base = norm_y._scale.base
linscale = norm_y._scale.linscale
set_symlog_axis(
ax=cax,
major_ticks=yticks_kwargs["ticks"],
axis="y",
linthresh=linthresh,
base=base,
linscale=linscale,
)
else:
cax.set_yticks(**yticks_kwargs)
# Add labels and title
if title is not None:
cax.set_title(title, **title_kwargs)
if xlabel is not None:
cax.set_xlabel(xlabel, **xlabel_kwargs)
if ylabel is not None:
cax.set_ylabel(ylabel, **ylabel_kwargs)
# Invert axis origin if specified as "upper"
if origin == "upper":
cax.invert_yaxis()
return p
[docs]
def add_bivariate_legend(
*,
bivariate_cmap,
ax,
# Inset options
box_aspect=1,
height=0.2,
pad=0.005,
loc="upper right",
inside_figure=True,
optimize_layout=True,
# Fancybox options
fancybox=False,
fancybox_pad=0,
fancybox_fc="white",
fancybox_ec="none",
fancybox_lw=0.5,
fancybox_alpha=0.4,
fancybox_shape="square",
# Colorbar options
**kwargs,
):
"""
Add the bivariate colorbar legend to a plot.
Parameters
----------
bivariate_cmap : pycolorbar.BivariateColormap
The bivariate colormap to be used for the legend.
ax : matplotlib.axes.Axes
The axes to which the bivariate legend will be added.
box_aspect : float, optional
Aspect ratio of the inset Axes. Default is 1.
height : float, optional
Height of the inset as a fraction [0-1] of the main Axes. Default is 0.2.
pad : float, optional
Padding between the inset and main Axes in figure coordinates. Default is 0.005.
loc : str or tuple, optional
Location of the inset. Default is 'upper right'.
inside_figure : bool, optional
Whether inset is inside the figure region. Default is True.
optimize_layout : bool, optional
Whether to auto-adjust the inset position for ticklabels. Default is True.
NOTE: If True, do not call `fig.tight_layout()` afterwards.
fancybox : bool, optional
Whether to draw a fancy box behind the inset. Default is False.
fancybox_pad : float, optional
Padding of the fancy box in figure coordinates. Default is 0.
fancybox_fc : str, optional
Face color of the fancy box. Default is 'white'.
fancybox_ec : str, optional
Edge color of the fancy box. Default is 'none'.
fancybox_lw : float, optional
Line width of the fancy box. Default is 0.5.
fancybox_alpha : float, optional
Alpha of the fancy box. Default is 0.4.
fancybox_shape : {'circle', 'square'}, optional
Shape of the fancy box. Default is 'square'.
**kwargs : dict
Additional keyword arguments passed to the bivariate colorbar.
See the add_bivariate_colorbar documentation.
Returns
-------
matplotlib.image.AxesImage
The image object representing the bivariate colorbar.
"""
# The actual colorbar plotting function
colorbar_func = add_bivariate_colorbar
colorbar_func_kwargs = dict(
bivariate_cmap=bivariate_cmap,
**kwargs,
)
p_cbar = add_colorbar_inset(
ax=ax,
colorbar_func=colorbar_func,
colorbar_func_kwargs=colorbar_func_kwargs,
# Inset options
projection=None,
box_aspect=box_aspect,
height=height,
pad=pad,
loc=loc,
inside_figure=inside_figure,
optimize_layout=optimize_layout,
fancybox=fancybox,
fancybox_pad=fancybox_pad,
fancybox_fc=fancybox_fc,
fancybox_ec=fancybox_ec,
fancybox_lw=fancybox_lw,
fancybox_alpha=fancybox_alpha,
fancybox_shape=fancybox_shape,
)
return p_cbar
[docs]
def plot_bivariate_colorbar(
*,
bivariate_cmap,
ax=None,
cax=None,
origin="lower",
location="right",
size="30%",
pad=0.45,
box_aspect=1,
**kwargs,
):
"""
Plot a bivariate colorbar.
This function plots a 2D colorbar representing the specified
bivariate colormap. You can either provide:
- An existing Axes (`ax`) in which to place the colorbar (the colorbar will
be appended to one of its sides).
- A dedicated Axes object (`cax`) for direct drawing of the colorbar
on the specified `cax`.
- Or no Axes at all, in which case a new figure and Axes are created.
If both `ax` and `cax` are given, `ax` is ignored !.
Parameters
----------
bivariate_cmap : pycolorbar.BivariateColormap
The colormap to be used for the bivariate colorbar.
ax : matplotlib.axes.Axes or cartopy.mpl.geoaxes.GeoAxesSubplot, optional
The Axes to which the colorbar should be appended. Ignored if
`cax` is provided. If both `ax` and `cax` are None, a new figure
and Axes are created.
cax : matplotlib.axes.Axes, optional
The Axes in which to directly draw the colorbar. If provided,
`ax` is ignored.
origin : {'lower', 'upper'}, optional
Indicates where to locate the origin in the colorbar Axes.
Default is 'lower'.
location : {'right', 'left', 'top', 'bottom'}, optional
The side of the plot where the colorbar should be placed
(when `ax` is used). Default is 'right'.
size : float or str, optional
The size of the colorbar relative to the parent Axes when using
`append_axes`. For instance, `'30%'` means 30% of the parent Axes
width (or height, depending on `location`). Default is `'30%'`.
pad : float, optional
The padding between the parent Axes and the colorbar, in inches.
Default is 0.45.
box_aspect : float, optional
The aspect ratio of the colorbar Axes box. Default is 1.
**kwargs : dict
Additional keyword arguments passed to the internal
``add_bivariate_colorbar`` function, which is responsible for
actually rendering the colorbar content.
Returns
-------
matplotlib.image.AxesImage
The image object representing the bivariate colorbar.
"""
# Determine colorbar axis
if cax is not None:
pass
elif ax is not None: # and cax is None
divider = make_axes_locatable(ax)
cax = divider.append_axes(location, size=size, pad=pad, axes_class=plt.Axes)
cax.set_box_aspect(box_aspect)
else:
fig, cax = plt.subplots() # noqa: RUF059
# Add the 2D colorbar with custom ticks
p = add_bivariate_colorbar(
cax=cax,
bivariate_cmap=bivariate_cmap,
origin=origin,
**kwargs,
)
return p
####------------------------------------------------------------------------.
[docs]
class BivariateColormap:
"""Class representing a bivariate colormap."""
def __init__(self, rgba_array, *, luminance_factor=None, n=None, interp_method=None):
"""
Initialize the bivariate colormap with an RGBA array of shape (n_y, n_x, 4).
Parameters
----------
rgba_array : numpy.ndarray
2D RGBA array (n_y, n_x, 4) providing colors from top to bottom.
The (x, 0) values are mapped to the corresponding color in the bottom row of the 2D RGBA array.
The (x, 1) values are mapped to the corresponding color in the top row of the 2D RGBA array.
n : int or tuple, optional
Either a single integer or a (n_y, n_x) tuple specifying the number of colormap colors.
luminance_factor : float or None, optional
If set, apply radial-based luminance gradient.
Radial darkening is obtained with values < 1.
Radial whitening is obtained with values > 1.
None or 1 produce no change.
interp_method : str, optional
The interpolation method to use for generating the colormap.
The default is 'cubic'.
"""
# Ensure rgba array
rgba_array = ensure_rgba_array(rgba_array)
# Resample rgb array if n is specified
n_x, n_y = check_n(n, both_integers=False)
rgba_array = resample_rgba_array(rgba_array, n_x=n_x, n_y=n_y, interp_method=interp_method)
# Apply luminance gradient if specified
self.rgba_array = apply_luminance_gradient(rgba_array, luminance_factor=luminance_factor)
# Initialize attributes
self.shape = self.rgba_array.shape
self.n_x = self.shape[1]
self.n_y = self.shape[0]
# Default “under” and “over” colors as fully transparent
# self._under = (0.0, 0.0, 0.0, 0.0)
# self._over = (0.0, 0.0, 0.0, 0.0)
self._bad = (0.0, 0.0, 0.0, 0.0)
# Initialize other arguments
self.norm_x = None
self.norm_y = None
[docs]
@classmethod
def from_corners(cls, colors, n, *, luminance_factor=None, interp_method=None):
"""
Generate a bivariate colormap from the colors at four corners.
Parameters
----------
color_list : list of color specs
E.g. ['red', 'blue', 'green', 'black'] or precomputed array of RGBA.
n : int or tuple
Either a single integer or a (n_y, n_x) tuple specifying the number of colormap colors.
luminance_factor : float or None, optional
If set, apply radial-based luminance gradient.
Radial darkening is obtained with values < 1.
Radial whitening is obtained with values > 1.
None or 1 produce no change.
interp_method : str, optional
The interpolation method to use for generating the colormap.
The default is 'cubic'.
Returns
-------
pycolorbar.BiviariateColormap
"""
return cls.from_colors(colors=colors, n=n, luminance_factor=luminance_factor, interp_method=interp_method)
[docs]
@classmethod
def from_colors(cls, colors, n, *, luminance_factor=None, interp_method=None):
"""
Generate a bivariate colormap from a set of color points interpolated onto an 2D (n_y, n_x) grid.
Parameters
----------
color_list : list of color specs
E.g. ['red', 'blue', 'green', 'black'] or precomputed array of RGBA.
n : int or tuple
Either a single integer or a (n_y, n_x) tuple specifying the number of colormap colors.
luminance_factor : float or None, optional
If set, apply radial-based luminance gradient.
Radial darkening is obtained with values < 1.
Radial whitening is obtained with values > 1.
None or 1 produce no change.
interp_method : str, optional
The interpolation method to use for generating the colormap.
The default is 'cubic'.
Returns
-------
pycolorbar.BiviariateColormap
"""
rgba_array = get_bivariate_cmap_from_colors(colors, n=n, interp_method=interp_method)
return cls(rgba_array, luminance_factor=luminance_factor)
[docs]
@classmethod
def from_cmaps(cls, cmap_x=plt.cm.Blues, cmap_y=plt.cm.Reds, n=256, *, luminance_factor=None):
"""
Generate a bivariate colormap by blending two univariate colormaps along x and y axes.
Parameters
----------
cmap_x : matplotlib.colors.Colormap or str
Univariate colormap for the x axis.
cmap_y : matplotlib.colors.Colormap or str
Univariate colormap for the y axis.
n : int or tuple, optional
Either a single integer or a (n_x, n_y) tuple specifying
the number of colormap colors on the x and y axis.
Default is 256.
luminance_factor : float or None, optional
If set, apply radial-based luminance gradient.
Radial darkening is obtained with values < 1.
Radial whitening is obtained with values > 1.
None or 1 produce no change.
Returns
-------
pycolorbar.BiviariateColormap
"""
rgba_array = get_bivariate_cmap_from_two_cmaps(cmap_x=cmap_x, cmap_y=cmap_y, n=n)
return cls(rgba_array, luminance_factor=luminance_factor)
[docs]
@classmethod
def from_name(cls, name, n, *, diagonal_tilt=0.8, offdiag_tilt=1.0, luminance_factor=None, interp_method=None):
"""Load a predefined bivariate colormap.
Parameters
----------
name : str
The name of the predefined bivariate colormap to load.
See available_biviariate_colormaps().
n : int
The number of colors in the colormap.
diagonal_tilt : float, optional
The tilt of the diagonal in the colormap, by default 0.8.
Used only for Teuling colormaps.
offdiag_tilt : float, optional
The tilt of the off-diagonal in the colormap, by default 1.
Used only for Teuling colormaps.
interp_method : str or None, optional
The interpolation method to use, by default None.
luminance_factor : float or None, optional
If set, apply radial-based luminance gradient.
Radial darkening is obtained with values < 1.
Radial whitening is obtained with values > 1.
None or 1 produce no change.
Returns
-------
pycolorbar.BiviariateColormap
"""
rgba_array = get_bivariate_cmap_from_name(
name=name,
n=n,
diagonal_tilt=diagonal_tilt,
offdiag_tilt=offdiag_tilt,
interp_method=interp_method,
)
return cls(rgba_array, luminance_factor=luminance_factor)
def __getitem__(self, key):
"""Retrieve a subset of the colormap."""
if not isinstance(key, tuple) or len(key) != 2:
raise ValueError("Exactly two slices (for y,x) are required.")
y_slice, x_slice = key
if not isinstance(x_slice, slice) or not isinstance(y_slice, slice):
raise ValueError("Subset both dimensions with slice objects.")
# Check slice length >= 2 for x dimension
start_x, stop_x, step_x = x_slice.indices(self.rgba_array.shape[1])
if ((stop_x - start_x) // step_x) < 2:
raise ValueError("x slice must include at least 2 elements.")
# Check slice length >= 2 for y dimension
start_y, stop_y, step_y = y_slice.indices(self.rgba_array.shape[0])
if ((stop_y - start_y) // step_y) < 2:
raise ValueError("y slice must include at least 2 elements.")
rgba_array = self.rgba_array.copy()[y_slice, x_slice, :]
return self._copy_attributes(BivariateColormap(rgba_array=rgba_array))
def __eq__(self, other):
"""Check equality of two BivariateColormap instances."""
if not isinstance(other, BivariateColormap):
return False
return np.all(self.rgba_array == other.rgba_array) and self._bad == other._bad
def __hash__(self):
"""Return a hash value for the BivariateColormap instance."""
# TODO: also hash norm?
return hash((self.rgba_array.tobytes(), self._bad))
[docs]
def copy(self):
"""Create a copy of the BivariateColormap instance."""
rgba_array = self.rgba_array.copy()
return self._copy_attributes(BivariateColormap(rgba_array=rgba_array))
def _copy_attributes(self, new_instance):
new_instance._bad = self._bad
return new_instance
def __setitem__(self, key, value):
"""Modify the colormap palette."""
self.rgba_array[key] = value
self.shape = self.rgba_array.shape
self.n_x = self.shape[1]
self.n_y = self.shape[0]
[docs]
def adapt_interval(self, interval_x=None, interval_y=None):
"""
Subset the bivariate colormap based on the specified interval fractions.
Parameters
----------
interval_x : tuple
A tuple of two float values between 0 and 1, indicating the fraction of the colors to retain on the x axis.
If None, no subsetting is performed.
interval_y : tuple
A tuple of two float values between 0 and 1, indicating the fraction of the colors to retain on the y axis.
If None, no subsetting is performed.
Returns
-------
pycolorbar.BiviariateColormap
"""
from pycolorbar.univariate.cmap import check_interval
# Check intervals
interval_x = check_interval(interval_x)
interval_y = check_interval(interval_y)
# Define indexing
n_x = self.n_x
n_y = self.n_y
x_start, x_end = int(interval_x[0] * n_x), int(interval_x[1] * n_x)
y_start, y_end = int(interval_y[0] * n_y), int(interval_y[1] * n_y)
x_indices = slice(x_start, x_end)
y_indices = slice(y_start, y_end)
return self[y_indices, x_indices]
[docs]
def set_bad(self, color, alpha=None):
"""
Set the color for bad (masked) values.
Parameters
----------
color : color spec
The color to use for bad values.
"""
self._bad = mpl.colors.to_rgba(color, alpha=alpha)
[docs]
def set_alpha(self, alpha):
"""
Set the alpha (transparency) value for the entire colormap.
Parameters
----------
alpha : float
The alpha value to set, where 0 is fully transparent and 1 is fully opaque.
"""
self.rgba_array[:, :, 3] = alpha
[docs]
def change_luminance_gradient(self, luminance_factor=None):
"""Change the luminance gradient of the colormap.
It add a radial whitening/darkening effect.
Parameters
----------
luminance_factor : float or None
Radial darkening is obtained with values < 1.
Radial whitening is obtained with values > 1.
None or 1 produce no change.
Returns
-------
pycolorbar.BiviariateColormap
The colormap with the new luminance gradient.
"""
rgba_array = apply_luminance_gradient(self.rgba_array, luminance_factor=luminance_factor)
return self._copy_attributes(BivariateColormap(rgba_array=rgba_array))
[docs]
def rot90(self, *, clockwise=True):
"""Rotate the colormap by 90 degrees.
Parameters
----------
clockwise : bool, optional
If True, rotate clockwise.
If False, rotate counterclockwise.
Default is True.
Returns
-------
pycolorbar.BiviariateColormap
The colormap rotated by 90 degrees.
"""
if clockwise:
rgba_array = np.rot90(self.rgba_array, k=-1, axes=(0, 1))
else:
rgba_array = np.rot90(self.rgba_array, k=1, axes=(0, 1))
return self._copy_attributes(BivariateColormap(rgba_array=rgba_array))
[docs]
def rot180(self, *, clockwise=True):
"""Rotate the colormap by 180 degrees.
Parameters
----------
clockwise : bool, optional
If True, rotate clockwise.
If False, rotate counterclockwise.
Default is True.
Returns
-------
pycolorbar.BiviariateColormap
The colormap rotated by 180 degrees.
"""
return self.rot90(clockwise=clockwise).rot90(clockwise=clockwise)
[docs]
def fliplr(self):
"""Flip the colormap array in the left/right direction.
This method flips the RGBA array of the colormap horizontally,
creating a mirror image along the vertical axis.
"""
rgba_array = np.fliplr(self.rgba_array)
return self._copy_attributes(BivariateColormap(rgba_array=rgba_array))
[docs]
def flipud(self):
"""Flip the colormap array in the up/down direction.
This method flips the RGBA array of the colormap in the vertical direction,
effectively reversing the order of the rows.
"""
rgba_array = np.flipud(self.rgba_array)
return self._copy_attributes(BivariateColormap(rgba_array=rgba_array))
[docs]
def resampled(self, *, n_x=None, n_y=None, interp_method="linear"):
"""
Create a new BivariateColormap instance with the desired number of colors.
Parameters
----------
n_x : int, optional
The desired number of colormap colors along the x axis.
If None, the original number of colors is kept.
n_y : int, optional
The desired number of colormap colors along the y axis.
If None, the original number of colors is kept.
interp_method : str
Interpolation method (e.g. "nearest", "linear", "cubic").
The default method is "nearest".
Returns
-------
BivariateCmap
A new BivariateCmap instance resampled to have an rgba_array of shape (n_y, n_x).
"""
rgba_array = resample_rgba_array(self.rgba_array, n_x=n_x, n_y=n_y, interp_method=interp_method)
return self._copy_attributes(BivariateColormap(rgba_array=rgba_array))
def __call__(self, x, y, *, norm_x=None, norm_y=None):
"""
Map (x, y) data to RGBA colors based on this bivariate colormap.
Parameters
----------
x, y : array-like, pd.Series, or xarray.DataArray
Data arrays to be mapped. Must be of the same type and shape.
norm_x : None or mpl.colors.Normalize or BoundaryNorm
Normalization for the x dimension. If None, default 0-1 scaling is used or computed from data.
norm_y : None or mpl.colors.Normalize or BoundaryNorm
Normalization for the y dimension. If None, default 0-1 scaling is used or computed from data.
Returns
-------
Mapped result:
- A numpy.ndarray with shape = x.shape + (4,) if x is a numpy array or pd.Series
- An xarray.DataArray with shape = x.shape + ("rgba",) if x is an xarray.DataArray
"""
# Check inputs
if type(x) is not type(y):
raise TypeError("`x` and `y` must be of the same type.")
# Define norm if None or update with the specified
if norm_x is None:
self.norm_x = define_norm(x)
else:
self.norm_x = norm_x
if norm_y is None:
self.norm_y = define_norm(y)
else:
self.norm_y = norm_y
# Dispatch to methods
if (_PANDAS_AVAILABLE and isinstance(x, pd.Series)) or (_GEOPANDAS_AVAILABLE and isinstance(x, gpd.GeoSeries)):
# Define special norm for category series
# - The norm is ignored for remapping but it used for plotting the colorbar !
if isinstance(x.dtype, pd.CategoricalDtype):
self.norm_x = create_pandas_category_norm(x)
if isinstance(y.dtype, pd.CategoricalDtype):
self.norm_y = create_pandas_category_norm(y)
# Map data
return map_pandas_data(
x,
y,
norm_x=self.norm_x,
norm_y=self.norm_y,
rgba_array=self.rgba_array,
bad_color=self._bad,
)
if _XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
return map_xarray_data(
x,
y,
norm_x=self.norm_x,
norm_y=self.norm_y,
rgba_array=self.rgba_array,
bad_color=self._bad,
)
return map_array_data(
x,
y,
norm_x=self.norm_x,
norm_y=self.norm_y,
rgba_array=self.rgba_array,
bad_color=self._bad,
)
# Alias
map = __call__
[docs]
@copy_docstring(plot_bivariate_palette)
def plot(self, ax=None, disable_axis=True, **kwargs): # noqa: D102
# Plot colormap
return plot_bivariate_palette(self.rgba_array, ax=ax, disable_axis=disable_axis, **kwargs)
[docs]
@copy_docstring(plot_bivariate_colorbar)
def plot_colorbar(self, ax=None, cax=None, **kwargs): # noqa: D102
# Plot colorbar
return plot_bivariate_colorbar(bivariate_cmap=self, ax=ax, cax=cax, **kwargs)
[docs]
@copy_docstring(add_bivariate_legend)
def add_legend(self, ax, **kwargs): # noqa: D102
# Add bivariate colorbar as a legend to the plot.
return add_bivariate_legend(bivariate_cmap=self, ax=ax, **kwargs)
def _repr_png_(self):
"""
Generate a PNG representation of the 2D RGBA array for this bivariate colormap.
Returns
-------
bytes
PNG-encoded bytes of the RGBA array.
"""
import pycolorbar
# Convert float RGBA -> 8-bit RGBA
# shape = (height, width, 4)
img_8bit = (self.rgba_array * 255).astype(np.uint8)
# Create a PIL Image from the array
image = Image.fromarray(img_8bit, mode="RGBA")
# Resize to a constant display size
image = image.resize(_BIVAR_REPR_PNG_SIZE, resample=Image.NEAREST) # Image.BICUBIC)
# Encode as PNG in memory
png_bytes = io.BytesIO()
pnginfo = PngInfo()
author = f"pycolorbar v{pycolorbar.__version__}, https://github.com/ghiggi/pycolorbar"
pnginfo.add_text("Author", author)
image.save(png_bytes, format="png", pnginfo=pnginfo)
return png_bytes.getvalue()
def _repr_html_(self):
"""Generate an HTML representation of the bivariate colormap with an embedded PNG.
This function allows to display the colormap in the IPython terminal
and JupyterNotebook cells.
"""
# Convert the PNG bytes to base64
png_bytes = self._repr_png_()
png_base64 = base64.b64encode(png_bytes).decode("ascii")
html = (
# Display just the 2D colormap image, no title
f'<div class="cmap" style="border: 1px solid #555;">'
f' <img src="data:image/png;base64,{png_base64}" />'
"</div>"
)
return html
[docs]
def get_bivariate_transparancy_rgba_array(cmap, n_x, n_y, alpha_min=0.1, alpha_max=1):
"""Create the RGBA array for the bivariate transparency colormap."""
import pycolorbar
cmap = pycolorbar.get_cmap(cmap)
# Generate the base colormap
base_colors = cmap(np.linspace(0, 1, n_x))
# Generate the alpha values
alphas = np.linspace(alpha_max, alpha_min, n_y)
# Create the RGBA array
rgba_array = np.ones((n_y, n_x, 4))
for i in range(n_y):
rgba_array[i, :, :3] = base_colors[:, :3]
rgba_array[i, :, 3] = alphas[i]
return rgba_array
[docs]
class BivariateTransparencyColormap(BivariateColormap):
"""Class representing a bivariate transparency colormap."""
def __init__(self, cmap, alpha_min=0.2, alpha_max=1, n=None):
"""
Initialize the bivariate colormap with transparency.
Parameters
----------
cmap : matplotlib.colors.Colormap or str
The colormap to be used.
alpha_min : float, optional
The minimum alpha (transparency) value, by default 0.2.
alpha_max : float, optional
The maximum alpha (transparency) value, by default 1.
n : int or tuple of int, optional
The number of discrete colors in the colormap. If an integer is provided,
it is used for both dimensions. If a tuple is provided, it should be of
the form (n_x, n_y), where n_x is the number of colors in the x dimension
and n_y is the number of colors in the y dimension. If None, a default
value is used.
Notes
-----
The y dimension corresponds to the transparency levels.
"""
n_x, n_y = check_n(n)
rgba_array = get_bivariate_transparancy_rgba_array(
cmap=cmap,
n_x=n_x,
n_y=n_y,
alpha_min=alpha_min,
alpha_max=alpha_max,
)
super().__init__(rgba_array)
self.alpha_min = alpha_min
self.alpha_max = alpha_max