#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plotting utilities, leveraging Matplotlib.
"""
__all__ = ['set_style', 'closefig', 'savefig']
import matplotlib.pyplot as plt
import pathlib
import matplotlib as mpl
import numpy as np
import pandas as pd
from ..utils import lipdutils
# import pandas as pd
# from matplotlib.patches import Rectangle
# from matplotlib.collections import PatchCollection
# from matplotlib.colors import ListedColormap
# import seaborn as sns
# this is here because it's only used to set labels in plots
[docs]def infer_period_unit_from_time_unit(time_unit):
''' infer a period unit based on the given time unit
'''
if time_unit is None:
period_unit = None
else:
unit_group = lipdutils.timeUnitsCheck(time_unit)
if unit_group != 'unknown':
if unit_group == 'kage_units':
period_unit = 'kyrs'
else:
period_unit = 'yrs'
else:
period_unit = f'{time_unit}'
# if time_unit[-1] == 's':
# period_unit = time_unit
# else:
# period_unit = f'{time_unit}s'
return period_unit
[docs]def scatter_xy(x, y, c=None, figsize=None, xlabel=None, ylabel=None, title=None,
xlim=None, ylim=None, savefig_settings=None, ax=None,
legend=True, plot_kwargs=None, lgd_kwargs=None):
"""
Make scatter plot.
Parameters
----------
x : numpy.array
x value
y : numpy.array
y value
c : array-like or list of colors or color, optional
The marker colors. Possible values:
- A scalar or sequence of n numbers to be mapped to colors using cmap and norm.
- A 2D array in which the rows are RGB or RGBA.
- A sequence of colors of length n.
- A single color format string.
Note that c should not be a single numeric RGB or RGBA sequence because that is indistinguishable from an array of values to be colormapped. If you want to specify the same RGB or RGBA value for all points, use a 2D array with a single row. Otherwise, value-matching will have precedence in case of a size matching with x and y.
If you wish to specify a single color for all points prefer the color keyword argument.
Defaults to None. In that case the marker color is determined by the value of color, facecolor or facecolors. In case those are not specified or None, the marker color is determined by the next color of the Axes' current "shape and fill" color cycle. This cycle defaults to rcParams["axes.prop_cycle"] (default: cycler('color', ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'])).
figsize : list, optional
A list of two integers indicating the dimension of the figure. The default is None.
xlabel : str, optional
x-axis label. The default is None.
ylabel : str, optional
y-axis label. The default is None.
title : str, optional
Title for the plot. The default is None.
xlim : list, optional
Limits for the x-axis. The default is None.
ylim : list, optional
Limits for the y-axis. The default is None.
savefig_settings : dict, optional
the dictionary of arguments for plt.savefig(); some notes below:
- "path" must be specified; it can be any existing or non-existing path,
with or without a suffix; if the suffix is not given in "path", it will follow "format"
- "format" can be one of {"pdf", "eps", "png", "ps"}
The default is None.
ax : pyplot.axis, optional
The axis object. The default is None.
legend : bool, optional
Whether to include a legend. The default is True.
plot_kwargs : dict, optional
the keyword arguments for ax.plot(). The default is None.
lgd_kwargs : dict, optional
the keyword arguments for ax.legend(). The default is None.
Returns
-------
ax : the pyplot.axis object
"""
savefig_settings = {} if savefig_settings is None else savefig_settings.copy()
plot_kwargs = {} if plot_kwargs is None else plot_kwargs.copy()
lgd_kwargs = {} if lgd_kwargs is None else lgd_kwargs.copy()
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
ax.scatter(x, y, c=c, **plot_kwargs)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if title is not None:
ax.set_title(title)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if legend:
ax.legend(**lgd_kwargs)
else:
ax.legend().remove()
if 'fig' in locals():
if 'path' in savefig_settings:
savefig(fig, settings=savefig_settings)
return fig, ax
else:
return ax
[docs]def plot_scatter_xy(x1, y1, x2, y2, figsize=None, xlabel=None,
ylabel=None, title=None, xlim=None, ylim=None,
savefig_settings=None, ax=None, legend=True,
plot_kwargs=None, lgd_kwargs=None):
''' Plot a scatter on top of a line plot.
Parameters
----------
x1 : array
x axis of timeseries1 - plotted as a line
y1 : array
values of timeseries1 - plotted as a line
x2 : array
x axis of scatter points
y2 : array
y of scatter points
figsize : list
a list of two integers indicating the figure size
xlabel : str
label for x-axis
ylabel : str
label for y-axis
title : str
the title for the figure
xlim : list
set the limits of the x axis
ylim : list
set the limits of the y axis
ax : pyplot.axis
the pyplot.axis object
legend : bool
plot legend or not
lgd_kwargs : dict
the keyword arguments for ax.legend()
plot_kwargs : dict
the keyword arguments for ax.plot()
savefig_settings : dict
the dictionary of arguments for plt.savefig(); some notes below:
- "path" must be specified; it can be any existing or non-existing path,
with or without a suffix; if the suffix is not given in "path", it will follow "format"
- "format" can be one of {"pdf", "eps", "png", "ps"}
Returns
-------
ax : the pyplot.axis object
See also
--------
pyleoclim.utils.plotting.set_style : set different styles for the figures. Should be set before invoking the plotting functions
pyleoclim.utils.plotting.savefig : save figures
'''
# handle dict defaults
savefig_settings = {} if savefig_settings is None else savefig_settings.copy()
plot_kwargs = {} if plot_kwargs is None else plot_kwargs.copy()
lgd_kwargs = {} if lgd_kwargs is None else lgd_kwargs.copy()
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
ax.plot(x1, y1, **plot_kwargs, color='green')
ax.scatter(x2, y2, color='red')
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if title is not None:
ax.set_title(title)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if legend:
ax.legend(**lgd_kwargs)
else:
ax.legend().remove()
if 'fig' in locals():
if 'path' in savefig_settings:
savefig(fig, settings=savefig_settings)
return fig, ax
else:
return ax
[docs]def plot_xy(x, y, figsize=None, xlabel=None, ylabel=None, title=None,
xlim=None, ylim=None, savefig_settings=None, ax=None,
legend=True, plot_kwargs=None, lgd_kwargs=None,
invert_xaxis=False, invert_yaxis=False):
''' Plot a timeseries
Parameters
----------
x : array
The time axis for the timeseries
y : array
The values of the timeseries
figsize : list
a list of two integers indicating the figure size
xlabel : str
label for x-axis
ylabel : str
label for y-axis
title : str
the title for the figure
xlim : list
set the limits of the x axis
ylim : list
set the limits of the y axis
ax : pyplot.axis
the pyplot.axis object
legend : bool
plot legend or not
lgd_kwargs : dict
the keyword arguments for ax.legend()
plot_kwargs : dict
the keyword arguments for ax.plot()
mute : bool
if True, the plot will not show;
recommend to turn on when more modifications are going to be made on ax
(going to be deprecated)
savefig_settings : dict
the dictionary of arguments for plt.savefig(); some notes below:
- "path" must be specified; it can be any existing or non-existing path,
with or without a suffix; if the suffix is not given in "path", it will follow "format"
- "format" can be one of {"pdf", "eps", "png", "ps"}
invert_xaxis : bool, optional
if True, the x-axis of the plot will be inverted
invert_yaxis : bool, optional
if True, the y-axis of the plot will be inverted
Returns
-------
ax : the pyplot.axis object
See Also
--------
pyleoclim.utils.plotting.set_style : set different styles for the figures. Should be set before invoking the plotting functions
pyleoclim.utils.plotting.savefig : save figures
'''
# handle dict defaults
savefig_settings = {} if savefig_settings is None else savefig_settings.copy()
plot_kwargs = {} if plot_kwargs is None else plot_kwargs.copy()
lgd_kwargs = {} if lgd_kwargs is None else lgd_kwargs.copy()
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
ax.plot(x, y, **plot_kwargs)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if title is not None:
ax.set_title(title)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if legend:
ax.legend(**lgd_kwargs)
else:
ax.legend().remove()
if invert_xaxis:
ax.invert_xaxis()
if invert_yaxis:
ax.invert_yaxis()
if 'fig' in locals():
if 'path' in savefig_settings:
savefig(fig, settings=savefig_settings)
return fig, ax
else:
return ax
def stripes_xy(x, y, cmap='coolwarm', figsize=None, ax=None,
vmin=None, vmax=None, xlabel=None, ylabel=None,
title=None, xlim=None, savefig_settings=None, label_color=None,
x_offset=0.05, label_size=None, show_xaxis=False,
invert_xaxis=False, top_label=None, bottom_label=None):
'''
Represent y = f(x) as an Ed Hawkins "warming stripes" pattern
Uses Matplotlib's pcolormesh'
Credit: https://esmvalgroup.github.io/ESMValTool_Tutorial/files/warming_stripes.py
Parameters
----------
x : array
Independent variable
y : array
Dependent variable (asumees centered and normalized to unit standard deviation)
cmap: str
colormap name
figsize : list
a list of two integers indicating the figure size
ax : pyplot.axis
the pyplot.axis object, default is None
vmin: float
lower bound for colormap normalization
vmax: float
upper bound for colormap normalization
top_label : str
the "title" label for the stripe. Set to '' if no label is wanted
bottom_label : str
the "ylabel" explaining which variable is being plotted. Set to '' if no label is wanted
label_size : int
size of the text in labels (in points). Default is the Matplotlib 'axes.labelsize'] rcParams
xlim : list
set the limits of the x axis
x_offset : float (0-1)
value controlling the horizontal offset between stripes and labels (default = 0.05)
show_xaxis : bool
flag indicating whether or not the x-axis should be shown (default = False)
savefig_settings : dict
the dictionary of arguments for plt.savefig(); some notes below:
- "path" must be specified; it can be any existing or non-existing path,
with or without a suffix; if the suffix is not given in "path", it will follow "format"
- "format" can be one of {"pdf", "eps", "png", "ps"}
invert_xaxis : bool, optional
if True, the x-axis of the plot will be inverted
See Also
--------
https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.pcolormesh.html
https://matplotlib.org/stable/tutorials/colors/colormapnorms.html
Returns
-------
ax, or (fig, ax) if no axes were provided.
'''
# handle dict defaults
savefig_settings = {} if savefig_settings is None else savefig_settings.copy()
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
if label_size is None:
label_size = mpl.rcParams['axes.labelsize']
ones = np.array([0, 1])
# ax.set_axis_off()
ax.pcolormesh(x, ones, np.vstack([y, y]), cmap=cmap,
vmin=vmin, vmax=vmax, shading='auto')
# hide y axis
ax.get_yaxis().set_visible(False)
ax.spines['left'].set_visible(False)
# manage x axis
ax.spines['bottom'].set_visible(show_xaxis)
ax.get_xaxis().set_visible(show_xaxis)
if show_xaxis is True and xlabel is not None:
ax.set_xlabel(xlabel)
# parameters for label position
thickness = ax.get_ybound()[1]
xmax = ax.get_xbound()[1] * (1 + x_offset / 10)
# xmax = x.max()*0.8*(1+x_offset)
ax.text(xmax, 0.5 * thickness, top_label, color=label_color,
fontsize=label_size, fontweight='bold')
ax.text(xmax, 0 * thickness, bottom_label, color=label_color,
fontsize=label_size)
if ylabel is not None:
ax.set_ylabel(ylabel)
if title is not None:
ax.set_title(title)
if xlim is not None:
ax.set_xlim(xlim)
if invert_xaxis:
ax.invert_xaxis()
if 'fig' in locals():
# fig.tight_layout()
if 'path' in savefig_settings:
savefig(fig, settings=savefig_settings)
return fig, ax
else:
return ax
[docs]def closefig(fig=None):
'''Close the figure
Parameters
----------
fig : matplotlib.pyplot.figure
The matplotlib figure object
'''
if fig is not None:
plt.close(fig)
else:
plt.close()
[docs]def savefig(fig, path=None, dpi=300, settings={}, verbose=True):
''' Save a figure to a path
Parameters
----------
fig : matplotlib.pyplot.figure
the figure to save
path : str
the path to save the figure, can be ignored and specify in "settings" instead
dpi : int
resolution in dot (pixels) per inch. Default: 300.
settings : dict
the dictionary of arguments for plt.savefig(); some notes below:
- "path" must be specified in settings if not assigned with the keyword argument;
it can be any existing or non-existing path, with or without a suffix;
if the suffix is not given in "path", it will follow "format"
- "format" can be one of {"pdf", "eps", "png", "ps"}
verbose : bool, {True,False}
If True, print the path of the saved file.
'''
if path is None and 'path' not in settings:
raise ValueError('"path" must be specified, either with the keyword argument or be specified in `settings`!')
savefig_args = {'bbox_inches': 'tight', 'path': path, 'dpi': dpi}
savefig_args.update(settings)
path = pathlib.Path(savefig_args['path'])
savefig_args.pop('path')
dirpath = path.parent
if not dirpath.exists():
dirpath.mkdir(parents=True, exist_ok=True)
if verbose:
print(f'Directory created at: "{dirpath}"')
path_str = str(path)
if path.suffix not in ['.eps', '.pdf', '.png', '.ps']:
path = pathlib.Path(f'{path_str}.pdf')
fig.savefig(path_str, **savefig_args)
plt.close(fig)
if verbose:
print(f'Figure saved at: "{str(path)}"')
[docs]def set_style(style='journal', font_scale=1.0, dpi=300):
''' Modify the visualization style
This function is inspired by `Seaborn <https://github.com/mwaskom/seaborn>`_.
Parameters
----------
style : {journal,web,matplotlib,_spines, _nospines,_grid,_nogrid}
set the styles for the figure:
- journal (default): fonts appropriate for paper
- web: web-like font (e.g. ggplot)
- matplotlib: the original matplotlib style
In addition, the following options are available:
- _spines/_nospines: allow to show/hide spines
- _grid/_nogrid: allow to show gridlines (default: _grid)
font_scale : float
Default is 1. Corresponding to 12 Font Size.
'''
font_dict = {
'font.size': 10,
'axes.labelsize': 11,
'axes.titlesize': 12,
'xtick.labelsize': 10,
'ytick.labelsize': 10,
'legend.fontsize': 9,
}
style_dict = {}
if 'journal' in style:
style_dict.update({
'axes.axisbelow': True,
'axes.facecolor': 'white',
'axes.edgecolor': 'black',
'axes.grid': True,
'grid.color': 'lightgrey',
'grid.linestyle': '--',
'xtick.direction': 'out',
'ytick.direction': 'out',
'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif'],
'axes.spines.left': True,
'axes.spines.bottom': True,
'axes.spines.right': False,
'axes.spines.top': False,
'legend.frameon': False,
'axes.linewidth': 1,
'grid.linewidth': 1,
'lines.linewidth': 2,
'lines.markersize': 6,
'patch.linewidth': 1,
'xtick.major.width': 1.25,
'ytick.major.width': 1.25,
'xtick.minor.width': 0,
'ytick.minor.width': 0,
})
elif 'web' in style:
style_dict.update({
'figure.facecolor': 'white',
'axes.axisbelow': True,
'axes.facecolor': 'whitesmoke',
'axes.edgecolor': 'lightgrey',
'axes.grid': True,
'grid.color': 'white',
'grid.linestyle': '-',
'xtick.direction': 'out',
'ytick.direction': 'out',
'text.color': 'grey',
'axes.labelcolor': 'grey',
'xtick.color': 'grey',
'ytick.color': 'grey',
'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif'],
'axes.spines.left': False,
'axes.spines.bottom': False,
'axes.spines.right': False,
'axes.spines.top': False,
'legend.frameon': False,
'axes.linewidth': 1,
'grid.linewidth': 1,
'lines.linewidth': 2,
'lines.markersize': 6,
'patch.linewidth': 1,
'xtick.major.width': 1.25,
'ytick.major.width': 1.25,
'xtick.minor.width': 0,
'ytick.minor.width': 0,
})
elif 'matplotlib' in style:
#mpl.rcParams.update(mpl.rcParamsDefault)
style_dict.update({})
else:
raise ValueError(f'Style [{style}] not availabel!')
if '_spines' in style:
style_dict.update({
'axes.spines.left': True,
'axes.spines.bottom': True,
'axes.spines.right': True,
'axes.spines.top': True,
})
elif '_nospines' in style:
style_dict.update({
'axes.spines.left': False,
'axes.spines.bottom': False,
'axes.spines.right': False,
'axes.spines.top': False,
})
if '_grid' in style:
style_dict.update({
'axes.grid': True,
})
elif '_nogrid' in style:
style_dict.update({
'axes.grid': False,
})
figure_dict = {
'savefig.dpi': dpi,
}
# modify font size based on font scale
font_dict.update({k: v * font_scale for k, v in font_dict.items()})
for d in [style_dict, font_dict, figure_dict]:
mpl.rcParams.update(d)
[docs]def make_phantom_ax(ax):
''' Remove all visual annotation from ax object
This function removes axis lines, axis labels, tick labels, tick marks and grid lines.
Parameters
----------
ax : matplotlib.axes.Axes object
the axes object to clear
Returns
-------
ax : matplotlib.axis
the axis object from matplotlib
See [matplotlib.axes](https://matplotlib.org/stable/api/axes_api.html) for details.
'''
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.set_yticks([])
# _ax.set_xlim(xlim)
ax.tick_params(axis='x', which='both', length=0)
ax.set_xlabel('')
ax.set_ylabel('')
ax.grid(False)
ax.set_xticklabels([])
ax.set_yticklabels([])
return ax
[docs]def make_annotation_ax(fig, ax, loc='overlay',
ax_name='highlighted_intervals',
height=None, v_offset=0, b=None,
width=None, h_offset=0, l=None,
zorder=-1):
''' Makes a clean axis for adding annotation
This function creates a new axes for adding annotation.
If the bottom left corner is not specified, it is established based on the ax objects in ax.
If there is only one ax object, this overkill, but is helpful to introduce annotations that span multiple data axes.
Parameters
----------
ax : matplotlib.axes.Axes object or dict
If ax is a dict, assumes data axes are assigned to integer keys and
supplemental axes have string keys
loc : string
if "overlay", annotation ax will attempt to cover the area with data axes
if "above", annotation ax will be located directly above the top data ax
if "below", annotation ax will be located below the bottom data ax
ax_name : string
name associated with new ax object
height : float
height of annotation ax
if loc = "above" or "below", height=.025 if not specified
if loc = "overlay", height=vertical span of data axes, if not specified
v_offset : float
vertical offset between data plot area and annotation ax
a positive v_offset will place the bottom corner higher
width : float
width of annotation ax
horizontal span of data axes, if not specified
b : float
location of bottom corner of annotation ax
h_offset : float
horizontal offset from left corner
a positive h_offset will place the left corner farther to the right
l : float
location of left corner of annotation ax
zorder : numeric
index of annotation ax layer in fig
zorder = -1 will place the layer behind other layers
zorder = 1000 will place the layer in front of other layers
Returns
-------
ax_d : dict
ax_d contains the original ax object(s) and new annotation ax assigned to specified ax_name
See [matplotlib.axes](https://matplotlib.org/stable/api/axes_api.html) for details.
'''
if type(ax) != dict:
ax_d = {0: ax}
else:
ax_d = ax
ll = []
ur = []
keys_list = [key for key in ax_d.keys() if type(key) == int]
keys_list.sort()
for ax_key in keys_list:
bbox_coords = ax_d[ax_key].get_position()
ll.append(bbox_coords._points[0].tolist())
ur.append(bbox_coords._points[1].tolist())
xlims = ax_d[ax_key].get_xlim()
if l is None:
l = min([_ll[0] for _ll in ll])
u = max([_ur[1] for _ur in ur])
r = max([_ur[0] for _ur in ur])
if loc == 'overlay':
if b is None:
b = min([_ll[1] for _ll in ll])
if height is None:
height = u - b
else:
if height is None:
height = .025
if loc == 'above':
if b is None:
b = u
if loc == 'below':
if b is None:
b = min([_ll[1] for _ll in ll]) - height
if width is None:
width = r - l
b += v_offset
l += h_offset
ax_d[ax_name] = fig.add_axes([l, b, width, height],
**{'zorder': zorder})
ax_d[ax_name].set_xlim(xlims)
ax_d[ax_name] = make_phantom_ax(ax_d[ax_name])
ax_d[ax_name].set_facecolor((1, 1, 1, 0))
return ax_d
import matplotlib.patches as mpatches
[docs]def hightlight_intervals(ax, intervals, labels=None, color='g', alpha=.3, legend=True):
''' Hightlights intervals
This function highlights intervals.
Parameters
----------
ax : matplotlib.axes.Axes object
intervals : list
list of intervals to be highlighted
color : string or list
If a string is passed, all intervals will be the specified color
If a list is passed, the list is expected to be the same length as intervals
alpha : float or list
If a float is passed, all intervals will have the same specified alpha value
If a list is passed, the list is expected to be the same length as intervals
Returns
-------
ax : matplotlib.axis
the axis object from matplotlib
See [matplotlib.axes](https://matplotlib.org/stable/api/axes_api.html) for details.
Examples
--------
.. jupyter-execute::
import pyleoclim as pyleo
ts_18 = pyleo.utils.load_dataset('cenogrid_d18O')
ts_13 = pyleo.utils.load_dataset('cenogrid_d13C')
ms = pyleo.MultipleSeries([ts_18, ts_13], label='Cenogrid', time_unit='ma BP')
fig, ax = ms.stackplot(linewidth=0.5, fill_between_alpha=0)
ax=pyleo.utils.plotting.make_annotation_ax(fig, ax, ax_name = 'highlighted_intervals', zorder=-1)
intervals = [[3, 8], [12, 18], [30, 31], [40,43], [49, 60], [60, 65]]
ax['highlighted_intervals'] = pyleo.utils.plotting.hightlight_intervals(ax['highlighted_intervals'], intervals,
color='g', alpha=.1)
'''
if isinstance(intervals[0], list) is False:
intervals = [intervals]
handles = []
new_labels = []
new_colors = []
new_alphas = []
for ik, _ts in enumerate(intervals):
if isinstance(color, list) is True:
c = color[ik]
else:
c = color
new_colors.append(c)
if isinstance(alpha, list) is True:
a = alpha[ik]
else:
a = alpha
new_alphas.append(a)
if isinstance(labels, list) is True:
label = labels[ik]
else:
label = ''
new_labels.append(label)
ax.axvspan(_ts[0], _ts[1], facecolor=c, alpha=a)
return ax
[docs]def get_label_width(ax, label, buffer=0., fontsize=10):
"""
Helper function to find width of text when rendered in ax object
"""
text = ax.text(0, 0, label, size=fontsize)
width = text.get_window_extent(renderer=ax.figure.canvas.get_renderer()).width
text.remove() # Remove the text used for measurement
return width + buffer
[docs]def calculate_overlapping_sets(fig, ax, labels, x_locs, fontsize, buffer=.1):
"""
Calculate overlapping sets of labels based on their positions and widths.
This function identifies sets of labels that would overlap if rendered at the same height on a plot.
It is used to determine how to place labels to avoid overlap in visualizations.
Parameters:
-----------
ax : matplotlib.axes.Axes
The Axes object on which the labels will be plotted.
labels : list of str
A list of label strings.
x_locs : list of float
A list of x-coordinates where the labels are to be positioned.
fontsize : int
The font size used for the labels.
buffer : float, optional
Additional space to consider around each label to prevent overlap.
Defaults to 0.1.
Returns:
--------
list of list of int: A list where each sublist contains the indices of overlapping labels.
"""
# Calculate the horizontal span of each label
intervals = []
for i, label in enumerate(labels):
w = get_label_width(ax, label, buffer=buffer, fontsize=fontsize)
# ann = ax.text(x_locs[i], 0, label, size=fontsize)
# box = ax.transData.inverted().transform(ann.get_tightbbox(fig.canvas.get_renderer()))
# w = box[1][0] - box[0][0] + buffer
# ann.remove()
interval = pd.Interval(left=x_locs[i] - w / 2, right=x_locs[i] + w / 2)
intervals.append(interval)
# Group overlapping labels
overlapping_sets = []
for i, interval_i in enumerate(intervals):
found = False
for overlap_set in overlapping_sets:
if any(interval_i.overlaps(intervals[j]) for j in overlap_set):
overlap_set.add(i)
found = True
break
if not found:
overlapping_sets.append({i})
# Convert sets to sorted lists
return [sorted(list(s)) for s in overlapping_sets]
[docs]def label_intervals(fig, ax, labels, x_locs, orientation='north', overlapping_sets=None, baseline=0.5,
height=0.5, buffer=0.1, fontsize=10, linewidth=None, linestyle_kwargs=None,
text_kwargs=None
):
"""
Place labels on a plot with given orientations and style parameters, avoiding overlaps.
This function positions labels at specified x-locations with adjustments to avoid overlaps.
Labels can be oriented either above (north) or below (south) a baseline.
Parameters:
--------
ax : matplotlib.axes.Axes
The Axes object where the labels are to be placed.
labels : list of str
A list of label strings.
x_locs : list of float
A list of x-coordinates for the labels.
orientation : str, optional
The vertical orientation of the labels, either 'north' or 'south'. Defaults to 'north'.
overlapping_sets : list of list of int, optional
Precomputed overlapping sets of labels. If None, the function will compute them. Defaults to None.
baseline : float, optional
The baseline height for the first label slot. Defaults to 0.5.
height : float, optional
The vertical spacing between slots. Defaults to 0.5.
buffer : float, optional
Horizontal buffer space around labels to prevent overlap. Defaults to 0.1.
fontsize : int, optional
Font size for labels. Defaults to 10.
linewidth : float, optional
Line width for connecting lines. If None, defaults to 1.
linestyle_kwargs : dict, optional
Additional keyword arguments for styling the connecting lines (per Matplotlib).
text_kwargs : dict, optional
Additional keyword arguments for styling the text labels (per Matplotlib).
Returns:
--------
matplotlib.axes.Axes: The modified Axes object with labels placed.
Examples
--------
.. jupyter-execute::
import pyleoclim as pyleo
import numpy as np
ts_18 = pyleo.utils.load_dataset('cenogrid_d18O')
ts_13 = pyleo.utils.load_dataset('cenogrid_d13C')
ms = pyleo.MultipleSeries([ts_18, ts_13], label='Cenogrid', time_unit='ma BP')
fig, ax = ms.stackplot(linewidth=0.5, fill_between_alpha=0)
ax=pyleo.utils.plotting.make_annotation_ax(fig, ax, ax_name = 'epochs', height=.03,
loc='above', v_offset=.015,zorder=-2)
ax['epochs'].set_facecolor((1, 1, 1, 0))
ceno_intervals_pairs = [[0.0, 0.01], [0.01, 1.6], [1.6, 5.3], [5.3, 23.7], [23.7, 36.6], [36.6, 57.8], [57.8, 66.4]]
ceno_epoch_labels = ['Holocene', 'Pleistocene', 'Pliocene', 'Miocene', 'Oligocene', 'Eocene', 'Paleocene']
ax['epochs'].set_ylim([-1,0])
colors = ['r', 'm', 'orange', 'blue', 'green', 'aqua', 'navy', 'pink']#['r', 'b']#'r' if ik%2 ==0 else 'b' for ik, _ts in enumerate(geo_ts)]
ax['epochs'] = pyleo.utils.plotting.hightlight_intervals(ax['epochs'],
ceno_intervals_pairs, color=colors,
alpha=.1)
### EPOCHS (labels)
ax=pyleo.utils.plotting.make_annotation_ax(fig, ax['epochs'], ax_name = 'epoch_annotation',
zorder=1, v_offset=0.01,
height=.25, loc='above')
x_locs = [np.mean(interval) for interval in ceno_intervals_pairs]
ax['epoch_annotation'].set_ylim([0,3])
ax['epoch_annotation'] = pyleo.utils.plotting.label_intervals(fig, ax['epoch_annotation'], ceno_epoch_labels, x_locs,
orientation='north', baseline=.45, height=0.35, buffer=0.1,
linestyle_kwargs= {'color':'gray'}, text_kwargs={'fontsize':10, 'va':'bottom'}
)
"""
if linestyle_kwargs is None:
linestyle_kwargs = {}
linestyle_defaults = {'linestyle': '--', 'color': 'gray', 'linewidth': 1 if linewidth is None else linewidth}
for key in linestyle_defaults:
if key not in linestyle_kwargs:
linestyle_kwargs[key] = linestyle_defaults[key]
if text_kwargs is None:
text_kwargs = {}
text_defaults = {'fontsize': 10 if fontsize is None else fontsize, 'ha': 'center'}
for key in text_defaults:
if key not in text_kwargs:
text_kwargs[key] = text_defaults[key]
fontsize = text_kwargs['fontsize']
# if overlapping sets aren't specified, calculate them
if overlapping_sets is None:
overlapping_sets = calculate_overlapping_sets(fig, ax, labels, x_locs, fontsize, buffer=buffer)
label_alignments = ['center' for _ in labels]
label_slots = [0 for _ in labels]
for overlap_set in overlapping_sets:
if len(overlap_set) > 1:
sorted_set = sorted(overlap_set, key=lambda i: x_locs[i])
peak = len(sorted_set) // 2
for i, label_index in enumerate(sorted_set):
label_slots[label_index] = i if i <= peak else peak - (i - peak)
cluster_min, cluster_max = x_locs[sorted_set[0]], x_locs[sorted_set[-1]]
for i, label_index in enumerate(sorted_set):
if i == 0:
label_alignments[label_index] = 'right'
else:
if len(sorted_set) == 2:
label_alignments[label_index] = 'center'
else:
if i == int((len(sorted_set) - 1) / 2):
label_alignments[label_index] = 'center'
elif i > int((len(sorted_set) - 1) / 2):
label_alignments[label_index] = 'left'
else:
label_width = get_label_width(ax, labels[label_index], buffer=buffer, fontsize=fontsize)
# label_width = get_label_width(labels[label_index])
if x_locs[label_index] - label_width / 2 < cluster_min:
label_alignments[label_index] = 'right'
elif x_locs[label_index] + label_width / 2 > cluster_max:
label_alignments[label_index] = 'left'
else:
label_alignments[label_index] = 'center'
else:
label_index = overlap_set[0]
label_alignments[label_index] = 'center'
for i, label in enumerate(labels):
label_text_kwargs = text_kwargs.copy()
slot_height = baseline + label_slots[i] * height if orientation == 'north' else -baseline - label_slots[
i] * height
label_text_kwargs['ha'] = label_alignments[i]
if 'va' not in label_text_kwargs:
label_text_kwargs['va'] = 'bottom' if orientation == 'north' else 'top'
ax.text(x_locs[i], slot_height, label, **label_text_kwargs)
ax.plot([x_locs[i], x_locs[i]], [0, slot_height], **linestyle_kwargs)
return ax