Source code for elephant.phase_analysis

# -*- coding: utf-8 -*-
"""
Methods for performing phase analysis.

.. autosummary::
    :toctree: _toctree/phase_analysis

    spike_triggered_phase
    phase_locking_value
    mean_phase_vector
    phase_difference
    weighted_phase_lag_index

References
----------

.. bibliography:: ../bib/elephant.bib
   :labelprefix: ph
   :keyprefix: phase-
   :style: unsrt

:copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`.
:license: Modified BSD, see LICENSE.txt for details.
"""

from __future__ import division, print_function, unicode_literals

import numpy as np
import quantities as pq
import neo

__all__ = [
    "spike_triggered_phase",
    "phase_locking_value",
    "mean_phase_vector",
    "phase_difference",
    "weighted_phase_lag_index"
]


[docs] def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): """ Calculate the set of spike-triggered phases of a `neo.AnalogSignal`. Parameters ---------- hilbert_transform : neo.AnalogSignal or list of neo.AnalogSignal `neo.AnalogSignal` of the complex analytic signal (e.g., returned by the `elephant.signal_processing.hilbert` function). If `hilbert_transform` is only one signal, all spike trains are compared to this signal. Otherwise, length of `hilbert_transform` must match the length of `spiketrains`. spiketrains : neo.SpikeTrain or list of neo.SpikeTrain Spike trains on which to trigger `hilbert_transform` extraction. interpolate : bool If True, the phases and amplitudes of `hilbert_transform` for spikes falling between two samples of signal is interpolated. If False, the closest sample of `hilbert_transform` is used. Returns ------- phases : list of np.ndarray Spike-triggered phases. Entries in the list correspond to the `neo.SpikeTrain`s in `spiketrains`. Each entry contains an array with the spike-triggered angles (in rad) of the signal. amp : list of pq.Quantity Corresponding spike-triggered amplitudes. times : list of pq.Quantity A list of times corresponding to the signal. They correspond to the times of the `neo.SpikeTrain` referred by the list item. Raises ------ ValueError If the number of spike trains and number of phase signals don't match, and neither of the two are a single signal. Examples -------- Create a 20 Hz oscillatory signal sampled at 1 kHz and a random Poisson spike train, then calculate spike-triggered phases and amplitudes of the oscillation: >>> import neo >>> import elephant >>> import quantities as pq >>> import numpy as np ... >>> f_osc = 20. * pq.Hz >>> f_sampling = 1 * pq.ms >>> tlen = 100 * pq.s ... >>> time_axis = np.arange( ... 0, tlen.magnitude, ... f_sampling.rescale(pq.s).magnitude) * pq.s >>> analogsignal = neo.AnalogSignal( ... np.sin(2 * np.pi * (f_osc * time_axis).simplified.magnitude), ... units=pq.mV, t_start=0*pq.ms, sampling_period=f_sampling) >>> spiketrain = (elephant.spike_train_generation. ... homogeneous_poisson_process( ... 50 * pq.Hz, t_start=0.0*pq.ms, t_stop=tlen.rescale(pq.ms))) ... >>> phases, amps, times = elephant.phase_analysis.spike_triggered_phase( ... elephant.signal_processing.hilbert(analogsignal), ... spiketrain, ... interpolate=True) >>> phases # doctest: +SKIP [array([-0.57890515, 1.03105904, -0.82241075, ..., 0.90023903, 2.23702263, 2.93744259])] >>> amps # doctest: +SKIP [array([0.86117412, 1.08918248, 0.98256318, ..., 1.05760518, 1.08407016, 1.01927305]) * dimensionless] >>> times # doctest: +SKIP [array([6.41327152e+00, 2.02715221e+01, 1.05827312e+02, ..., 9.99692942e+04, 9.99808429e+04, 9.99870120e+04]) * ms] """ # Convert inputs to lists if not isinstance(spiketrains, list): spiketrains = [spiketrains] if not isinstance(hilbert_transform, list): hilbert_transform = [hilbert_transform] # Number of signals num_spiketrains = len(spiketrains) num_phase = len(hilbert_transform) if num_spiketrains != 1 and num_phase != 1 and \ num_spiketrains != num_phase: raise ValueError( "Number of spike trains and number of phase signals" "must match, or either of the two must be a single signal.") # For each trial, select the first input start = [elem.t_start for elem in hilbert_transform] stop = [elem.t_stop for elem in hilbert_transform] result_phases = [] result_amps = [] result_times = [] # Step through each signal for spiketrain_i, spiketrain in enumerate(spiketrains): # Check which hilbert_transform AnalogSignal to look at - if there is # only one then all spike trains relate to this one, otherwise the two # lists of spike trains and phases are matched up if num_phase > 1: phase_i = spiketrain_i else: phase_i = 0 # Take only spikes which lie directly within the signal segment - # ignore spikes sitting on the last sample sttimeind = np.where(np.logical_and( spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0] # Extract times for speed reasons times = hilbert_transform[phase_i].times # Find index into signal for each spike ind_at_spike = ( (spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) / hilbert_transform[phase_i].sampling_period). \ simplified.magnitude.astype(int) # Append new list to the results for this spiketrain result_phases.append([]) result_amps.append([]) result_times.append([]) # Step through all spikes for spike_i, ind_at_spike_j in enumerate(ind_at_spike): if interpolate and ind_at_spike_j+1 < len(times): # Get relative spike occurrence between the two closest signal # sample points # if z->0 spike is more to the left sample # if z->1 more to the right sample z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\ hilbert_transform[phase_i].sampling_period # Save hilbert_transform (interpolate on circle) p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j] ).item() p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1] ).item() interpolation = (1 - z) * np.exp(complex(0, p1)) \ + z * np.exp(complex(0, p2)) p12 = np.angle([interpolation]) result_phases[spiketrain_i].append(p12) # Save amplitude result_amps[spiketrain_i].append( (1 - z) * np.abs( hilbert_transform[phase_i][ind_at_spike_j]) + z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1])) else: p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j]) result_phases[spiketrain_i].append(p1) # Save amplitude result_amps[spiketrain_i].append( np.abs(hilbert_transform[phase_i][ind_at_spike_j])) # Save time result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]]) # Convert outputs to arrays for i, entry in enumerate(result_phases): result_phases[i] = np.array(entry).flatten() for i, entry in enumerate(result_amps): result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten() for i, entry in enumerate(result_times): result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten() return result_phases, result_amps, result_times
[docs] def phase_locking_value(phases_i, phases_j): r""" Calculates the phase locking value (PLV) :cite:`phase-Lachaux99_194`. This function expects the phases of two signals (each containing multiple trials). For each trial pair, it calculates the phase difference at each time point. Then it calculates the mean vectors of those phase differences across all trials. The PLV at time `t` is the length of the corresponding mean vector. Parameters ---------- phases_i, phases_j : (t, n) np.ndarray Time-series of the first and second signals, with `t` time points and `n` trials. Returns ------- plv : (t,) np.ndarray Vector of floats with the phase-locking value at each time point. Range: :math:`[0, 1]` Raises ------ ValueError If the shapes of `phases_i` and `phases_j` are different. Notes ----- This implementation is based on the formula taken from [1] (pp. 195): .. math:: PLV_t = \frac{1}{N} \left | \sum_{n=1}^N \exp(i \cdot \theta(t, n)) \right | \\ where :math:`\theta(t, n) = \phi_x(t, n) - \phi_y(t, n)` is the phase difference at time `t` for trial `n`. """ if np.shape(phases_i) != np.shape(phases_j): raise ValueError("trial number and trial length of signal x and y " "must be equal") # trial by trial and time-resolved # version 0.2: signal x and y have multiple trials # with discrete values/phases phase_diff = phase_difference(phases_i, phases_j) theta, r = mean_phase_vector(phase_diff, axis=0) return r
[docs] def mean_phase_vector(phases, axis=0): r""" Calculates the mean vector of phases. This function expects phases (in radians) and uses their representation as complex numbers to calculate the direction :math:`\theta` and the length `r` of the mean vector. Parameters ---------- phases : np.ndarray Phases in radians. axis : int, optional Axis along which the mean vector will be calculated. If None, it will be computed across the flattened array. Default: 0 Returns ------- z_mean_theta : np.ndarray Angle of the mean vector. Range: :math:`(-\pi, \pi]` z_mean_r : np.ndarray Length of the mean vector. Range: :math:`[0, 1]` """ # use complex number representation # z_phases = np.cos(phases) + 1j * np.sin(phases) z_phases = np.exp(1j * np.asarray(phases)) z_mean = np.mean(z_phases, axis=axis) z_mean_theta = np.angle(z_mean) z_mean_r = np.abs(z_mean) return z_mean_theta, z_mean_r
[docs] def phase_difference(alpha, beta): r""" Calculates the difference between a pair of phases. The output is in range from :math:`-\pi` to :math:`\pi`. Parameters ---------- alpha : np.ndarray Phases in radians. beta : np.ndarray Phases in radians. Returns ------- phase_diff : np.ndarray Difference between phases `alpha` and `beta`. Range: :math:`[-\pi, \pi]` Notes ----- The usage of `np.arctan2` ensures that the range of the phase difference is :math:`[-\pi, \pi]` and is located in the correct quadrant. """ delta = alpha - beta phase_diff = np.arctan2(np.sin(delta), np.cos(delta)) return phase_diff
[docs] def weighted_phase_lag_index(signal_i, signal_j, sampling_frequency=None, absolute_value=True): r""" Calculates the Weigthed Phase-Lag Index (WPLI) :cite:`phase-Vinck11_1548`. This function estimates the WPLI, which is a measure of phase-synchrony. It describes for two given signals i and j, which is leading/lagging the other signal in the frequency domain across multiple trials. Parameters ---------- signal_i, signal_j : np.array, pq.quantity.Quantity, neo.AnalogSignal Time-series of the first and second signals, with `t` time points and `n` trials. sampling_frequency : pq.quantity.Quantity (default: None) Sampling frequency of the signals in Hz. Not needed if signal i and j are neo.AnalogSignals. absolute_value : boolean (default: True) Takes the absolute value of the numerator in the WPLI-formula. When set to `False`, the WPLI contains additional directionality information about which signal leads/lags the other signal: * wpli > 0 : first signal i leads second signal j * wpli < 0 : first signal i lags second signal j Returns ------- freqs : pq.quantity.Quantity Positive frequencies in Hz associated with the estimates of `wpli`. Range: :math:`[0, sampling frequency/2]` wpli : np.ndarray with dtype=float Weighted phase-lag index of `signal_i` and `signal_j` across trials. Range: :math:`[0, 1]` Raises ------ ValueError If trial number or trial length are different for signal i and j. Notes ----- This implementation is based on the formula taken from :cite:`phase-Vinck11_1548` (pp.1550, equation (8)) : .. math:: WPLI = \frac{| E( |Im(X)| * sgn(Im(X)) ) |}{E( |Im(X)| )} with: * :math:`E{...}` : expected value operator * :math:`Im{X}` : imaginary component of the cross-spectrum * :math:`X = Z_i Z_{j}^{*}` : cross-spectrum, averaged across trials * :math:`Z_i, Z_j`: complex-valued matrix, representing the Fourier spectra of a particular frequency of the signals i and j. """ if isinstance(signal_i, neo.AnalogSignal) and \ isinstance(signal_j, neo.AnalogSignal): # neo.AnalogSignal input if signal_i.sampling_rate.rescale("Hz") != \ signal_j.sampling_rate.rescale("Hz"): raise ValueError("sampling rate of signal i and j must be equal") sampling_frequency = signal_i.sampling_rate signal_i = signal_i.magnitude signal_j = signal_j.magnitude else: # np.array() or Quantity input if sampling_frequency is None: raise ValueError("sampling frequency must be given for np.array or" "Quantity input") if np.shape(signal_i) != np.shape(signal_j): if len(signal_i) != len(signal_j): raise ValueError("trial number of signal i and j must be equal") raise ValueError("trial length of signal i and j must be equal") # calculate Fourier transforms fft1 = np.fft.rfft(signal_i) fft2 = np.fft.rfft(signal_j) freqs = np.fft.rfftfreq(np.shape(signal_i)[1], d=1.0 / sampling_frequency) # obtain cross-spectrum cs = fft1 * np.conjugate(fft2) # calculate WPLI wpli_num = np.mean(np.abs(np.imag(cs)) * np.sign(np.imag(cs)), axis=0) if absolute_value: wpli_num = np.abs(wpli_num) wpli_den = np.mean(np.abs(np.imag(cs)), axis=0) wpli = wpli_num / wpli_den return freqs, wpli