Source code for spykeutils.sorting_quality_assesment

""" Functions for estimating the quality of spike sorting results. These
functions estimate false positive and false negative fractions.
"""

from __future__ import division

import scipy as sp
from scipy.spatial.distance import cdist
import quantities as pq
import neo

from progress_indicator import ProgressIndicator
from . import SpykeException
from conversions import spikes_to_spike_train


[docs]def get_refperiod_violations(spike_trains, refperiod, progress=None): """ Return the refractory period violations in the given spike trains for the specified refractory period. :param dict spike_trains: Dictionary of lists of :class:`neo.core.SpikeTrain` objects. :param refperiod: The refractory period (time). :type refperiod: Quantity scalar :param progress: Set this parameter to report progress. :type progress: :class:`.progress_indicator.ProgressIndicator` :returns: Two values: * The total number of violations. * A dictionary (with the same indices as ``spike_trains``) of arrays with violation times (Quantity 1D with the same unit as ``refperiod``) for each spike train. :rtype: int, dict """ if type(refperiod) != pq.Quantity or \ refperiod.simplified.dimensionality != pq.s.dimensionality: raise ValueError('refperiod must be a time quantity!') if not progress: progress = ProgressIndicator() total_violations = 0 violations = {} for u, tL in spike_trains.iteritems(): violations[u] = [] for i, t in enumerate(tL): st = t.copy() st.sort() isi = sp.diff(st) violations[u].append(st[isi < refperiod].rescale(refperiod.units)) total_violations += len(violations[u][i]) progress.step() return total_violations, violations
[docs]def calculate_refperiod_fp(num_spikes, refperiod, violations, total_time): """ Return the rate of false positives calculated from refractory period calculations for each unit. The equation used is described in (Hill et al. The Journal of Neuroscience. 2011). :param dict num_spikes: Dictionary of total number of spikes, indexed by unit. :param refperiod: The refractory period (time). If the spike sorting algorithm includes a censored period (a time after a spike during which no new spikes can be found), subtract it from the refractory period before passing it to this function. :type refperiod: Quantity scalar :param dict violations: Dictionary of total number of violations, indexed the same as num_spikes. :param total_time: The total time in which violations could have occured. :type total_time: Quantity scalar :returns: A dictionary of false positive rates indexed by unit. Note that values above 0.5 can not be directly interpreted as a false positive rate! These very high values can e.g. indicate that the generating processes are not independent. """ if type(refperiod) != pq.Quantity or \ refperiod.simplified.dimensionality != pq.s.dimensionality: raise ValueError('refperiod must be a time quantity!') fp = {} factor = total_time / (2 * refperiod) for u, n in num_spikes.iteritems(): if n == 0: fp[u] = 0 continue zw = (violations[u] * factor / n ** 2).simplified if zw > 0.25: fp[u] = 0.5 + sp.sqrt(0.25 - zw).imag continue fp[u] = 0.5 - sp.sqrt(0.25 - zw) return fp
def _multi_norm(x, mean): """ Evaluate pdf of multivariate normal distribution with a mean at rows of x with high precision. """ d = x.shape[1] fac = (2 * sp.pi) ** (-d / 2.0) y = cdist(x, sp.atleast_2d(mean), 'sqeuclidean') * -0.5 return fac * sp.exp(sp.longdouble(y)) def _fast_overlap_whitened(spike_arrays, means): units = spike_arrays.keys() spikes = {u: spike_arrays[u].shape[1] for u in spike_arrays.iterkeys()} prior = {} total_spikes = 0 for u, mean in means.iteritems(): total_spikes += spikes[u] if total_spikes < 1: return {u: (0.0, 0.0) for u in units}, {} # Arrays of unnormalized posteriors (likelihood times prior) # for all units posterior = {} false_positive = {} false_negative = {} for u in units: prior[u] = spikes[u] / total_spikes false_positive[u] = 0 false_negative[u] = 0 # Calculate posteriors for u1 in units[:]: if not spikes[u1]: units.remove(u1) continue posterior[u1] = {} for u2, mean in means.iteritems(): llh = _multi_norm(spike_arrays[u1].T, mean) posterior[u1][u2] = llh * prior[u2] # Calculate pairwise false positives/negatives singles = {u: {} for u in units} for i, u1 in enumerate(units): u1 = units[i] for u2 in units[i + 1:]: f1 = sp.sum(posterior[u1][u2] / (posterior[u1][u1] + posterior[u1][u2]), dtype=sp.double) f2 = sp.sum(posterior[u2][u1] / (posterior[u2][u1] + posterior[u2][u2]), dtype=sp.double) singles[u1][u2] = (f1 / spikes[u1] if spikes[u1] else 0, f2 / spikes[u1] if spikes[u1] else 0) singles[u2][u1] = (f2 / spikes[u2] if spikes[u2] else 0, f1 / spikes[u2] if spikes[u2] else 0) # Calculate complete false positives/negatives with extended bayes for u1 in units: numerator = posterior[u1][u1] normalizer = sum(posterior[u1][u2] for u2 in units) false_positive[u1] = sp.sum((normalizer - numerator) / normalizer) other_units = units[:] other_units.remove(u1) numerator = sp.vstack((posterior[u][u1] for u in other_units)) normalizer = sp.vstack(sum(posterior[u][u2] for u2 in units) for u in other_units) false_negative[u1] = sp.sum(numerator / normalizer) # Prepare return values, convert sums to means totals = {} for u, fp in false_positive.iteritems(): fn = false_negative[u] if not spikes[u]: totals[u] = (0, 0) else: num = spikes[u] totals[u] = (fp / num, fn / num) return totals, singles def _pair_overlap(waves1, waves2, mean1, mean2, cov1, cov2): """ Calculate FP/FN estimates for two gaussian clusters """ from sklearn import mixture means = sp.vstack([[mean1], [mean2]]) covars = sp.vstack([[cov1], [cov2]]) weights = sp.array([waves1.shape[1], waves2.shape[1]], dtype=float) weights /= weights.sum() # Create mixture of two Gaussians from the existing estimates mix = mixture.GMM(n_components=2, covariance_type='full', init_params='') mix.covars_ = covars mix.weights_ = weights mix.means_ = means posterior1 = mix.predict_proba(waves1.T)[:, 1] posterior2 = mix.predict_proba(waves2.T)[:, 0] return (posterior1.mean(), posterior2.sum() / len(posterior1), posterior2.mean(), posterior1.sum() / len(posterior2)) def _object_has_size(obj, size): """ Return if the object, which could be either a neo.Spike or ndarray, has the given size. """ if isinstance(obj, neo.Spike): return obj.waveform.size == size return obj.size == size
[docs]def overlap_fp_fn(spikes, means=None, covariances=None): """ Return dicts of tuples (False positive rate, false negative rate) indexed by unit. This function needs :mod:`sklearn` if ``covariances`` is not set to ``'white'``. This function estimates the pairwise and total false positive and false negative rates for a number of waveform clusters. The results can be interpreted as follows: False positives are the fraction of spikes in a cluster that is estimated to belong to a different cluster (a specific cluster for pairwise results or any other cluster for total results). False negatives are the number spikes from other clusters that are estimated to belong to a given cluster (also expressed as fraction, this number can be larger than 1 in extreme cases). Details for the calculation can be found in (Hill et al. The Journal of Neuroscience. 2011). The calculation for total false positive and false negative rates does not follow Hill et al., who propose a simple addition of pairwise probabilities. Instead, the total error probabilities are estimated using all clusters at once. :param dict spikes: Dictionary, indexed by unit, of lists of spike waveforms as :class:`neo.core.Spike` objects or numpy arrays. If the waveforms have multiple channels, they will be flattened automatically. All waveforms need to have the same number of samples. :param dict means: Dictionary, indexed by unit, of lists of spike waveforms as :class:`neo.core.Spike` objects or numpy arrays. Means for units that are not in this dictionary will be estimated using the spikes. Note that if you pass ``'white'`` for ``covariances`` and you want to provide means, they have to be whitened in the same way as the spikes. Default: None, means will be estimated from data. :param covariances: Dictionary, indexed by unit, of lists of covariance matrices. Covariances for units that are not in this dictionary will be estimated using the spikes. It is useful to give a covariance matrix if few spikes are present - consider using the noise covariance. If you use prewhitened spikes (i.e. all clusters are normal distributed, so their covariance matrix is the identity), you can pass ``'white'`` here. The calculation will be much faster in this case and the sklearn package is not required. Default: None, covariances will estimated from data. :type covariances: dict or str :returns: Two values: * A dictionary (indexed by unit) of total (false positive rate, false negative rate) tuples. * A dictionary of dictionaries, both indexed by units, of pairwise (false positive rate, false negative rate) tuples. :rtype: dict, dict """ units = spikes.keys() total_spikes = 0 for spks in spikes.itervalues(): total_spikes += len(spks) if total_spikes < 1: return {u: (0.0, 0.0) for u in units}, {} if means is None: means = {} white = False if covariances is None: covariances = {} elif covariances == 'white': white = True covariances = {} # Convert Spike objects to arrays dimensionality = None spike_arrays = {} for u, spks in spikes.iteritems(): spikelist = [] if not spks or (len(spks) < 2 and u not in covariances): units.remove(u) continue for s in spks: if isinstance(s, neo.Spike): spikelist.append( sp.asarray(s.waveform.rescale(pq.uV)).T.flatten()) else: spikelist.append(s) spike_arrays[u] = sp.array(spikelist).T if dimensionality is None: dimensionality = spike_arrays[u].shape[0] elif dimensionality != spike_arrays[u].shape[0]: raise SpykeException('All spikes need to have the same number' 'of samples!') if not units: return {}, {} if len(units) == 1: return {units[0]: (0.0, 0.0)}, {} # Convert or calculate means and covariances shaped_means = {} covs = {} if white: cov = sp.eye(dimensionality) covariances = {u: cov for u in units} for u in units: if u in means and _object_has_size(means[u], dimensionality): mean = means[u] if isinstance(mean, neo.Spike): shaped_means[u] = sp.asarray( mean.waveform.rescale(pq.uV)).T.flatten() else: shaped_means[u] = means[u].T.flatten() else: shaped_means[u] = spike_arrays[u].mean(axis=1) if white: return _fast_overlap_whitened(spike_arrays, shaped_means) for u in units: if u not in covariances: covs[u] = sp.cov(spike_arrays[u]) else: covs[u] = covariances[u] # Calculate pairwise false positives/negatives singles = {u: {} for u in units} for i, u1 in enumerate(units): u1 = units[i] for u2 in units[i + 1:]: error_rates = _pair_overlap( spike_arrays[u1], spike_arrays[u2], shaped_means[u1], shaped_means[u2], covs[u1], covs[u2]) singles[u1][u2] = error_rates[0:2] singles[u2][u1] = error_rates[2:4] # Calculate complete false positives/negatives import sklearn mix = sklearn.mixture.GMM(n_components=2, covariance_type='full') mix_means = [] mix_covars = [] mix_weights = [] for u in units: mix_means.append(shaped_means[u]) mix_covars.append([covs[u]]) mix_weights.append(spike_arrays[u].shape[1]) mix.means_ = sp.vstack(mix_means) mix.covars_ = sp.vstack(mix_covars) mix_weights = sp.array(mix_weights, dtype=float) mix_weights /= mix_weights.sum() mix.weights_ = mix_weights # P(spikes of unit[i] in correct cluster) post_mean = sp.zeros(len(units)) # sum(P(spikes of unit[i] in cluster[j]) post_sum = sp.zeros((len(units), len(units))) for i, u in enumerate(units): posterior = mix.predict_proba(spike_arrays[u].T) post_mean[i] = posterior[:, i].mean() post_sum[i, :] = posterior.sum(axis=0) totals = {} for i, u in enumerate(units): fp = 1.0 - post_mean[i] ind = range(len(units)) ind.remove(i) fn = post_sum[ind, i].sum() / float(spike_arrays[u].shape[1]) totals[u] = (fp, fn) return totals, singles
[docs]def variance_explained(spikes, means=None, noise=None): """ Returns the fraction of variance in each channel that is explained by the means. Values below 0 or above 1 for large data sizes indicate that some assumptions were incorrect (e.g. about channel noise) and the results should not be trusted. :param dict spikes: Dictionary, indexed by unit, of :class:`neo.core.SpikeTrain` objects (where the ``waveforms`` member includes the spike waveforms) or lists of :class:`neo.core.Spike` objects. :param dict means: Dictionary, indexed by unit, of lists of spike waveforms as :class:`neo.core.Spike` objects or numpy arrays. Means for units that are not in this dictionary will be estimated using the spikes. Default: None - means will be estimated from given spikes. :type noise: Quantity 1D :param noise: The known noise levels (as variance) per channel of the original data. This should be estimated from the signal periods that do not contain spikes, otherwise the explained variance could be overestimated. If None, the estimate of explained variance is done without regard for noise. Default: None :return dict: A dictionary of arrays, both indexed by unit. If ``noise`` is ``None``, the dictionary contains the fraction of explained variance per channel without taking noise into account. If ``noise`` is given, it contains the fraction of variance per channel explained by the means and given noise level together. """ ret = {} if means is None: means = {} for u, spks in spikes.iteritems(): train = spks if not isinstance(train, neo.SpikeTrain): train = spikes_to_spike_train(spks) if u in means and means[u].waveform.shape[0] == train.waveforms.shape[1]: spike = means[u] else: spike = neo.Spike(0) spike.waveform = sp.mean(train.waveforms, axis=0) orig = sp.mean(sp.var(train.waveforms, axis=1), axis=0) waves = train.waveforms - spike.waveform new = sp.mean(sp.var(waves, axis=1), axis=0) if noise is not None: ret[u] = sp.asarray(1 - (new - noise) / orig) else: ret[u] = sp.asarray(1 - new / orig) return ret

Project Versions

This Page