Source code for elephant.utils

"""
.. autosummary::
    :toctree: _toctree/utils

    is_time_quantity
    get_common_start_stop_times
    check_neo_consistency
    check_same_units
    round_binning_errors
"""

from __future__ import division, print_function, unicode_literals

import ctypes
import warnings
from functools import wraps

import neo
import numpy as np
import quantities as pq


__all__ = [
    "deprecated_alias",
    "is_binary",
    "is_time_quantity",
    "get_common_start_stop_times",
    "check_neo_consistency",
    "check_same_units",
    "round_binning_errors"
]


def is_binary(array):
    """
    Parameters
    ----------
    array : np.ndarray or list

    Returns
    -------
    bool
        Whether the input array is binary or not.

    """
    array = np.asarray(array)
    return ((array == 0) | (array == 1)).all()


def deprecated_alias(**aliases):
    """
    A deprecation decorator constructor.

    Parameters
    ----------
    **aliases
        The key-value pairs of mapping old --> new argument names of a
        function.

    Returns
    -------
    callable
        A decorator for the specific mapping of deprecated argument names.

    Examples
    --------
    In the example below, `my_function(binsize)` signature is marked as
    deprecated (but still usable) and changed to `my_function(bin_size)`.

    >>> @deprecated_alias(binsize='bin_size')
    ... def my_function(bin_size):
    ...     pass

    """
    def deco(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            _rename_kwargs(func.__name__, kwargs, aliases)
            return func(*args, **kwargs)

        return wrapper

    return deco


def _rename_kwargs(func_name, kwargs, aliases):
    for old, new in aliases.items():
        if old in kwargs:
            if new in kwargs:
                raise TypeError(f"{func_name} received both '{old}' and "
                                f"'{new}'")
            warnings.warn(f"'{old}' is deprecated; use '{new}'",
                          DeprecationWarning)
            kwargs[new] = kwargs.pop(old)


[docs]def is_time_quantity(*quantities, allow_none=False): """ Parameters ---------- *quantities : pq.Quantity A scalar or array-like to check for being a Quantity with time units. allow_none : bool, optional Allow the input to be None or not. Default: False Returns ------- bool Whether the input is a time Quantity (True) or not (False). If the input is None and `allow_none` is set to True, returns True. """ for quantity in quantities: if allow_none and quantity is None: continue if not isinstance(quantity, pq.Quantity): return False if quantity.dimensionality.simplified != pq.s.dimensionality: return False return True
[docs]def get_common_start_stop_times(neo_objects): """ Extracts the common `t_start` and the `t_stop` from the input neo objects. If a single neo object is given, its `t_start` and `t_stop` is returned. Otherwise, the aligned times are returned: the maximal `t_start` and minimal `t_stop` across `neo_objects`. Parameters ---------- neo_objects : neo.SpikeTrain or neo.AnalogSignal or list A neo object or a list of neo objects that have `t_start` and `t_stop` attributes. Returns ------- t_start, t_stop : pq.Quantity Shared start and stop times. Raises ------ AttributeError If the input neo objects do not have `t_start` and `t_stop` attributes. ValueError If there is no shared interval ``[t_start, t_stop]`` across the input neo objects. """ if hasattr(neo_objects, 't_start') and hasattr(neo_objects, 't_stop'): return neo_objects.t_start, neo_objects.t_stop try: t_start = max(elem.t_start for elem in neo_objects) t_stop = min(elem.t_stop for elem in neo_objects) except AttributeError: raise AttributeError("Input neo objects must have 't_start' and " "'t_stop' attributes") if t_stop < t_start: raise ValueError(f"t_stop ({t_stop}) is smaller than t_start " f"({t_start})") return t_start, t_stop
[docs]def check_neo_consistency(neo_objects, object_type, t_start=None, t_stop=None, tolerance=1e-8): """ Checks that all input neo objects share the same units, t_start, and t_stop. Parameters ---------- neo_objects : list of neo.SpikeTrain or neo.AnalogSignal A list of neo spike trains or analog signals. object_type : type The common type. t_start, t_stop : pq.Quantity or None, optional If None, check for exact match of t_start/t_stop across the input. tolerance : float, optional The absolute affordable tolerance for the discrepancies between t_start/stop magnitude values across trials. Default : 1e-6 Raises ------ TypeError If input objects are not instances of the specified `object_type`. ValueError If input object units, t_start, or t_stop do not match across trials. """ if not isinstance(neo_objects, (list, tuple)): neo_objects = [neo_objects] try: units = neo_objects[0].units start = neo_objects[0].t_start.item() stop = neo_objects[0].t_stop.item() except (IndexError, AttributeError): raise TypeError(f"The input must be a list of {object_type.__name__}") if not is_time_quantity(t_start, t_stop, allow_none=True): raise TypeError("'t_start' and 't_stop' must be time quantities.") if tolerance is None: tolerance = 0 for neo_obj in neo_objects: if not isinstance(neo_obj, object_type): raise TypeError("The input must be a list of " f"{object_type.__name__}. Got " f"{type(neo_obj).__name__}") if neo_obj.units != units: raise ValueError("The input must have the same units.") if t_start is None and abs(neo_obj.t_start.item() - start) > tolerance: raise ValueError("The input must have the same t_start.") if t_stop is None and abs(neo_obj.t_stop.item() - stop) > tolerance: raise ValueError("The input must have the same t_stop.")
[docs]def check_same_units(quantities, object_type=pq.Quantity): """ Check that all input quantities are of the same type and share common units. Raise an error if the check is unsuccessful. Parameters ---------- quantities : list of pq.Quantity or pq.Quantity A list of quantities, neo objects or a single neo object. object_type : type, optional The common type. Default: pq.Quantity Raises ------ TypeError If input objects are not instances of the specified `object_type`. ValueError If input objects do not share common units. """ if not isinstance(quantities, (list, tuple)): quantities = [quantities] try: units = quantities[0].units except (IndexError, AttributeError): raise TypeError(f"The input must be a list of {object_type.__name__}") for quantity in quantities: if not isinstance(quantity, object_type): raise TypeError("The input must be a list of " f"{object_type.__name__}. Got " f"{type(quantity).__name__}") if quantity.units != units: raise ValueError("The input quantities must have the same units, " "which is achieved with object.rescale('ms') " "operation.")
[docs]def round_binning_errors(values, tolerance=1e-8): """ Round the input `values` in-place due to the machine floating point precision errors. Parameters ---------- values : np.ndarray or float An input array or a scalar. tolerance : float or None, optional The precision error absolute tolerance; acts as ``atol`` in :func:`numpy.isclose` function. If None, no rounding is performed. Default: 1e-8 Returns ------- values : np.ndarray or int Corrected integer values. Examples -------- >>> from elephant.utils import round_binning_errors >>> round_binning_errors(0.999999, tolerance=None) 0 >>> round_binning_errors(0.999999, tolerance=1e-6) 1 """ if tolerance is None or tolerance == 0: if isinstance(values, np.ndarray): return values.astype(np.int32) return int(values) # a scalar # same as '1 - (values % 1) <= tolerance' but faster correction_mask = 1 - tolerance <= values % 1 if isinstance(values, np.ndarray): num_corrections = correction_mask.sum() if num_corrections > 0: warnings.warn(f'Correcting {num_corrections} rounding errors by ' f'shifting the affected spikes into the following ' f'bin. You can set tolerance=None to disable this ' 'behaviour.') values[correction_mask] += 0.5 return values.astype(np.int32) if correction_mask: warnings.warn('Correcting a rounding error in the calculation ' 'of the number of bins by incrementing the value by 1. ' 'You can set tolerance=None to disable this ' 'behaviour.') values += 0.5 return int(values)
def get_cuda_capability_major(): """ Extracts CUDA capability major version of the first available Nvidia GPU card, if detected. Otherwise, return 0. Returns ------- int CUDA capability major version. """ cuda_success = 0 for libname in ('libcuda.so', 'libcuda.dylib', 'cuda.dll'): try: cuda = ctypes.CDLL(libname) except OSError: continue else: break else: # not found return 0 result = cuda.cuInit(0) if result != cuda_success: return 0 device = ctypes.c_int() # parse the first GPU card only result = cuda.cuDeviceGet(ctypes.byref(device), 0) if result != cuda_success: return 0 cc_major = ctypes.c_int() cc_minor = ctypes.c_int() cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) return cc_major.value