"""
.. autosummary::
:toctree: _toctree/utils
is_time_quantity
get_common_start_stop_times
check_neo_consistency
check_same_units
round_binning_errors
"""
from __future__ import division, print_function, unicode_literals
import ctypes
import logging
import warnings
from functools import wraps
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
from elephant.trials import Trials
__all__ = [
"deprecated_alias",
"is_binary",
"is_time_quantity",
"get_common_start_stop_times",
"check_neo_consistency",
"check_same_units",
"round_binning_errors",
]
# Create logger and set configuration
logger = logging.getLogger(__file__)
log_handler = logging.StreamHandler()
log_handler.setFormatter(
logging.Formatter(
f"[%(asctime)s] {__name__[__name__.rfind('.')+1::]} -"
" %(levelname)s: %(message)s"
)
)
logger.addHandler(log_handler)
logger.propagate = False
def is_binary(array):
"""
Parameters
----------
array : np.ndarray or list
Returns
-------
bool
Whether the input array is binary or not.
"""
array = np.asarray(array)
return ((array == 0) | (array == 1)).all()
def deprecated_alias(**aliases):
"""
A deprecation decorator constructor.
Parameters
----------
**aliases
The key-value pairs of mapping old --> new argument names of a
function.
Returns
-------
callable
A decorator for the specific mapping of deprecated argument names.
Examples
--------
In the example below, `my_function(binsize)` signature is marked as
deprecated (but still usable) and changed to `my_function(bin_size)`.
>>> @deprecated_alias(binsize='bin_size')
... def my_function(bin_size):
... pass
"""
def deco(func):
@wraps(func)
def wrapper(*args, **kwargs):
_rename_kwargs(func.__name__, kwargs, aliases)
return func(*args, **kwargs)
return wrapper
return deco
def _rename_kwargs(func_name, kwargs, aliases):
for old, new in aliases.items():
if old in kwargs:
if new in kwargs:
raise TypeError(
f"{func_name} received both '{old}' and " f"'{new}'"
)
warnings.warn(
f"'{old}' is deprecated; use '{new}'", DeprecationWarning
)
kwargs[new] = kwargs.pop(old)
[docs]
def is_time_quantity(*quantities, allow_none=False):
"""
Parameters
----------
*quantities : pq.Quantity
A scalar or array-like to check for being a Quantity with time units.
allow_none : bool, optional
Allow the input to be None or not.
Default: False
Returns
-------
bool
Whether the input is a time Quantity (True) or not (False).
If the input is None and `allow_none` is set to True, returns True.
"""
for quantity in quantities:
if allow_none and quantity is None:
continue
if not isinstance(quantity, pq.Quantity):
return False
if quantity.dimensionality.simplified != pq.s.dimensionality:
return False
return True
[docs]
def get_common_start_stop_times(neo_objects):
"""
Extracts the common `t_start` and the `t_stop` from the input neo objects.
If a single neo object is given, its `t_start` and `t_stop` is returned.
Otherwise, the aligned times are returned: the maximal `t_start` and
minimal `t_stop` across `neo_objects`.
Parameters
----------
neo_objects : neo.SpikeTrain or neo.AnalogSignal or list
A neo object or a list of neo objects that have `t_start` and `t_stop`
attributes.
Returns
-------
t_start, t_stop : pq.Quantity
Shared start and stop times.
Raises
------
AttributeError
If the input neo objects do not have `t_start` and `t_stop` attributes.
ValueError
If there is no shared interval ``[t_start, t_stop]`` across the input
neo objects.
"""
if hasattr(neo_objects, "t_start") and hasattr(neo_objects, "t_stop"):
return neo_objects.t_start, neo_objects.t_stop
try:
t_start = max(elem.t_start for elem in neo_objects)
t_stop = min(elem.t_stop for elem in neo_objects)
except AttributeError:
raise AttributeError(
"Input neo objects must have 't_start' and " "'t_stop' attributes"
)
if t_stop < t_start:
raise ValueError(
f"t_stop ({t_stop}) is smaller than t_start " f"({t_start})"
)
return t_start, t_stop
[docs]
def check_neo_consistency(
neo_objects, object_type, t_start=None, t_stop=None, tolerance=1e-8
):
"""
Checks that all input neo objects share the same units, t_start, and
t_stop.
Parameters
----------
neo_objects : list of neo.SpikeTrain or neo.AnalogSignal
A list of neo spike trains or analog signals.
object_type : type
The common type.
t_start, t_stop : pq.Quantity or None, optional
If None, check for exact match of t_start/t_stop across the input.
tolerance : float, optional
The absolute affordable tolerance for the discrepancies between
t_start/stop magnitude values across trials.
Default : 1e-6
Raises
------
TypeError
If input objects are not instances of the specified `object_type`.
ValueError
If input object units, t_start, or t_stop do not match across trials.
"""
if not isinstance(neo_objects, (list, tuple, SpikeTrainList)):
neo_objects = [neo_objects]
try:
units = neo_objects[0].units
start = neo_objects[0].t_start.item()
stop = neo_objects[0].t_stop.item()
except (IndexError, AttributeError):
raise TypeError(f"The input must be a list of {object_type.__name__}")
if not is_time_quantity(t_start, t_stop, allow_none=True):
raise TypeError("'t_start' and 't_stop' must be time quantities.")
if tolerance is None:
tolerance = 0
for neo_obj in neo_objects:
if not isinstance(neo_obj, object_type):
raise TypeError(
"The input must be a list of "
f"{object_type.__name__}. Got "
f"{type(neo_obj).__name__}"
)
if neo_obj.units != units:
raise ValueError("The input must have the same units.")
if t_start is None and abs(neo_obj.t_start.item() - start) > tolerance:
raise ValueError("The input must have the same t_start.")
if t_stop is None and abs(neo_obj.t_stop.item() - stop) > tolerance:
raise ValueError("The input must have the same t_stop.")
[docs]
def check_same_units(quantities, object_type=pq.Quantity):
"""
Check that all input quantities are of the same type and share common
units. Raise an error if the check is unsuccessful.
Parameters
----------
quantities : list of pq.Quantity or pq.Quantity
A list of quantities, neo objects or a single neo object.
object_type : type, optional
The common type.
Default: pq.Quantity
Raises
------
TypeError
If input objects are not instances of the specified `object_type`.
ValueError
If input objects do not share common units.
"""
if not isinstance(quantities, (list, tuple)):
quantities = [quantities]
try:
units = quantities[0].units
except (IndexError, AttributeError):
raise TypeError(f"The input must be a list of {object_type.__name__}")
for quantity in quantities:
if not isinstance(quantity, object_type):
raise TypeError(
"The input must be a list of "
f"{object_type.__name__}. Got "
f"{type(quantity).__name__}"
)
if quantity.units != units:
raise ValueError(
"The input quantities must have the same units, "
"which is achieved with object.rescale('ms') "
"operation."
)
[docs]
def round_binning_errors(values, tolerance=1e-8):
"""
Round the input `values` in-place due to the machine floating point
precision errors.
Parameters
----------
values : np.ndarray or float
An input array or a scalar.
tolerance : float or None, optional
The precision error absolute tolerance; acts as ``atol`` in
:func:`numpy.isclose` function. If None, no rounding is performed.
Default: 1e-8
Returns
-------
values : np.ndarray or int
Corrected integer values.
Examples
--------
>>> from elephant.utils import round_binning_errors
>>> round_binning_errors(0.999999, tolerance=None)
0
>>> round_binning_errors(0.999999, tolerance=1e-6)
1
"""
if tolerance is None or tolerance == 0:
if isinstance(values, np.ndarray):
return values.astype(np.int32)
return int(values) # a scalar
# same as '1 - (values % 1) <= tolerance' but faster
correction_mask = 1 - tolerance <= values % 1
if isinstance(values, np.ndarray):
num_corrections = correction_mask.sum()
if num_corrections > 0:
logger.warning(
f"Correcting {num_corrections} rounding errors by "
"shifting the affected spikes into the following "
"bin. You can set tolerance=None to disable this "
"behaviour."
)
values[correction_mask] += 0.5
return values.astype(np.int32)
if correction_mask:
logger.warning(
"Correcting a rounding error in the calculation "
"of the number of bins by incrementing the value by 1. "
"You can set tolerance=None to disable this "
"behaviour."
)
values += 0.5
return int(values)
def get_cuda_capability_major():
"""
Extracts CUDA capability major version of the first available Nvidia GPU
card, if detected. Otherwise, return 0.
Returns
-------
int
CUDA capability major version.
"""
cuda_success = 0
for libname in ("libcuda.so", "libcuda.dylib", "cuda.dll"):
try:
cuda = ctypes.CDLL(libname)
except OSError:
continue
else:
break
else:
# not found
return 0
result = cuda.cuInit(0)
if result != cuda_success:
return 0
device = ctypes.c_int()
# parse the first GPU card only
result = cuda.cuDeviceGet(ctypes.byref(device), 0)
if result != cuda_success:
return 0
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
cuda.cuDeviceComputeCapability(
ctypes.byref(cc_major), ctypes.byref(cc_minor), device
)
return cc_major.value
def get_opencl_capability():
"""
Return a list of available OpenCL devices.
Returns
-------
bool
True: if openCL platform detected and at least one device is found,
False: if OpenCL is not found or if no OpenCL devices are found
"""
try:
import pyopencl
platforms = pyopencl.get_platforms()
if len(platforms) == 0:
return False
# len(platforms) is > 0, if it is not == 0
return True
except ImportError:
return False
def trials_to_list_of_spiketrainlist(method):
"""
Decorator to convert `Trials` object to a list of `SpikeTrainList` before
calling the wrapped method.
Parameters
----------
method: callable
The method to be decorated.
Returns
-------
callable:
The decorated method.
Examples
--------
The decorator can be used as follows:
>>> @trials_to_list_of_spiketrainlist
... def process_data(self, spiketrains):
... return None
"""
@wraps(method)
def wrapper(*args, **kwargs):
new_args = tuple(
[
arg.get_spiketrains_from_trial_as_list(idx)
for idx in range(arg.n_trials)
]
if isinstance(arg, Trials)
else arg
for arg in args
)
new_kwargs = {
key: (
[
value.get_spiketrains_from_trial_as_list(idx)
for idx in range(value.n_trials)
]
if isinstance(value, Trials)
else value
)
for key, value in kwargs.items()
}
return method(*new_args, **new_kwargs)
return wrapper