Source code for elephant.phase_analysis
# -*- coding: utf-8 -*-
"""
Methods for performing phase analysis.
:copyright: Copyright 2014-2018 by the Elephant team, see `doc/authors.rst`.
:license: Modified BSD, see LICENSE.txt for details.
"""
from __future__ import division, print_function, unicode_literals
import numpy as np
import quantities as pq
[docs]def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
"""
Calculate the set of spike-triggered phases of a `neo.AnalogSignal`.
Parameters
----------
hilbert_transform : neo.AnalogSignal or list of neo.AnalogSignal
`neo.AnalogSignal` of the complex analytic signal (e.g., returned by
the `elephant.signal_processing.hilbert` function).
If `hilbert_transform` is only one signal, all spike trains are
compared to this signal. Otherwise, length of `hilbert_transform` must
match the length of `spiketrains`.
spiketrains : neo.SpikeTrain or list of neo.SpikeTrain
Spike trains on which to trigger `hilbert_transform` extraction.
interpolate : bool
If True, the phases and amplitudes of `hilbert_transform` for spikes
falling between two samples of signal is interpolated.
If False, the closest sample of `hilbert_transform` is used.
Returns
-------
phases : list of np.ndarray
Spike-triggered phases. Entries in the list correspond to the
`neo.SpikeTrain`s in `spiketrains`. Each entry contains an array with
the spike-triggered angles (in rad) of the signal.
amp : list of pq.Quantity
Corresponding spike-triggered amplitudes.
times : list of pq.Quantity
A list of times corresponding to the signal. They correspond to the
times of the `neo.SpikeTrain` referred by the list item.
Raises
------
ValueError
If the number of spike trains and number of phase signals don't match,
and neither of the two are a single signal.
Examples
--------
Create a 20 Hz oscillatory signal sampled at 1 kHz and a random Poisson
spike train, then calculate spike-triggered phases and amplitudes of the
oscillation:
>>> import neo
>>> import elephant
>>> import quantities as pq
>>> import numpy as np
...
>>> f_osc = 20. * pq.Hz
>>> f_sampling = 1 * pq.ms
>>> tlen = 100 * pq.s
...
>>> time_axis = np.arange(
... 0, tlen.magnitude,
... f_sampling.rescale(pq.s).magnitude) * pq.s
>>> analogsignal = neo.AnalogSignal(
... np.sin(2 * np.pi * (f_osc * time_axis).simplified.magnitude),
... units=pq.mV, t_start=0*pq.ms, sampling_period=f_sampling)
>>> spiketrain = (elephant.spike_train_generation.
... homogeneous_poisson_process(
... 50 * pq.Hz, t_start=0.0*pq.ms, t_stop=tlen.rescale(pq.ms)))
...
>>> phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
... elephant.signal_processing.hilbert(analogsignal),
... spiketrain,
... interpolate=True)
"""
# Convert inputs to lists
if not isinstance(spiketrains, list):
spiketrains = [spiketrains]
if not isinstance(hilbert_transform, list):
hilbert_transform = [hilbert_transform]
# Number of signals
num_spiketrains = len(spiketrains)
num_phase = len(hilbert_transform)
if num_spiketrains != 1 and num_phase != 1 and \
num_spiketrains != num_phase:
raise ValueError(
"Number of spike trains and number of phase signals"
"must match, or either of the two must be a single signal.")
# For each trial, select the first input
start = [elem.t_start for elem in hilbert_transform]
stop = [elem.t_stop for elem in hilbert_transform]
result_phases = []
result_amps = []
result_times = []
# Step through each signal
for spiketrain_i, spiketrain in enumerate(spiketrains):
# Check which hilbert_transform AnalogSignal to look at - if there is
# only one then all spike trains relate to this one, otherwise the two
# lists of spike trains and phases are matched up
if num_phase > 1:
phase_i = spiketrain_i
else:
phase_i = 0
# Take only spikes which lie directly within the signal segment -
# ignore spikes sitting on the last sample
sttimeind = np.where(np.logical_and(
spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0]
# Find index into signal for each spike
ind_at_spike = np.round(
(spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) /
hilbert_transform[phase_i].sampling_period). \
simplified.magnitude.astype(int)
# Extract times for speed reasons
times = hilbert_transform[phase_i].times
# Append new list to the results for this spiketrain
result_phases.append([])
result_amps.append([])
result_times.append([])
# Step through all spikes
for spike_i, ind_at_spike_j in enumerate(ind_at_spike):
# Difference vector between actual spike time and sample point,
# positive if spike time is later than sample point
dv = spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]
# Make sure ind_at_spike is to the left of the spike time
if dv < 0 and ind_at_spike_j > 0:
ind_at_spike_j = ind_at_spike_j - 1
if interpolate:
# Get relative spike occurrence between the two closest signal
# sample points
# if z->0 spike is more to the left sample
# if z->1 more to the right sample
z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\
hilbert_transform[phase_i].sampling_period
# Save hilbert_transform (interpolate on circle)
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1])
result_phases[spiketrain_i].append(
np.angle(
(1 - z) * np.exp(np.complex(0, p1)) +
z * np.exp(np.complex(0, p2))))
# Save amplitude
result_amps[spiketrain_i].append(
(1 - z) * np.abs(
hilbert_transform[phase_i][ind_at_spike_j]) +
z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1]))
else:
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
result_phases[spiketrain_i].append(p1)
# Save amplitude
result_amps[spiketrain_i].append(
np.abs(hilbert_transform[phase_i][ind_at_spike_j]))
# Save time
result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]])
# Convert outputs to arrays
for i, entry in enumerate(result_phases):
result_phases[i] = np.array(entry).flatten()
for i, entry in enumerate(result_amps):
result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten()
for i, entry in enumerate(result_times):
result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten()
return result_phases, result_amps, result_times