""" Functions for estimating the quality of spike sorting results. These
functions estimate false positive and false negative fractions.
"""
from __future__ import division
import scipy as sp
from scipy.spatial.distance import cdist
import quantities as pq
import neo
from progress_indicator import ProgressIndicator
from . import SpykeException
from conversions import spikes_to_spike_train
[docs]def get_refperiod_violations(spike_trains, refperiod, progress=None):
""" Return the refractory period violations in the given spike trains
for the specified refractory period.
:param dict spike_trains: Dictionary of lists of
:class:`neo.core.SpikeTrain` objects.
:param refperiod: The refractory period (time).
:type refperiod: Quantity scalar
:param progress: Set this parameter to report progress.
:type progress: :class:`.progress_indicator.ProgressIndicator`
:returns: Two values:
* The total number of violations.
* A dictionary (with the same indices as ``spike_trains``) of
arrays with violation times (Quantity 1D with the same unit as
``refperiod``) for each spike train.
:rtype: int, dict """
if type(refperiod) != pq.Quantity or \
refperiod.simplified.dimensionality != pq.s.dimensionality:
raise ValueError('refperiod must be a time quantity!')
if not progress:
progress = ProgressIndicator()
total_violations = 0
violations = {}
for u, tL in spike_trains.iteritems():
violations[u] = []
for i, t in enumerate(tL):
st = t.copy()
st.sort()
isi = sp.diff(st)
violations[u].append(st[isi < refperiod].rescale(refperiod.units))
total_violations += len(violations[u][i])
progress.step()
return total_violations, violations
[docs]def calculate_refperiod_fp(num_spikes, refperiod, violations, total_time):
""" Return the rate of false positives calculated from refractory period
calculations for each unit. The equation used is described in
(Hill et al. The Journal of Neuroscience. 2011).
:param dict num_spikes: Dictionary of total number of spikes,
indexed by unit.
:param refperiod: The refractory period (time). If the spike sorting
algorithm includes a censored period (a time after a spike during
which no new spikes can be found), subtract it from the refractory
period before passing it to this function.
:type refperiod: Quantity scalar
:param dict violations: Dictionary of total number of violations,
indexed the same as num_spikes.
:param total_time: The total time in which violations could have occured.
:type total_time: Quantity scalar
:returns: A dictionary of false positive rates indexed by unit.
Note that values above 0.5 can not be directly interpreted as a
false positive rate! These very high values can e.g. indicate
that the generating processes are not independent.
"""
if type(refperiod) != pq.Quantity or \
refperiod.simplified.dimensionality != pq.s.dimensionality:
raise ValueError('refperiod must be a time quantity!')
fp = {}
factor = total_time / (2 * refperiod)
for u, n in num_spikes.iteritems():
if n == 0:
fp[u] = 0
continue
zw = (violations[u] * factor / n ** 2).simplified
if zw > 0.25:
fp[u] = 0.5 + sp.sqrt(0.25 - zw).imag
continue
fp[u] = 0.5 - sp.sqrt(0.25 - zw)
return fp
def _multi_norm(x, mean):
""" Evaluate pdf of multivariate normal distribution with a mean
at rows of x with high precision.
"""
d = x.shape[1]
fac = (2 * sp.pi) ** (-d / 2.0)
y = cdist(x, sp.atleast_2d(mean), 'sqeuclidean') * -0.5
return fac * sp.exp(sp.longdouble(y))
def _fast_overlap_whitened(spike_arrays, means):
units = spike_arrays.keys()
spikes = {u: spike_arrays[u].shape[1] for u in spike_arrays.iterkeys()}
prior = {}
total_spikes = 0
for u, mean in means.iteritems():
total_spikes += spikes[u]
if total_spikes < 1:
return {u: (0.0, 0.0) for u in units}, {}
# Arrays of unnormalized posteriors (likelihood times prior)
# for all units
posterior = {}
false_positive = {}
false_negative = {}
for u in units:
prior[u] = spikes[u] / total_spikes
false_positive[u] = 0
false_negative[u] = 0
# Calculate posteriors
for u1 in units[:]:
if not spikes[u1]:
units.remove(u1)
continue
posterior[u1] = {}
for u2, mean in means.iteritems():
llh = _multi_norm(spike_arrays[u1].T, mean)
posterior[u1][u2] = llh * prior[u2]
# Calculate pairwise false positives/negatives
singles = {u: {} for u in units}
for i, u1 in enumerate(units):
u1 = units[i]
for u2 in units[i + 1:]:
f1 = sp.sum(posterior[u1][u2] /
(posterior[u1][u1] + posterior[u1][u2]),
dtype=sp.double)
f2 = sp.sum(posterior[u2][u1] /
(posterior[u2][u1] + posterior[u2][u2]),
dtype=sp.double)
singles[u1][u2] = (f1 / spikes[u1] if spikes[u1] else 0,
f2 / spikes[u1] if spikes[u1] else 0)
singles[u2][u1] = (f2 / spikes[u2] if spikes[u2] else 0,
f1 / spikes[u2] if spikes[u2] else 0)
# Calculate complete false positives/negatives with extended bayes
for u1 in units:
numerator = posterior[u1][u1]
normalizer = sum(posterior[u1][u2] for u2 in units)
false_positive[u1] = sp.sum((normalizer - numerator) / normalizer)
other_units = units[:]
other_units.remove(u1)
numerator = sp.vstack((posterior[u][u1] for u in other_units))
normalizer = sp.vstack(sum(posterior[u][u2] for u2 in units) for u in other_units)
false_negative[u1] = sp.sum(numerator / normalizer)
# Prepare return values, convert sums to means
totals = {}
for u, fp in false_positive.iteritems():
fn = false_negative[u]
if not spikes[u]:
totals[u] = (0, 0)
else:
num = spikes[u]
totals[u] = (fp / num, fn / num)
return totals, singles
def _pair_overlap(waves1, waves2, mean1, mean2, cov1, cov2):
""" Calculate FP/FN estimates for two gaussian clusters
"""
from sklearn import mixture
means = sp.vstack([[mean1], [mean2]])
covars = sp.vstack([[cov1], [cov2]])
weights = sp.array([waves1.shape[1], waves2.shape[1]], dtype=float)
weights /= weights.sum()
# Create mixture of two Gaussians from the existing estimates
mix = mixture.GMM(n_components=2, covariance_type='full', init_params='')
mix.covars_ = covars
mix.weights_ = weights
mix.means_ = means
posterior1 = mix.predict_proba(waves1.T)[:, 1]
posterior2 = mix.predict_proba(waves2.T)[:, 0]
return (posterior1.mean(), posterior2.sum() / len(posterior1),
posterior2.mean(), posterior1.sum() / len(posterior2))
def _object_has_size(obj, size):
""" Return if the object, which could be either a neo.Spike or ndarray,
has the given size. """
if isinstance(obj, neo.Spike):
return obj.waveform.size == size
return obj.size == size
[docs]def overlap_fp_fn(spikes, means=None, covariances=None):
""" Return dicts of tuples (False positive rate, false negative rate)
indexed by unit. This function needs :mod:`sklearn` if
``covariances`` is not set to ``'white'``.
This function estimates the pairwise and total false positive and false
negative rates for a number of waveform clusters. The results can be
interpreted as follows: False positives are the fraction of spikes in a
cluster that is estimated to belong to a different cluster (a specific
cluster for pairwise results or any other cluster for total results).
False negatives are the number spikes from other clusters that are
estimated to belong to a given cluster (also expressed as fraction, this
number can be larger than 1 in extreme cases).
Details for the calculation can be found in
(Hill et al. The Journal of Neuroscience. 2011).
The calculation for total false positive and false negative rates does
not follow Hill et al., who propose a simple addition of pairwise
probabilities. Instead, the total error probabilities are estimated
using all clusters at once.
:param dict spikes: Dictionary, indexed by unit, of lists of
spike waveforms as :class:`neo.core.Spike` objects or numpy arrays.
If the waveforms have multiple channels, they will be flattened
automatically. All waveforms need to have the same number of samples.
:param dict means: Dictionary, indexed by unit, of lists of
spike waveforms as :class:`neo.core.Spike` objects or numpy arrays.
Means for units that are not in this dictionary will be estimated
using the spikes. Note that if you pass ``'white'`` for
``covariances`` and you want to provide means, they have to be
whitened in the same way as the spikes.
Default: None, means will be estimated from data.
:param covariances: Dictionary, indexed by unit, of lists of
covariance matrices. Covariances for units that are not in this
dictionary will be estimated using the spikes. It is useful to give
a covariance matrix if few spikes are present - consider using the
noise covariance. If you use prewhitened spikes (i.e. all clusters
are normal distributed, so their covariance matrix is the identity),
you can pass ``'white'`` here. The calculation will be much faster in
this case and the sklearn package is not required.
Default: None, covariances will estimated from data.
:type covariances: dict or str
:returns: Two values:
* A dictionary (indexed by unit) of total
(false positive rate, false negative rate) tuples.
* A dictionary of dictionaries, both indexed by units,
of pairwise (false positive rate, false negative rate) tuples.
:rtype: dict, dict
"""
units = spikes.keys()
total_spikes = 0
for spks in spikes.itervalues():
total_spikes += len(spks)
if total_spikes < 1:
return {u: (0.0, 0.0) for u in units}, {}
if means is None:
means = {}
white = False
if covariances is None:
covariances = {}
elif covariances == 'white':
white = True
covariances = {}
# Convert Spike objects to arrays
dimensionality = None
spike_arrays = {}
for u, spks in spikes.iteritems():
spikelist = []
if not spks or (len(spks) < 2 and u not in covariances):
units.remove(u)
continue
for s in spks:
if isinstance(s, neo.Spike):
spikelist.append(
sp.asarray(s.waveform.rescale(pq.uV)).T.flatten())
else:
spikelist.append(s)
spike_arrays[u] = sp.array(spikelist).T
if dimensionality is None:
dimensionality = spike_arrays[u].shape[0]
elif dimensionality != spike_arrays[u].shape[0]:
raise SpykeException('All spikes need to have the same number'
'of samples!')
if not units:
return {}, {}
if len(units) == 1:
return {units[0]: (0.0, 0.0)}, {}
# Convert or calculate means and covariances
shaped_means = {}
covs = {}
if white:
cov = sp.eye(dimensionality)
covariances = {u: cov for u in units}
for u in units:
if u in means and _object_has_size(means[u], dimensionality):
mean = means[u]
if isinstance(mean, neo.Spike):
shaped_means[u] = sp.asarray(
mean.waveform.rescale(pq.uV)).T.flatten()
else:
shaped_means[u] = means[u].T.flatten()
else:
shaped_means[u] = spike_arrays[u].mean(axis=1)
if white:
return _fast_overlap_whitened(spike_arrays, shaped_means)
for u in units:
if u not in covariances:
covs[u] = sp.cov(spike_arrays[u])
else:
covs[u] = covariances[u]
# Calculate pairwise false positives/negatives
singles = {u: {} for u in units}
for i, u1 in enumerate(units):
u1 = units[i]
for u2 in units[i + 1:]:
error_rates = _pair_overlap(
spike_arrays[u1], spike_arrays[u2],
shaped_means[u1], shaped_means[u2],
covs[u1], covs[u2])
singles[u1][u2] = error_rates[0:2]
singles[u2][u1] = error_rates[2:4]
# Calculate complete false positives/negatives
import sklearn
mix = sklearn.mixture.GMM(n_components=2, covariance_type='full')
mix_means = []
mix_covars = []
mix_weights = []
for u in units:
mix_means.append(shaped_means[u])
mix_covars.append([covs[u]])
mix_weights.append(spike_arrays[u].shape[1])
mix.means_ = sp.vstack(mix_means)
mix.covars_ = sp.vstack(mix_covars)
mix_weights = sp.array(mix_weights, dtype=float)
mix_weights /= mix_weights.sum()
mix.weights_ = mix_weights
# P(spikes of unit[i] in correct cluster)
post_mean = sp.zeros(len(units))
# sum(P(spikes of unit[i] in cluster[j])
post_sum = sp.zeros((len(units), len(units)))
for i, u in enumerate(units):
posterior = mix.predict_proba(spike_arrays[u].T)
post_mean[i] = posterior[:, i].mean()
post_sum[i, :] = posterior.sum(axis=0)
totals = {}
for i, u in enumerate(units):
fp = 1.0 - post_mean[i]
ind = range(len(units))
ind.remove(i)
fn = post_sum[ind, i].sum() / float(spike_arrays[u].shape[1])
totals[u] = (fp, fn)
return totals, singles
[docs]def variance_explained(spikes, means=None, noise=None):
""" Returns the fraction of variance in each channel that is explained
by the means.
Values below 0 or above 1 for large data sizes indicate
that some assumptions were incorrect (e.g. about channel noise) and
the results should not be trusted.
:param dict spikes: Dictionary, indexed by unit, of
:class:`neo.core.SpikeTrain` objects (where the ``waveforms``
member includes the spike waveforms) or lists of
:class:`neo.core.Spike` objects.
:param dict means: Dictionary, indexed by unit, of lists of
spike waveforms as :class:`neo.core.Spike` objects or numpy arrays.
Means for units that are not in this dictionary will be estimated
using the spikes.
Default: None - means will be estimated from given spikes.
:type noise: Quantity 1D
:param noise: The known noise levels (as variance) per channel of the
original data. This should be estimated from the signal periods
that do not contain spikes, otherwise the explained variance
could be overestimated. If None, the estimate of explained variance
is done without regard for noise.
Default: None
:return dict: A dictionary of arrays, both indexed by unit. If ``noise``
is ``None``, the dictionary contains
the fraction of explained variance per channel without taking noise
into account. If ``noise`` is given, it contains the fraction of
variance per channel explained by the means and given noise level
together.
"""
ret = {}
if means is None:
means = {}
for u, spks in spikes.iteritems():
train = spks
if not isinstance(train, neo.SpikeTrain):
train = spikes_to_spike_train(spks)
if u in means and means[u].waveform.shape[0] == train.waveforms.shape[1]:
spike = means[u]
else:
spike = neo.Spike(0)
spike.waveform = sp.mean(train.waveforms, axis=0)
orig = sp.mean(sp.var(train.waveforms, axis=1), axis=0)
waves = train.waveforms - spike.waveform
new = sp.mean(sp.var(waves, axis=1), axis=0)
if noise is not None:
ret[u] = sp.asarray(1 - (new - noise) / orig)
else:
ret[u] = sp.asarray(1 - new / orig)
return ret