Source code for pyleoclim.core.coherences

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
The Coherence class stores the result of Series.wavelet_coherence(), whether WWZ or CWT.
It includes wavelet transform coherency and cross-wavelet transform.
"""
import multiprocessing

import dill

from ..core.scalograms import MultipleScalogram, Scalogram
from ..utils import plotting
from ..utils import wavelet as waveutils

# Set `dill` as the pickler for multiprocessing
multiprocessing.set_start_method(
    "spawn", force=True
)  # Use "fork" (most compatible with dill)
multiprocessing.get_context("spawn").reduce = dill.dumps
multiprocessing.get_context("spawn").rebuild = dill.loads

import warnings
from concurrent.futures import ProcessPoolExecutor  # parallel processing library
from contextlib import contextmanager
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np

# from matplotlib import cm
from matplotlib import gridspec
from matplotlib.ticker import FormatStrFormatter, ScalarFormatter
from scipy.stats.mstats import mquantiles
from tqdm import tqdm


def _run_wavelet_coherence(args):
    """Helper function for parallel wavelet coherence computation."""
    surr1_series, surr2_series, wave_method, wave_args = args
    return surr1_series.wavelet_coherence(
        surr2_series, method=wave_method, settings=wave_args
    )


def _run_global_coherence(args):
    """Helper function for computing global coherence between surrogate series."""
    surr_series1, surr_series2, wavelet_kwargs = args
    return surr_series1.global_coherence(
        surr_series2, wavelet_kwargs=wavelet_kwargs
    ).global_coh


@contextmanager
def _get_process_pool():
    ctx = multiprocessing.get_context("spawn")
    with ProcessPoolExecutor(mp_context=ctx) as executor:
        yield executor


[docs]class Coherence: """Coherence object, meant to receive the WTC and XWT part of Series.wavelet_coherence() See also -------- pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence method """ def __init__( self, frequency, scale, time, wtc, xwt, phase, coi=None, wave_method=None, wave_args=None, timeseries1=None, timeseries2=None, signif_qs=None, signif_method=None, qs=None, freq_method=None, freq_kwargs=None, Neff_threshold=3, scale_unit=None, time_label=None, ): self.frequency = np.array(frequency) self.time = np.array(time) self.scale = np.array(scale) self.wtc = np.array(wtc) self.xwt = np.array(xwt) if coi is not None: self.coi = np.array(coi) else: self.coi = waveutils.make_coi(self.time, Neff_threshold=Neff_threshold) self.phase = np.array(phase) self.timeseries1 = timeseries1 self.timeseries2 = timeseries2 self.signif_qs = signif_qs self.signif_method = signif_method self.freq_method = freq_method self.freq_kwargs = freq_kwargs self.wave_method = wave_method if wave_args is not None: if "freq" in wave_args.keys(): wave_args["freq"] = np.array(wave_args["freq"]) if "tau" in wave_args.keys(): wave_args["tau"] = np.array(wave_args["tau"]) self.wave_args = wave_args self.qs = qs if scale_unit is not None: self.scale_unit = scale_unit elif timeseries1 is not None: self.scale_unit = plotting.infer_period_unit_from_time_unit( timeseries1.time_unit ) elif timeseries2 is not None: self.scale_unit = plotting.infer_period_unit_from_time_unit( timeseries2.time_unit ) else: self.scale_unit = None if time_label is not None: self.time_label = time_label elif timeseries1 is not None: if timeseries1.time_unit is not None: self.time_label = f"{timeseries1.time_name} [{timeseries1.time_unit}]" else: self.time_label = f"{timeseries1.time_name}" elif timeseries2 is not None: if timeseries2.time_unit is not None: self.time_label = f"{timeseries2.time_name} [{timeseries2.time_unit}]" else: self.time_label = f"{timeseries2.time_name}" else: self.time_label = None
[docs] def copy(self): """Copy object""" return deepcopy(self)
[docs] def plot( self, var="wtc", xlabel=None, ylabel=None, title="auto", figsize=[10, 8], ylim=None, xlim=None, in_scale=True, yticks=None, contourf_style={}, phase_style={}, cbar_style={}, savefig_settings={}, ax=None, signif_clr="white", signif_linestyles="-", signif_linewidths=1, signif_thresh=0.95, under_clr="ivory", over_clr="black", bad_clr="dimgray", ): """Plot the cross-wavelet results Parameters ---------- var : str {'wtc', 'xwt'} variable to be plotted as color field. Default: 'wtc', the wavelet transform coherency. 'xwt' plots the cross-wavelet transform instead. xlabel : str, optional x-axis label. The default is None. ylabel : str, optional y-axis label. The default is None. title : str, optional Title of the plot. The default is 'auto', where it is made from object metadata. To mute, pass title = None. figsize : list, optional Figure size. The default is [10, 8]. ylim : list, optional y-axis limits. The default is None. xlim : list, optional x-axis limits. The default is None. in_scale : bool, optional Plots scales instead of frequencies The default is True. yticks : list, optional y-ticks label. The default is None. contourf_style : dict, optional Arguments for the contour plot. The default is {}. phase_style : dict, optional Arguments for the phase arrows. The default is {}. It includes: - 'pt': the default threshold above which phase arrows will be plotted - 'skip_x': the number of points to skip between phase arrows along the x-axis - 'skip_y': the number of points to skip between phase arrows along the y-axis - 'scale': number of data units per arrow length unit (see matplotlib.pyplot.quiver) - 'width': shaft width in arrow units (see matplotlib.pyplot.quiver) - 'color': arrow color (see matplotlib.pyplot.quiver) cbar_style : dict, optional Arguments for the color bar. The default is {}. savefig_settings : dict, optional The default is {}. the dictionary of arguments for plt.savefig(); some notes below: - "path" must be specified; it can be any existed or non-existed path, with or without a suffix; if the suffix is not given in "path", it will follow "format" - "format" can be one of {"pdf", "eps", "png", "ps"} ax : ax, optional Matplotlib axis on which to return the figure. The default is None. signif_thresh: float in [0, 1] Significance threshold. Default is 0.95. If this quantile is not found in the qs field of the Coherence object, the closest quantile will be picked. signif_clr : str, optional Color of the significance line. The default is 'white'. signif_linestyles : str, optional Style of the significance line. The default is '-'. signif_linewidths : float, optional Width of the significance line. The default is 1. under_clr : str, optional Color for under 0. The default is 'ivory'. over_clr : str, optional Color for over 1. The default is 'black'. bad_clr : str, optional Color for missing values. The default is 'dimgray'. Returns ------- fig, ax See also -------- pyleoclim.core.coherence.Coherence.dashboard : plots a a dashboard showing the coherence and the cross-wavelet transform. pyleoclim.core.series.Series.wavelet_coherence : computes the coherence from two timeseries. matplotlib.pyplot.quiver : quiver plot Examples -------- Calculate the wavelet coherence of NINO3 and All India Rainfall and plot it: .. jupyter-execute:: ts_air = pyleo.utils.load_dataset('AIR') ts_nino = pyleo.utils.load_dataset('NINO3') coh = ts_air.wavelet_coherence(ts_nino) coh.plot() Establish significance against an AR(1) benchmark: .. jupyter-execute:: coh_sig = coh.signif_test(number=20, qs=[.9,.95,.99]) coh_sig.plot() Note that specifiying 3 significance thresholds does not take any more time as the quantiles are simply estimated from the same ensemble. By default, the plot function looks for the closest quantile to 0.95, but this is easy to adjust, e.g. for the 99th percentile: .. jupyter-execute:: coh_sig.plot(signif_thresh = 0.99) By default, the function plots the wavelet transform coherency (WTC), which quantifies where two timeseries exhibit similar behavior in time-frequency space, regardless of whether this corresponds to regions of high common power. To visualize the latter, you want to plot the cross-wavelet transform (XWT) instead, like so: .. jupyter-execute:: coh_sig.plot(var='xwt') """ if ax is None: fig, ax = plt.subplots(figsize=figsize) # handling NaNs mask_freq = [] for i in range(np.size(self.frequency)): if all(np.isnan(self.wtc[:, i])): mask_freq.append(False) else: mask_freq.append(True) if in_scale: y_axis = self.scale[mask_freq] if ylabel is None: ylabel = ( f"Scale [{self.scale_unit}]" if self.scale_unit is not None else "Scale" ) if yticks is None: yticks_default = np.array( [ 0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 1e4, 2e4, 5e4, 1e5, 2e5, 5e5, 1e6, ] ) mask = (yticks_default >= np.min(y_axis)) & ( yticks_default <= np.max(y_axis) ) yticks = yticks_default[mask] else: y_axis = self.frequency[mask_freq] if ylabel is None: ylabel = ( f"Frequency [1/{self.scale_unit}]" if self.scale_unit is not None else "Frequency" ) if signif_thresh > 1 or signif_thresh < 0: raise ValueError("The significance threshold must be in [0, 1] ") # plot color field for WTC or XWT contourf_args = { "cmap": "magma", "origin": "lower", } contourf_args.update(contourf_style) cmap = plt.get_cmap(contourf_args["cmap"]) cmap.set_under(under_clr) cmap.set_over(over_clr) cmap.set_bad(bad_clr) contourf_args["cmap"] = cmap if var == "wtc": lev = np.linspace(0, 1, 11) cont = ax.contourf( self.time, y_axis, self.wtc[:, mask_freq].T, levels=lev, **contourf_args ) elif var == "xwt": cont = ax.contourf( self.time, y_axis, self.xwt[:, mask_freq].T, levels=11, **contourf_args ) # just pass number of contours else: raise ValueError("Unknown variable; please choose either 'wtc' or 'xwt'") # plot significance levels if self.signif_qs is not None: signif_method_label = { "ar1": "AR(1)", } if signif_thresh not in self.qs: isig = np.abs(np.array(self.qs) - signif_thresh).argmin() print( "Significance threshold {:3.2f} not found in qs. Picking the closest, which is {:3.2f}".format( signif_thresh, self.qs[isig] ) ) else: isig = self.qs.index(signif_thresh) if var == "wtc": signif_coh = self.signif_qs[0].scalogram_list[ isig ] # extract WTC significance threshold signif_boundary = ( self.wtc[:, mask_freq].T / signif_coh.amplitude[:, mask_freq].T ) elif var == "xwt": signif_coh = self.signif_qs[1].scalogram_list[ isig ] # extract XWT significance threshold signif_boundary = ( self.xwt[:, mask_freq].T / signif_coh.amplitude[:, mask_freq].T ) ax.contour( self.time, y_axis, signif_boundary, [-99, 1], colors=signif_clr, linestyles=signif_linestyles, linewidths=signif_linewidths, ) if title is not None: ax.set_title("Lines:" + str(round(self.qs[isig] * 100)) + "% threshold") # plot colorbar cbar_args = { "label": var.upper(), "drawedges": False, "orientation": "vertical", "fraction": 0.15, "pad": 0.05, "ticks": cont.levels, } cbar_args.update(cbar_style) # assign colorbar to axis (instead of fig) : https://matplotlib.org/stable/gallery/subplots_axes_and_figures/colorbar_placement.html cb = plt.colorbar(cont, ax=ax, **cbar_args) # plot cone of influence ax.set_yscale("log") ax.plot(self.time, self.coi, "k--") if ylim is None: ylim = [np.min(y_axis), np.min([np.max(y_axis), np.max(self.coi)])] ax.fill_between(self.time, self.coi, np.max(self.coi), color="white", alpha=0.5) if yticks is not None: ax.set_yticks(yticks) ax.yaxis.set_major_formatter(ScalarFormatter()) ax.yaxis.set_major_formatter(FormatStrFormatter("%g")) if xlabel is None: xlabel = self.time_label if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) # plot phase skip_x = np.max([int(np.size(self.time) // 20), 1]) skip_y = np.max([int(np.size(y_axis) // 20), 1]) phase_args = { "pt": 0.5, "skip_x": skip_x, "skip_y": skip_y, "scale": 30, "width": 0.004, } phase_args.update(phase_style) pt = phase_args["pt"] skip_x = phase_args["skip_x"] skip_y = phase_args["skip_y"] scale = phase_args["scale"] width = phase_args["width"] if "color" in phase_style: color = phase_style["color"] else: color = "black" phase = np.copy(self.phase)[:, mask_freq] if self.signif_qs is None: if var == "wtc": phase[self.wtc[:, mask_freq] < pt] = np.nan else: field = self.xwt[:, mask_freq] phase[field < pt * field.max()] = np.nan else: phase[signif_boundary.T < 1] = np.nan X, Y = np.meshgrid(self.time, y_axis) U, V = np.cos(phase).T, np.sin(phase).T ax.quiver( X[::skip_y, ::skip_x], Y[::skip_y, ::skip_x], U[::skip_y, ::skip_x], V[::skip_y, ::skip_x], scale=scale, width=width, zorder=99, color=color, ) ax.set_ylim(ylim) if xlim is not None: ax.set_xlim(xlim) lbl1 = self.timeseries1.label lbl2 = self.timeseries2.label if "fig" in locals(): if "path" in savefig_settings: plotting.savefig(fig, settings=savefig_settings) if title is not None and title != "auto": fig.suptitle(title) elif title == "auto" and lbl1 is not None and lbl2 is not None: title = ( "Wavelet coherency (" + self.wave_method.upper() + ") between " + lbl1 + " and " + lbl2 ) fig.suptitle(title) return fig, ax else: return ax
[docs] def dashboard( self, title=None, figsize=[9, 12], overlap=True, phase_style={}, line_colors=["tab:blue", "tab:orange"], savefig_settings={}, ts_plot_kwargs=None, wavelet_plot_kwargs=None, ): """Cross-wavelet dashboard, including the two series, their WTC and XWT. Note: this design balances many considerations, and is not easily customizable. Parameters ---------- title : str, optional Title of the plot. The default is None. figsize : list, optional Figure size. The default is [9, 12], as this is an information-rich figure. overlap : boolean, optional whether to restrict the plot to the period of overlap between the series. Defaults to True phase_style : dict, optional Arguments for the phase arrows. The default is {}. It includes: - 'pt': the default threshold above which phase arrows will be plotted - 'skip_x': the number of points to skip between phase arrows along the x-axis - 'skip_y': the number of points to skip between phase arrows along the y-axis - 'scale': number of data units per arrow length unit (see matplotlib.pyplot.quiver) - 'width': shaft width in arrow units (see matplotlib.pyplot.quiver) - 'color': arrow color (see matplotlib.pyplot.quiver) line_colors : list, optional Colors for the 2 traces For nomenclature, see https://matplotlib.org/stable/gallery/color/named_colors.html savefig_settings : dict, optional The default is {}. the dictionary of arguments for plt.savefig(); some notes below: - "path" must be specified; it can be any existed or non-existed path, with or without a suffix; if the suffix is not given in "path", it will follow "format" - "format" can be one of {"pdf", "eps", "png", "ps"} ts_plot_kwargs : dict arguments to be passed to the timeseries subplot, see pyleoclim.core.series.Series.plot for details wavelet_plot_kwargs : dict arguments to be passed to the contour subplots (XWT and WTC), [see pyleoclim.core.coherence.Coherence.plot for details] Returns ------- fig, ax See also -------- pyleoclim.core.coherence.Coherence.plot : creates a coherence plot pyleoclim.core.series.Series.wavelet_coherence : computes the coherence between two timeseries. pyleoclim.core.series.Series.plot: plots a timeseries matplotlib.pyplot.quiver: makes a quiver plot Examples -------- Calculate the coherence of NINO3 and All India Rainfall and plot it as a dashboard: .. jupyter-execute:: ts_air = pyleo.utils.load_dataset('AIR') ts_nino = pyleo.utils.load_dataset('NINO3') coh = ts_air.wavelet_coherence(ts_nino) coh_sig = coh.signif_test(number=10) coh_sig.dashboard() You may customize colors like so: .. jupyter-execute:: coh_sig.dashboard(line_colors=['teal','gold']) To export the figure, use `savefig_settings`: .. jupyter-execute:: coh_sig.dashboard(savefig_settings={'path':'./coh_dash.png','dpi':300}) """ # prepare options dictionaries savefig_settings = {} if savefig_settings is None else savefig_settings.copy() wavelet_plot_kwargs = ( {} if wavelet_plot_kwargs is None else wavelet_plot_kwargs.copy() ) ts_plot_kwargs = {} if ts_plot_kwargs is None else ts_plot_kwargs.copy() # create figure fig = plt.figure(figsize=figsize) gs = gridspec.GridSpec(8, 1) gs.update(wspace=0, hspace=0.5) # add some breathing room ax = {} # assess period of overlap xlims = np.min(self.time), np.max(self.time) # 1) plot timeseries # plt.rc('ytick', labelsize=8) ax["ts1"] = plt.subplot(gs[0:2, 0]) self.timeseries1.plot( ax=ax["ts1"], color=line_colors[0], **ts_plot_kwargs, legend=False ) ax["ts1"].yaxis.label.set_color(line_colors[0]) ax["ts1"].tick_params(axis="y", colors=line_colors[0], labelsize=8) ax["ts1"].spines["left"].set_color(line_colors[0]) ax["ts1"].spines["bottom"].set_visible(False) ax["ts1"].grid(False) ax["ts1"].set_xlabel("") if overlap: ax["ts1"].set_xlim(xlims) ax["ts2"] = ax["ts1"].twinx() self.timeseries2.plot( ax=ax["ts2"], color=line_colors[1], **ts_plot_kwargs, legend=False ) ax["ts2"].yaxis.label.set_color(line_colors[1]) ax["ts2"].tick_params(axis="y", colors=line_colors[1], labelsize=8) ax["ts2"].spines["right"].set_color(line_colors[1]) ax["ts2"].spines["right"].set_visible(True) ax["ts2"].spines["left"].set_visible(False) ax["ts2"].grid(False) if overlap: ax["ts2"].set_xlim(xlims) # 2) plot WTC ax["wtc"] = plt.subplot(gs[2:5, 0], sharex=ax["ts1"]) if "cbar_style" not in wavelet_plot_kwargs: wavelet_plot_kwargs.update( {"cbar_style": {"orientation": "horizontal", "pad": 0.15, "aspect": 60}} ) self.plot(var="wtc", ax=ax["wtc"], title=None, **wavelet_plot_kwargs) # ax['wtc'].xaxis.set_visible(False) # hide x axis ax["wtc"].set_xlabel("") # 3) plot XWT ax["xwt"] = plt.subplot(gs[5:8, 0], sharex=ax["ts1"]) if "phase_style" not in wavelet_plot_kwargs: wavelet_plot_kwargs.update({"phase_style": {"color": "lightgray"}}) self.plot( var="xwt", ax=ax["xwt"], title=None, contourf_style={"cmap": "viridis"}, cbar_style={"orientation": "horizontal", "pad": 0.2, "aspect": 60}, phase_style=wavelet_plot_kwargs["phase_style"], ) # gs.tight_layout(fig) # this does nothing if "fig" in locals(): if "path" in savefig_settings: plotting.savefig(fig, settings=savefig_settings) return fig, ax else: return ax
[docs] def signif_test( self, number=200, method="ar1sim", seed=None, qs=[0.95], settings=None, mute_pbar=False, ): """Significance testing for Coherence objects The method obtains quantiles `qs` of the distribution of coherence between `number` pairs of Monte Carlo simulations of a process that resembles the original series. Currently, only AR(1) surrogates are supported. Parameters ---------- number : int, optional Number of surrogate series to create for significance testing. The default is 200. method : {'ar1sim','phaseran','CN'}, optional Method through which to generate the surrogate series. The default is 'phaseran'. seed : int, optional Fixes the seed for NumPy's random number generator. Useful for reproducibility. The default is None, so fresh, unpredictable entropy will be pulled from the operating system. qs : list, optional Significance levels to return. The default is [0.95]. settings : dict, optional Parameters for surrogate model. The default is None. mute_pbar : bool, optional Mute the progress bar. The default is False. Returns ------- new : pyleoclim.core.coherence.Coherence original Coherence object augmented with significance levels signif_qs, a list with the following `MultipleScalogram` objects: * 0: MultipleScalogram for the wavelet transform coherency (WTC) * 1: MultipleScalogram for the cross-wavelet transform (XWT) Each object contains as many Scalogram objects as qs contains values See also -------- pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence pyleoclim.core.scalograms.Scalogram : Scalogram object pyleoclim.core.scalograms.MultipleScalogram : Multiple Scalogram object pyleoclim.core.coherence.Coherence.plot : plotting method for Coherence objects Examples -------- Calculate the coherence of NINO3 and All India Rainfall and assess significance: .. jupyter-execute:: ts_air = pyleo.utils.load_dataset('AIR') ts_nino = pyleo.utils.load_dataset('NINO3') coh = ts_air.wavelet_coherence(ts_nino) coh_sig = coh.signif_test(number=20) coh_sig.plot() By default, significance is assessed against a 95% benchmark derived from an AR(1) process fit to the data, using 200 Monte Carlo simulations. To customize, one can increase the number of simulations (more reliable, but slower), and the quantile levels. .. jupyter-execute:: coh_sig2 = coh.signif_test(number=100, qs=[.9,.95,.99]) coh_sig2.plot() The plot() function will represent the 95% level as contours by default. If you need to show 99%, say, use the `signif_thresh` argument: .. jupyter-execute:: coh_sig2.plot(signif_thresh=0.99) Note that if the 99% quantile is not present, the plot method will look for the closest match, but lines are always labeled appropriately. For reproducibility purposes, it may be good to specify the (pseudo)random number generator's seed, like so: .. jupyter-execute:: coh_sig27 = coh.signif_test(number=20, seed=27) This will generate exactly the same set of draws from the (pseudo)random number at every execution, which may be important for marginal features in small ensembles. In general, however, we recommend increasing the number of draws to check that features are robust. One can also specifiy a different method to obtain surrogates, e.g. phase randomization: .. jupyter-execute:: coh.signif_test(method='phaseran').plot() """ from ..core.series import ( Series, # This is necessary for the multiprocessing pickling process!!! DO NOT REMOVE!!!!! ) from ..core.surrogateseries import SurrogateSeries if number == 0: return self new = self.copy() surr1 = SurrogateSeries(method=method, number=number, seed=seed) surr1.from_series(self.timeseries1) surr2 = SurrogateSeries(method=method, number=number, seed=seed) surr2.from_series(self.timeseries2) # Prepare arguments for parallel processing args = [ ( surr1.series_list[i], surr2.series_list[i], self.wave_method, self.wave_args, ) for i in range(number) ] # Perform wavelet coherence calculations in parallel with _get_process_pool() as executor: results = list( tqdm( executor.map(_run_wavelet_coherence, args), total=number, desc="Performing wavelet coherence on surrogate pairs", disable=mute_pbar, ) ) # Split results into wtcs and xwts wtcs = [result.wtc for result in results] xwts = [result.xwt for result in results] wtcs = np.array(wtcs) xwts = np.array(xwts) ne, nf, nt = np.shape(wtcs) # reshape because mquantiles only accepts inputs of at most 2D wtcs_r = np.reshape(wtcs, (ne, nf * nt)) xwts_r = np.reshape(xwts, (ne, nf * nt)) # define nd-arrays nq = len(qs) wtc_qs = np.ndarray(shape=(nq, nf, nt)) xwt_qs = np.empty_like(wtc_qs) # extract quantiles and reshape wtc_qs = mquantiles(wtcs_r, qs, axis=0) wtc_qs = np.reshape(wtc_qs, (nq, nf, nt)) xwt_qs = mquantiles(xwts_r, qs, axis=0) xwt_qs = np.reshape(xwt_qs, (nq, nf, nt)) # put in Scalogram objects for export wtc_list, xwt_list = [], [] for i in range(nq): wtc_tmp = Scalogram( frequency=self.frequency, time=self.time, amplitude=wtc_qs[i, :, :], coi=self.coi, scale=self.scale, freq_method=self.freq_method, freq_kwargs=self.freq_kwargs, label=f"{qs[i]*100:g}%", ) wtc_list.append(wtc_tmp) xwt_tmp = Scalogram( frequency=self.frequency, time=self.time, amplitude=xwt_qs[i, :, :], coi=self.coi, scale=self.scale, freq_method=self.freq_method, freq_kwargs=self.freq_kwargs, label=f"{qs[i]*100:g}%", ) xwt_list.append(xwt_tmp) new.signif_qs = [] new.signif_qs.append( MultipleScalogram(scalogram_list=wtc_list) ) # Export WTC quantiles new.signif_qs.append( MultipleScalogram(scalogram_list=xwt_list) ) # Export XWT quantiles new.signif_method = method new.qs = qs return new
[docs] def phase_stats(self, scales, number=1000, level=0.05): """Estimate phase angle statistics of a Coherence object As per [1], the strength (consistency) of a phase relationship is assessed using: * sigma, the circular standard deviation * kappa, an estimate of the Von Mises distribution's concentration parameter. It is a reciprocal measure of dispersion, so 1/kappa is analogous to the variance) [3]. Because of inherent persistence of geophysical signals and of the reproducing kernel of the continuous wavelet transform [3], phase statistics are assessed relative to an AR(1) model fit to the angle deviations observed at the requested scale(s). Specifically, if `number` is specified, the method simulates `number` Monte Carlo realizations of an AR(1) process fit to fluctuations around the mean angle. This ensemble is used to obtain the confidence limits: `sigma_lo` (`level` quantile) and `kappa_hi` (1-`level` quantile). These correspond to 1-tailed tests of the strength of the relationship. Parameters ---------- scales : float scale at which to evaluate the phase angle number : int, optional number of AR(1) series to create for significance testing. The default is 1000. level : float, optional significance level against which to gauge sigma and kappa. default: 0.05 Returns ------- result : dict contains angle_mean (the mean angle for those scales), sigma (the circular standard deviation), kappa, sigma_lo (alpha-level quantile for sigma) and kappa_hi, the (1-alpha)-level quantile for kappa. See also -------- pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence pyleoclim.core.scalograms.Scalogram : Scalogram object pyleoclim.core.scalograms.MultipleScalogram : Multiple Scalogram object pyleoclim.core.coherence.Coherence.plot : plotting method for Coherence objects pyleoclime.utils.wavelet.angle_sig : significance of phase angle statistics pyleoclim.utils.wavelet.angle_stats: phase angle statistics References ---------- [1] Grinsted, A., J. C. Moore, and S. Jevrejeva (2004), Application of the cross wavelet transform and wavelet coherence to geophysical time series, Nonlinear Processes in Geophysics, 11, 561–566. [2] Huber, R., Dutra, L. V., & da Costa Freitas, C. (2001). SAR interferogram phase filtering based on the Von Mises distribution. In IGARSS 2001. Scanning the Present and Resolving the Future. Proceedings. IEEE 2001 International Geoscience and Remote Sensing Symposium (Cat. No. 01CH37217) (Vol. 6, pp. 2816-2818). IEEE. [3] Farge, M. and Schneider, K. (2006): Wavelets: application to turbulence Encyclopedia of Mathematical Physics (Eds. J.-P. Françoise, G. Naber and T.S. Tsun) pp 408-420. Examples -------- Calculate the phase angle between NINO3 and All India Rainfall at 5y scales: .. jupyter-execute:: ts_air = pyleo.utils.load_dataset('AIR') ts_nino = pyleo.utils.load_dataset('NINO3') coh = ts_air.wavelet_coherence(ts_nino) coh.phase_stats(scales=5) One may also obtain phase angle statistics over an interval, like the 2-8y ENSO band: .. jupyter-execute:: phase = coh.phase_stats(scales=[2,8]) print("The mean angle is {:4.2f}°".format(phase.mean_angle/np.pi*180)) print(phase) From this example, one diagnoses a strong anti-phased relationship in the ENSO band, with high von Mises concentration (kappa ~ 3.35 >> kappa_hi) and low circular dispersion (sigma ~ 0.6 << sigma_lo). This would be strong evidence of a consistent anti-phasing between NINO3 and AIR at those scales. """ scales = np.array(scales) if scales.max() > self.scale.max(): warnings.warn( "Requested scale exceeds largest scale in object. Truncating to " + str(self.scale.max()) ) if scales.size == 1: scale_idx = np.argmin(np.abs(self.scale - scales)) res = waveutils.angle_sig(self.phase[:, scale_idx], nMC=number, level=level) elif scales.size == 2: idx_lo = np.argmin(np.abs(self.scale - scales.min())) idx_hi = np.argmin(np.abs(self.scale - scales.max())) if idx_hi >= idx_lo: raise ValueError( "Insufficiently spaced scales. Please pick a single one, or a wider interval" ) else: # average phase over those scales nt, ns = self.phase.shape phase = np.empty((nt)) for i in range(nt): phase[i], _, _ = waveutils.angle_stats(self.phase[i, idx_hi:idx_lo]) res = waveutils.angle_sig( phase, nMC=number, level=level ) # assess significance return res
[docs]class GlobalCoherence: """Class to store the results of cross spectral analysis Parameters ---------- global_coh: numpy array coherence values scale: numpy array scale values frequency: numpy array frequency values coi: numpy array cone of influence values coh: Coherence Original coherence object See Also -------- pyleoclim.core.series.Series.global_coherence : method to compute the spectral coherence """ def __init__( self, global_coh, coh, signif_qs=None, signif_method=None, qs=None, label="Coherence", ): self.global_coh = global_coh self.label = label self.coh = coh self.signif_qs = signif_qs self.signif_method = signif_method self.qs = qs
[docs] def copy(self): """Copy object""" return deepcopy(self)
[docs] def signif_test(self, method="ar1sim", number=200, qs=[0.95]): """Perform a significance test on the coherence values Parameters ---------- method: str; {'ar1sim','CN','phaseran'} method to use for the surrogate test. Default is 'ar1sim'. number: int number of surrogates to generate. Default is 200 qs: list list of quantiles to compute. Default is [.95] Returns ------- global_coh: pyleoclim.core.globalcoherence.GlobalCoherence Global coherence with significance field filled in Examples -------- .. jupyter-execute:: soi = pyleo.utils.load_dataset('SOI') nino3 = pyleo.utils.load_dataset('NINO3') gcoh = soi.global_coherence(nino3) gcoh_sig = gcoh.signif_test(number=10) gcoh_sig.plot() """ from ..core.surrogateseries import SurrogateSeries new = self.copy() ts1 = self.coh.timeseries1 ts2 = self.coh.timeseries2 surr1 = SurrogateSeries(method=method, number=number) surr2 = SurrogateSeries(method=method, number=number) surr1.from_series(ts1) surr2.from_series(ts2) coh_array = np.empty((number, len(self.global_coh))) wavelet_kwargs = { "freq": self.coh.frequency, # pass the frequency axis directly "settings": self.coh.wave_args, "method": self.coh.wave_method, } # Prepare arguments for parallel processing args = [ (surr1.series_list[i], surr2.series_list[i], wavelet_kwargs) for i in range(number) ] # Use DillProcessPoolExecutor for parallel execution with _get_process_pool() as executor: results = list( tqdm( executor.map(_run_global_coherence, args), total=number, desc="Computing global coherence for surrogate pairs", disable=False, ) ) # Collect results into coh_array for i, result in enumerate(results): coh_array[i, :] = result # Compute quantiles quantiles = mquantiles(coh_array, qs, axis=0) new.signif_qs = quantiles.data new.signif_method = method new.qs = qs return new
[docs] def plot( self, figsize=(8, 8), xlim=None, xlabel=None, label=None, psd_y_label="PSD", coh_y_label="Coherence", coh_line_color="grey", ax=None, coh_ylim=(0.4, 1), fill_alpha=0.3, fill_color="grey", coh_plot_kwargs=None, savefig_settings=None, spectral_kwargs=None, legend=True, legend_kwargs=None, spec1_plot_kwargs=None, spec2_plot_kwargs=None, ): """Plot the coherence as a function of scale or frequency, alongside the spectrum of the two timeseries (using the same method used for the coherence). Parameters ---------- figsize: tuple size of the figure. Default is (8,8). Only used if ax is None xlim: tuple x limits for the plot. Default is None label: str label of the plot xlabel: str x label of the plot psd_y_label: str y label of the power spectral density plot (left hand side) coh_y_label: str y label of the coherence plot (right hand side) coh_line_color: str color of the coherence line coh_ylim: tuple y limits for the coherence plot. Default is (.4,1) fill_alpha: float alpha value for the fill_between plot. Default is .3 fill_color : str color of the fill_between plot coh_plot_kwargs: dict additional arguments to pass to the pyleoclim.utils.plotting.plot_xy savefig_settings: dict settings to pass to the pyleoclim.utils.plotting.savefig function spectral_kwargs: dict additional arguments to pass to the pyleo.Series.spectral method spec1_plot_kwargs: dict additional arguments to pass to the pyleo.Series.spectral method spec2_plot_kwargs: dict additional arguments to pass to the pyleo.Series.spectral method legend: bool whether to include a legend or not legend_kwargs: dict additional arguments to pass to ax.legend ax: matplotlib axis axis to plot on Returns ------- ax: matplotlib axis axis with the plot Examples -------- .. jupyter-execute:: soi = pyleo.utils.load_dataset('SOI') nino3 = pyleo.utils.load_dataset('NINO3') gcoh = soi.global_coherence(nino3) gcoh.plot()""" coh_plot_kwargs = {} if coh_plot_kwargs is None else coh_plot_kwargs.copy() savefig_settings = {} if savefig_settings is None else savefig_settings.copy() spectral_kwargs = {} if spectral_kwargs is None else spectral_kwargs.copy() legend_kwargs = {} if legend_kwargs is None else legend_kwargs.copy() spec1_plot_kwargs = ( {} if spec1_plot_kwargs is None else spec1_plot_kwargs.copy() ) spec2_plot_kwargs = ( {} if spec2_plot_kwargs is None else spec2_plot_kwargs.copy() ) if ax is None: fig, ax = plt.subplots(figsize=figsize) else: pass coh_dict = self.coh.__dict__ if "method" not in spectral_kwargs: spectral_kwargs.update({"method": coh_dict["wave_method"]}) if "freq" not in spectral_kwargs: spectral_kwargs.update({"freq": coh_dict["freq_method"]}) if "freq_kwargs" not in spectral_kwargs: spectral_kwargs.update({"freq_kwargs": coh_dict["freq_kwargs"]}) if spectral_kwargs["method"] == coh_dict["wave_method"]: for key, value in coh_dict["wave_args"].items(): if key not in spectral_kwargs: spectral_kwargs.update({key: value}) ts1 = coh_dict["timeseries1"] ts2 = coh_dict["timeseries2"] spec1 = ts1.spectral(label=ts1.label, **spectral_kwargs) spec2 = ts2.spectral(label=ts2.label, **spectral_kwargs) spec1.plot(ax=ax, **spec1_plot_kwargs) spec2.plot(ax=ax, **spec2_plot_kwargs) if xlim is not None: ax.set_xlim(xlim) if xlabel is not None: ax.set_xlabel(xlabel) if psd_y_label is not None: ax.set_ylabel(psd_y_label) ax2 = ax.twinx() if coh_line_color is not None: coh_plot_kwargs.update({"color": coh_line_color}) if coh_y_label is not None: ax2.set_ylabel(coh_y_label) if coh_ylim is not None: ax2.set_ylim(coh_ylim) if label is None: label = self.label coh_plot_kwargs.update({"label": label}) scale = coh_dict["scale"] ax2.plot(scale, self.global_coh, **coh_plot_kwargs) ax2.fill_between(scale, 0, self.global_coh, color=fill_color, alpha=fill_alpha) ax2.grid(False) # plot significance levels if present if self.signif_qs is not None: signif_method_label = { "ar1sim": "AR(1) simulations (MoM)", "phaseran": "Phase Randomization", "CN": "Colored Noise", } for i, q in enumerate(self.signif_qs): ax.plot( scale, q, label=f"{signif_method_label[self.signif_method]}, {self.qs[i]} threshold", color="red", linestyle="dashed", linewidth=0.8, ) # formatting if legend: if len(legend_kwargs) == 0: ax.legend().set_visible(False) lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax2.legend(lines + lines2, labels + labels2) else: lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() if "handles" not in legend_kwargs: legend_kwargs.update({"handles": lines + lines2}) if "labels" not in legend_kwargs: legend_kwargs.update({"labels": labels + labels2}) ax.legend(**legend_kwargs) ax2.legend().set_visible(False) else: ax.legend().set_visible(False) ax2.legend().set_visible(False) if "fig" in locals(): if "path" in savefig_settings: plotting.savefig(fig, settings=savefig_settings) return fig, ax else: return ax