Source code for elephant.datasets

"""
This module provides a simple interface to load data used in some Elephant
tutorials and that can also be used as example neuronal activity data for
learning, testing and development purposes.

Each dataset can be loaded by the :func:`load_data` function by providing a
specific dataset name. Some are downloaded from the official `elephant-data`
G-Node GIN `repository <http://datasets.python-elephant.org>`_ and others are
generated by simulation. For downloaded files, the function also takes care of
any loading needed, directly returning the relevant data objects. For
generated data, the required parameterization is already defined.

Currently, the following data are available for loading:

    - `asset`: a dataset containing the activity of 500 parallel spike trains
      from a simulation of synfire chains that are activated repeatedly
      :cite:`Schrader08_2165`. The data is loaded as a single
      :class:`neo.core.Segment` object, containing the 500
      :class:`neo.core.SpikeTrain` objects with the spiking activity of each
      neuron (accessible by the `.spiketrains` attribute). This dataset is
      used in the :doc:`tutorial </tutorials/asset>` for the Analysis of
      Sequences of Synchronous EvenTs (ASSET) method :cite:`Torre16_e1004939`.

    - `granger_causality_indirect`: a dataset of three simulated time series
      X, Y, and Z, with indirect causal influence from Y to X through Z
      (Y -> Z -> X). Each time series has 10000 sample points. The data is
      loaded as a `(10000, 3)` :class:`numpy.ndarray` object. The second
      dimension is the time series dimension, where each column corresponds to
      one of the three time series, ordered as X, Y, Z. This data is used in
      the :doc:`tutorial </tutorials/granger_causality>` for the conditional
      Granger causality method and it is recreated from Example 2 section 5.2
      of :cite:t:`Ding06_0608035`.

    - `granger_causality_both`: a dataset of three simulated time series X, Y,
      and Z, with both direct and indirect causal influences from Y to X
      (Y -> X and Y -> Z -> X). Each time series has 10000 sample points. The
      data is loaded as a `(10000, 3)` :class:`numpy.ndarray` object. The
      second dimension is the time series dimension, where each column
      corresponds to one of the three time series, ordered as X, Y, Z. This
      data is used in the :doc:`tutorial </tutorials/granger_causality>` for
      the conditional Granger causality method and it is recreated from
      Example 2 section 5.2 of :cite:t:`Ding06_0608035`.

    - `unitary_events`: a dataset containing the simultaneously recorded
      activities of two neurons in the primary motor cortex of monkeys
      performing a delayed-pointing task :cite:`Riehle97_1950`. The neuronal
      activity is recorded across 36 trials. The dataset consists of a list
      with 36 inner lists (one per trial). Each trial list contains two
      :class:`neo.core.SpikeTrain` objects storing the spike times of each
      neuron. This dataset is used in the
      :doc:`tutorial </tutorials/unitary_event_analysis>` for the Unitary
      Event Analysis method :cite:`Gruen99_67`, which detects the coordinated
      spiking activity that occurs significantly more often than predicted by
      the firing rates of the neurons.

.. autosummary::
    :toctree: _toctree/datasets

    load_data

:copyright: Copyright 2014-2026 by the Elephant team, see `doc/authors.rst`.
:license: Modified BSD, see LICENSE.txt for details.
"""

import hashlib
import os
import shutil
import ssl
import tempfile
from urllib.parse import urlparse
import warnings
from os import getenv

from pathlib import Path
from urllib.error import HTTPError, URLError
from urllib.request import urlopen, urlretrieve
from zipfile import ZipFile
from functools import partial

from tqdm import tqdm

from elephant import _get_version
import numpy as np
import neo


__all__= ["load_data"]


ELEPHANT_TMP_DIR = Path(tempfile.gettempdir()) / "elephant"


# Data generation functions

def generate_conditional_granger_ground_truth(length_2d=30000,
                                              causality_type="indirect"):
    """
    Recreated from Example 2 section 5.2 of :cite:'granger-Ding06-0608035'.
    The following should generate three signals in one of the two ways:
     1. "indirect" would generate data which contains no direct
    causal influence from Y to X, but mediated through Z
    (i.e. Y -> Z -> X).
    2. "both" would generate data which contains both direct and indirect
    causal influences from Y to X.
    """
    if causality_type == "indirect":
        y_t_lag_2 = 0
    elif causality_type == "both":
        y_t_lag_2 = 0.2
    else:
        raise ValueError("causality_type should be either 'indirect' or "
                         "'both'")

    order = 2
    signal = np.zeros((3, length_2d + order))

    weights_1 = np.array([[0.8, 0, 0.4],
                          [0, 0.9, 0],
                          [0., 0.5, 0.5]])

    weights_2 = np.array([[-0.5, y_t_lag_2, 0.],
                          [0., -0.8, 0],
                          [0, 0, -0.2]])

    weights = np.stack((weights_1, weights_2))

    noise_covariance = np.array([[0.3, 0.0, 0.0],
                                 [0.0, 1., 0.0],
                                 [0.0, 0.0, 0.2]])

    for i in range(length_2d):
        for lag in range(order):
            signal[:, i + order] += np.dot(weights[lag],
                                           signal[:, i + 1 - lag])
        rnd_var = np.random.multivariate_normal([0, 0, 0],
                                                noise_covariance)
        signal[:, i + order] += rnd_var

    signal = signal[:, 2:]

    # Return signals as Nx3
    return signal.T

def generate_pairwise_granger_ground_truth(length_2d=30000):
    order = 2
    signal = np.zeros((2, length_2d + order))

    weights_1 = np.array([[0.9, 0], [0.9, -0.8]])
    weights_2 = np.array([[-0.5, 0], [-0.2, -0.5]])

    weights = np.stack((weights_1, weights_2))

    noise_covariance = np.array([[1., 0.0], [0.0, 1.]])

    for i in range(length_2d):
        for lag in range(order):
            signal[:, i + order] += np.dot(weights[lag],
                                           signal[:, i + 1 - lag])
        rnd_var = np.random.multivariate_normal([0, 0],
                                                noise_covariance)
        signal[:, i+order] += rnd_var

    signal = signal[:, 2:]

    # Return signals as Nx2
    return signal.T


# Mapping data names to either data generation functions or dataset
# download information

ELEPHANT_DATA = {
    "asset": {
        "repo_path": "tutorials/tutorial_asset/data/asset_showcase_500.nix",
        "checksum": "d42201b83a14d85988b1a53c654472c8",
        "loader": lambda block: block.segments[0]
    },
    "granger_causality_indirect": partial(
        generate_conditional_granger_ground_truth,
        length_2d=10000,
        causality_type='indirect'),
    "granger_causality_both": partial(
        generate_conditional_granger_ground_truth,
        length_2d=10000,
        causality_type='both'),
    "unitary_events": {
        "repo_path": "tutorials/tutorial_unitary_event_analysis/data/dataset-1.nix",
        "checksum": "6449d2f4b8ae5beb1439d2b5dd03b078",
        "loader": lambda block: [[st for st in segment.spiketrains]
                                  for segment in block.segments]
    },
}


# Helper functions for downloading datasets from the `elephant-data` GIN
# repository, which is accessible at https://datasets.python-elephant.org.

class TqdmUpTo(tqdm):
    """
    Provides `update_to(n)` which uses `tqdm.update(delta_n)`.
    Original implementation:
    https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
    """

    def update_to(self, b=1, bsize=1, tsize=None):
        """
        b : int, optional
            Number of blocks transferred so far [default: 1].
        bsize : int, optional
            Size of each block (in tqdm units) [default: 1].
        tsize : int, optional
            Total size (in tqdm units). If [default: None] remains unchanged.
        """
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)  # will also set self.n = b * bsize


def calculate_md5(filepath, chunk_size=1024 * 1024):
    md5 = hashlib.md5()
    with open(filepath, 'rb') as f:
        for chunk in iter(lambda: f.read(chunk_size), b''):
            md5.update(chunk)
    return md5.hexdigest()


def check_integrity(filepath, md5):
    if not Path(filepath).exists() or md5 is None:
        return False
    return calculate_md5(filepath) == md5


def download(url, filepath=None, checksum=None, verbose=True):
    # If not explicitly given, store the file in the system's temporary
    # directory with the same name as in the URL.
    if filepath is None:
        filename = url.split('/')[-1]
        filepath = ELEPHANT_TMP_DIR / filename
    filepath = Path(filepath)

    # Check if a file with the provided checksum already exists. If that is
    # the case, skip download and directly return the path to the file.
    # If file does not exist or checksum does not match/is not given, the
    # requested file will be downloaded.
    if check_integrity(filepath, md5=checksum):
        return filepath

    # Download the file (create parent folder if it does not exist).
    folder = filepath.absolute().parent
    folder.mkdir(exist_ok=True)
    desc = f"Downloading {url} to '{filepath}'"
    with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
                  desc=desc, disable=not verbose) as t:
        try:
            urlretrieve(url, filename=filepath, reporthook=t.update_to)
        except URLError:
            # do not authenticate SSL certificate
            ssl._create_default_https_context = ssl._create_unverified_context
            urlretrieve(url, filename=filepath, reporthook=t.update_to)

    # Check integrity of the downloaded file if checksum is given
    if checksum and not check_integrity(filepath, checksum):
        raise ValueError(f"Data at {url} does not agree with MD5 hash "
                         f"{checksum}.")
    return filepath


def download_datasets(repo_path, filepath=None, checksum=None,
                      verbose=True):
    r"""
    This function can be used to download files from the `elephant-data`
    repository using only the path relative to the repository root.

    The default repository URL points to Elephant's corresponding release of
    `elephant-data` at
    `https://datasets.python-elephant.org/raw/v{version}`, where `{version}`
    is the current version of Elephant.

    Different versions of the Elephant package may require different versions
    of `elephant-data`, which can be defined using different repository URLs.
    For example:
    -  https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/v1.0.0
       points to release v1.0.0.
    -  https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/master
       always points to the latest state of `elephant-data`.

    To change this URL, use the environment variable `ELEPHANT_DATA_LOCATION`.
    This variable should be used to change the default URL when using data
    that is still not stored in the `master` branch or from a development
    branch of `elephant-data`.

    For example, to use data on branch `multitaper`, change
    `ELEPHANT_DATA_LOCATION` to
    https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper.

    To use a local copy of `elephant-data`, change `ELEPHANT_DATA_LOCATION`
    to define a local path in the system (e.g., `/home/user/elephant-data`).

    For a complete example, see Examples section.

    Parameters
    ----------
    repo_path : str
        Path to the dataset, relative to the `elephant-data` repository root
        (either the default URL pointing to the latest version or the one
        defined by the `ELEPHANT_DATA_LOCATION` environment variable).
    filepath : str or pathlib.Path, optional
        Path in the local system where the downloaded file will be stored.
        If None, the file will be stored within the system's current
        temporary directory with the same name as in `elephant-data`.
        Default: None
    checksum : str, optional
        MD5 hash of the file to use as checksum to verify data integrity of a
        file in a local folder or after its download. If None, no integrity
        check is performed. If defined and the check fails, an exception is
        raised.
        Default: None
    verbose : bool, optional
        If set to False, disable the progress bar.
        Default: True

    Returns
    -------
    filepath : pathlib.Path
        Path to access the dataset file.

    Raises
    ------
    ValueError
        If `checksum` is given and the MD5 hash of the downloaded file or
        the local copy available in a local folder is different.

        If the environment variable `ELEPHANT_DATA_LOCATION` is set but does
        not point to a valid URL or an existing file system path.

        If the environment variable `ELEPHANT_DATA_LOCATION` points to a local
        path in the system but the requested file does not exist in that path.

    Notes
    -----
    The default root repository URL always points to the latest version of
    `elephant-data`. Any changes needed for development purposes should use
    the environment variable 'ELEPHANT_DATA_LOCATION' to define the new root.

    Examples
    --------
    The following example downloads a file from branch `multitaper` in the
    `elephant-data` repository by setting the environment variable to the
    correct branch URL:

    >>> import os
    >>> from elephant.datasets import download_datasets
    >>> os.environ["ELEPHANT_DATA_LOCATION"] = "https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper" # noqa
    >>> download_datasets("unittest/spectral/multitaper_psd/data/time_series.npy") # doctest: +SKIP
    PosixPath('/tmp/elephant/time_series.npy')
    """

    # Try to get any user-defined data location from the environment variable.
    env_var = 'ELEPHANT_DATA_LOCATION'
    data_location = getenv(env_var)

    if data_location:
        # The user did set either a local path or URL to the root of
        # `elephant-data`
        if os.path.exists(data_location):
            # If `data_location` is a local path that exists, check if the
            # `repo_path` is an existing file relative to `data_location`. If
            # it exists, return it. Otherwise, raise an error.
            local_file = Path(data_location) / repo_path
            if local_file.is_file():
                # If a checksum is given, check integrity of the local file
                if checksum and not check_integrity(local_file, checksum):
                    raise ValueError(f"Local file at {local_file.as_posix()} does "
                                     "not agree with MD5 hash "
                                     f"{checksum}.")

                if filepath is not None:
                    # If specific path requested, copy the file
                    filepath = Path(filepath)
                    shutil.copy2(local_file, filepath)
                    return filepath
                return local_file

            raise ValueError(f"The environment variable {env_var} is set to "
                             f"the local path '{data_location}', but the file "
                             f"'{repo_path}' does not exist in that path.")

        if urlparse(data_location).scheme not in ('http', 'https'):
            # Check if the provided value is a valid URL. If not, raise an
            # error.
            raise ValueError(f"The environment variable {env_var} must be set "
                             "to either an existing file system path or a "
                             f"valid URL. The given value '{data_location}' "
                             "is neither.")

    else:
        # The user did not set a URL or path in `ELEPHANT_DATA_LOCATION`.
        # Use the default root URL, which redirects to the current location of
        # `elephant-data`.
        url_to_root = "https://datasets.python-elephant.org/"

        # Get the final URL to the current version of `elephant data`
        # (version of Elephant is equal to version of `elephant-data`).
        elephant_version = _get_version()
        data_location = url_to_root + f"raw/v{elephant_version}"

        try:
            # Check if that specific version URL is available by trying to
            # access a known file (README.md) in the repository.
            urlopen(data_location + '/README.md')

        except HTTPError as error:
            # If the corresponding `elephant-data` version is not found,
            # use the latest commit of `elephant-data` (`master` branch).
            # This is expected for development versions of Elephant, which may
            # not have a corresponding version of `elephant-data` yet.
            data_location = url_to_root + "raw/master"

            warnings.warn(f"No corresponding version of 'elephant-data' "
                          f"found.\nElephant version: {elephant_version}. "
                          f"Data URL:{error.url}, error: {error}.\n"
                          f"Using 'elephant-data' latest instead (This is "
                          f"expected for Elephant development versions).")

        except URLError:
            # If verification of SSL certificate fails, do not verify cert
            try:
                # Try again without certificate verification
                ctx = ssl._create_unverified_context()
                ctx.check_hostname = True
                urlopen(data_location + '/README.md', context=ctx)
            except HTTPError as unverified_ctx_error:  # e.g. 404
                # If it still fails, use latest commit of `elephant-data` in
                # the `master` branch.
                data_location = url_to_root + "raw/master"

                warnings.warn(f"Data URL: {unverified_ctx_error.url}, "
                              f"error: {unverified_ctx_error}. "
                              f"{unverified_ctx_error.reason}")

    # Get the final URL to the dataset file and download it.
    # If a checksum is given, the integrity of the downloaded file will be
    # verified after download. If the file already exists and the checksum
    # matches, the download will be skipped and the path to the existing file
    # will be returned.
    url = f"{data_location}/{repo_path}"
    return download(url, filepath, checksum, verbose)


def unzip(filepath, outdir=ELEPHANT_TMP_DIR, verbose=True):
    with ZipFile(filepath) as zfile:
        zfile.extractall(path=outdir)
    if verbose:
        print(f"Extracted {filepath} to {outdir}")


# Main function to load data for tutorials and examples.
# It checks if the requested data is defined in the `ELEPHANT_DATA`
# dictionary. If it is, it either generates the data (if the value is a
# callable function) or downloads and loads the data from the `elephant-data`
# repository.

[docs] def load_data(name): """ This function loads example data used in Elephant tutorials and examples. If the data is contained in a dataset file stored in the `elephant-data` repository (accessible at https://datasets.python-elephant.org), the correct file is automatically downloaded and loaded. Data that is not stored in datasets is generated. Parameters ---------- name: str The name of the data to load. The available data names and their content are defined in the main :doc:`datasets documentation </reference/datasets>` page. Returns ------- object The return type will vary according to the format and contents of the requested data. For example, `asset` returns a :class:`neo.core.Segment` object while `unitary_events` returns a list of lists with :class:`neo.core.SpikeTrain` objects. The detailed description is available in the main :doc:`datasets documentation </reference/datasets>` page. """ elephant_data = ELEPHANT_DATA.get(name) if not elephant_data: raise ValueError(f"Data '{name}' not available as downloadable " f"datasets or generated data.") # If data generation function, run and return if callable(elephant_data): return elephant_data() # Extract data loading function, if defined in the dataset dictionary. # Make a copy to avoid changing the default dictionary in `ELEPHANT_DATA`. elephant_data = elephant_data.copy() loader = elephant_data.pop("loader", None) # Download the dataset, ignoring version warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) dataset_path = download_datasets(**elephant_data) # Process each file format dataset_format = dataset_path.suffix if dataset_format == ".nix": with neo.NixIO(str(dataset_path), 'ro') as f: block = f.read_block() if loader: return loader(block) return block return np.load(dataset_path)