Source code for cornerhex.plot

import matplotlib
from matplotlib import colormaps
from matplotlib.colors import Colormap
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import griddata
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import gaussian_filter
from scipy.special import erf
from typing import Tuple


[docs] def cornerplot( data: np.ndarray, cmap="Blues", correlation_textcolor=None, hex_gridsize=30, highlight=None, highlight_linecolor=None, highlight_markercolor=None, hist_backgroundcolor=None, hist_bins=20, hist_edgecolor=None, hist_facecolor=None, labels=None, limits=None, scatter_alpha=0.5, scatter_markercolor=None, scatter_outside_sigma=None, show_correlations=False, sigma_levels=None, sigma_linecolor=None, sigma_smooth=3.0, title_quantiles=None, width=3.0, ) -> Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """ Function creates hexbin corner plot matrix to visualize multidimensional data. Parameters: ----------- data : (N_sample, N_dims) array-like Two-dimensional data array to be visualized. First dimension are the samples, second dimension the features. cmap : str or matplotlib.colors.Colormap, optional, default: "Blues" Either a string of a valid matplotlib colormap or a custom Colormap object to be used in the plot. correlation_textcolor : str, optional, default: None Color of correlation text. Defaults to a value from the chosen colormap. hex_gridsize : int, optional, default: 30 Number of hexagons ins x-direction highlight : (N_dims) array-like, optional, default: None If not None, array of values to be highlighted in the plot. Typically the truth values. highlight_linecolor : str or tuple, optional, default: None If not None linecolor of the highlighted values. Defaults to a value from the chosen colormap. highlight_markercolor : str or tuple, optional, default: None If not None markercolor of the highlighted values. Defaults to a value from the chosen colormap. hist_backgroundcolor : str or tuple, optional, default: None If not None background color of the axes object. Defaults to a value from the chosen colormap. hist_bins : int, optinal, default: 20 Number of bins in histograms. hist_edgecolor : str or tuple, optional, default: None If not None edgecolor of the histgram bars. Defaults to a value from the chosen colormap. hist_facecolor : str or tuple, optional, default: None If not None facecolor of the histgram bars. Defaults to a value from the chosen colormap. labels : (N_dims) list, optional, default: None If not None list of strings with the feature names to be used as axis labels or in the axis titles. limits : (2,) or (N_dims, 2) array-like, optional, default: None If not None limits is the axis limits of the features giving the minimum and maximum values of each feature scatter_alpha : float, optional, default: 0.5 Alpha transparency value of scatter plot scatter_markercolor : str, optional, default: None If not None markercolor of scatter plot markers. Defaults to a value from the chosen colormap. scatter_outside_sigma : float, optional, default: None If not None displays scatter plot of individual data points outside of given sigma contour. show_correlations : boolean, optional, default: False If True show Pearson's correlation coeffiction in each tile. sigma_levels : array_like, optional, default: None If not None contour levels to be plotted in units of the standard deviation. sigma_linecolor : str or tuple, optional, default: None If not None linecolor of the sigma contour lines. Defaults to a value from the chosen colormap. sigma_smooth : float, optional, default: 3. Smoothing factor for hexbin plot to smooth out contour lines. title_quantiles : (1,) or (3,) array-like, optional, default: None One-dimensional array of either size one or size three with the feature quantiles to be plotted as histogram titles. width : float, optional, default: 3. Width of a single tile of the corner plot. Returns: -------- (fig, ax) : tuple Tuple containing the figure and axes objects. """ # Number of dimensions _, Nd = data.shape # Validate colormap if isinstance(cmap, str): cm = colormaps[cmap] elif isinstance(cmap, Colormap): cm = cmap else: raise ValueError( "'cmap' has to be either type str or LinearSegmentedColormap.") # Setting colors hl_lc = highlight_linecolor if highlight_linecolor is not None else cm(1.) hl_mc = highlight_markercolor if highlight_markercolor is not None \ else cm(0.) sig_lc = sigma_linecolor if sigma_linecolor is not None else cm(1.) hist_bc = hist_backgroundcolor if hist_backgroundcolor is not None \ else cm(0.) hist_ec = hist_edgecolor if hist_edgecolor is not None else cm(1.) hist_fc = hist_facecolor if hist_facecolor is not None else cm(0.5) scat_mc = scatter_markercolor if scatter_markercolor is not None else cm( 0.5) # Validate labels set_labels = False if labels is None else True if set_labels: if len(labels) != Nd: raise ValueError( "Size of 'labels' does not match number of dimensions.") # Validate highlights set_highlights = False if highlight is None else True if set_highlights: if len(highlight) != Nd: raise ValueError( "Size of 'highlight' does not match number of dimensions.") # Produce levels if sigma_levels is not None: # Converting to 1d quantiles quants = [] for s in sigma_levels: q = sigma_to_quantile(s) quants.extend([50.-q, 50.+q]) quants = np.array(quants) quants.sort() # Converting to 2d sigma levels levels2d = 1. - np.exp(-0.5*np.array(sigma_levels)**2) # Threshold for scatter plot if scatter_outside_sigma is not None: scatter_thr = 1. - np.exp(-0.5*np.array(scatter_outside_sigma)**2) # Validate title quantiles if title_quantiles is not None: if len(title_quantiles) not in [1, 3]: raise ValueError( "'title_quantiles' has to have either size 1 or size 3.") title_quantiles = np.array(title_quantiles) title_quantiles.sort() # Compute correlation coefficients if show_correlations: rho = np.corrcoef(data.T) cor_tc = correlation_textcolor if correlation_textcolor is not None else cm( 1.) # Setting limits lim = np.empty((Nd, 2)) if limits is not None: l = np.array(limits) if l.shape not in [(2,), (Nd, 2)]: raise ValueError("'limits' does not have the correct shape.") lim[...] = l[None, :] if len(l.shape) == 1 else l else: lim[:, 0] = data.min(axis=0) lim[:, 1] = data.max(axis=0) fig, ax = plt.subplots(nrows=Nd, ncols=Nd, figsize=(Nd * width, Nd * width)) for i in range(Nd**2): # x and y coordinates of subplots ix, iy = np.divmod(i, Nd) if iy > ix: # Remove axes outside of corner plot fig.delaxes(ax[ix, iy]) elif ix == iy: # Histograms ax[ix, iy].set_facecolor(hist_bc) ax[ix, iy].hist( data[:, ix], bins=hist_bins, color=cm(1.0), facecolor=[hist_fc], edgecolor=[hist_ec], ) if sigma_levels is not None: perc = np.percentile(data[:, ix], quants) for p in perc: ax[ix, iy].axvline(p, ls="--", lw=1, c=sig_lc) if set_highlights: ax[ix, iy].axvline(highlight[iy], ls="-", color=hl_lc) if title_quantiles is not None: perc = np.percentile(data[:, ix], title_quantiles) if len(title_quantiles) == 1: if labels is not None: prefix = "{} $= ".format(labels[iy]) else: prefix = "$" msg = "{}{{{:.2f}}}$".format( prefix, perc[0]) else: diff = np.diff(perc) if labels is not None: prefix = "{} $= ".format(labels[iy]) else: prefix = "$" msg = "{}{{{:.2f}}}^{{+{:.2f}}}_{{-{:.2f}}}$".format( prefix, perc[1], diff[1], diff[0]) ax[ix, iy].set_title(msg, fontsize="x-large") else: # Hexbin plots h = ax[ix, iy].hexbin( data[:, iy], data[:, ix], cmap=cm, gridsize=hex_gridsize, linewidths=0.5, edgecolor=cm(0), extent=(lim[iy, 0], lim[iy, 1], lim[ix, 0], lim[ix, 1]) ) # Data gridding if sigma_levels given or scatter plot required if sigma_levels is not None or scatter_outside_sigma is not None: # Get positions and values from hexbin plots xy = h.get_offsets() z = h.get_array() # Compute gridded data for contour plot grid_x, grid_y = np.mgrid[ xy[:, 0].min():xy[:, 0].max():3*hex_gridsize*1j, xy[:, 1].min():xy[:, 1].max():3*hex_gridsize*1j ] zi = griddata(xy, z, (grid_x, grid_y), method="linear") # Smooth data to forget about hexplot grid zi = gaussian_filter(zi, sigma_smooth) # Convert sigma levels to data levels z_flattened = zi.flatten() z_ordered = z_flattened[z_flattened.argsort()[::-1]] cumsum = z_ordered.cumsum() cumsum /= cumsum[-1] # Contourplot if sigma_levels is not None: # Compute levels levels = np.empty(len(levels2d)) for i, s in enumerate(levels2d): levels[i] = z_ordered[np.abs(cumsum-s).argmin()] levels.sort() # Draw plot ax[ix, iy].contour( grid_x, grid_y, zi, levels=levels, colors=[sig_lc], linewidths=1., linestyles="--" ) # Scatter plot if scatter_outside_sigma is not None: # Compute threshold value scat_thr = z_ordered[np.abs(cumsum-scatter_thr).argmin()] # Nearest neighbor search NNDI = RegularGridInterpolator( (grid_x[:, 0], grid_y[0, :]), zi, method="linear", bounds_error=False, fill_value=0. ) mask = NNDI((data[:, iy], data[:, ix])) < scat_thr ax[ix, iy].scatter( data[mask, iy], data[mask, ix], marker=".", s=0.5, alpha=scatter_alpha, c=[scat_mc]) # Correlation coefficient if show_correlations: msg = r"$\rho = {{{:.2f}}}$".format(rho[ix, iy]) ax[ix, iy].text( 0.02, 0.98, msg, va="top", ha="left", transform=ax[ix, iy].transAxes, c=cor_tc ) # Set the highlights if set_highlights: ax[ix, iy].axvline(highlight[iy], ls="-", color=hl_lc) ax[ix, iy].axhline(highlight[ix], ls="-", color=hl_lc) ax[ix, iy].plot( highlight[iy], highlight[ix], marker=(6, 0, 0), markersize=8, c=hl_mc, markeredgecolor=hl_lc ) # Setting limits if ix == Nd - 1: ax[ix, iy].set_xlim(lim[iy, 0], lim[iy, 1]) if ix > 0 and iy == 0: ax[ix, iy].set_ylim(lim[ix, 0], lim[ix, 1]) # Set y-labels if ix > 0 and iy == 0: if set_labels: ax[ix, iy].set_ylabel(labels[ix], fontsize="x-large") elif iy < ix: ax[ix, iy].sharey(ax[ix, 0]) plt.setp(ax[ix, iy].get_yticklabels(), visible=False) else: plt.setp(ax[ix, iy].get_yticklabels(), visible=False) # Set x-labels if ix == Nd-1: if set_labels: ax[ix, iy].set_xlabel(labels[iy], fontsize="x-large") else: ax[ix, iy].sharex(ax[Nd-1, iy]) plt.setp(ax[ix, iy].get_xticklabels(), visible=False) fig.tight_layout() plt.show() return fig, ax
[docs] def sigma_to_quantile(sig: float) -> float: """ Function converts standard deviation of a normal distribution to quantile. Parameters: ----------- sig : float Standard deviation Returns: -------- q : float Quantile """ return _gaussian_primitive(-sig, 0., 1.)*100
def _gaussian_primitive(x: float, mu: float, sig: float) -> float: """ Function returns the primitive of the normal distribution Parameters: ----------- x : float, array-like Evaluation coordinate mu : float Mean of normal distribution sig : float Standard deviation of normal distribution Returns: -------- F(x) : float, array-like Primitive at x """ return -0.5*erf((x-mu)/(np.sqrt(2.)*sig))