"""
.. autoclass:: CatalogueConstructor
:members:
"""
import os
import json
from collections import OrderedDict
import time
import pickle
import itertools
import datetime
import shutil
import numpy as np
import scipy.signal
import scipy.interpolate
import seaborn as sns
sns.set_style("white")
import sklearn
from . import signalpreprocessor
from . import peakdetector
from . import decomposition
from . import cluster
from . import metrics
from .tools import median_mad, get_pairs_over_threshold, int32_to_rgba, rgba_to_int32, make_color_dict
from .iotools import ArrayCollection
import matplotlib.pyplot as plt
from . import labelcodes
_persistent_metrics = ('spike_waveforms_similarity', 'cluster_similarity',
'cluster_ratio_similarity', 'spike_silhouette')
_centroids_arrays = ('centroids_median', 'centroids_mad', 'centroids_mean', 'centroids_std',)
_reset_after_waveforms_arrays = ('some_features', 'channel_to_features', 'some_noise_snippet',
'some_noise_index', 'some_noise_features',) + _persistent_metrics + _centroids_arrays
#~ _reset_after_peak_arrays = ('some_peaks_index', 'some_waveforms', 'some_features',
#~ 'channel_to_features',
#~ 'some_noise_index', 'some_noise_snippet', 'some_noise_features',
#~ ) + _persistent_metrics + _centroids_arrays
_reset_after_peak_arrays = ('some_peaks_index', 'some_waveforms',) + _reset_after_waveforms_arrays
_persitent_arrays = ('all_peaks', 'signals_medians','signals_mads', 'clusters') + \
_reset_after_peak_arrays
_dtype_peak = [('index', 'int64'), ('cluster_label', 'int64'), ('segment', 'int64'),]
_dtype_cluster = [('cluster_label', 'int64'), ('cell_label', 'int64'),
('max_on_channel', 'int64'), ('max_peak_amplitude', 'float64'),
('waveform_rms', 'float64'), ('nb_peak', 'int64'),
('tag', 'U16'), ('annotations', 'U32'), ('color', 'uint32')]
_keep_cluster_attr_on_new = ['cell_label', 'tag','annotations', 'color']
[docs]class CatalogueConstructor:
__doc__ = """
The goal of CatalogueConstructor is to construct a catalogue of template (centroids)
for the Peeler.
For so the CatalogueConstructor will:
* preprocess a duration of data
* detect peaks
* extract some waveform
* compute some feature from these waveforms
* find cluster
* compute some metrics ontheses clusters
* enable some manual trash/merge/split
At the end we can **make_catalogue_for_peeler**, this construct everything usefull for
the Peeler : centroids (median for each cluster: the centroids, first and secnd derivative,
median and mad of noise, ...
You need to have one CatalogueConstructor for each channel_group.
Since all operation are more or less slow each internal arrays is by default persistent
on the disk (memory_mode='memmap'). With this when you instanciate a CatalogueConstructor
you load all existing arrays already computed. For that tridesclous mainly use the numpy
structured arrays (array of struct) with memmap. So, hacking internal arrays of the
CatalogueConstructor should be easy outside python and tridesclous.
You can explore/save/reload internal state by borwsing this directory:
*path_of_dataset/channel_group_XXX/catalogue_constructor*
**Usage**::
from tridesclous import Dataio, CatalogueConstructor
dirname = '/path/to/dataset'
dataio = DataIO(dirname=dirname)
cc = CatalogueConstructor(dataio, chan_grp=0)
# preprocessing
cc.set_preprocessor_params(chunksize=1024,
highpass_freq=300.,
lowpass_freq=5000.,
common_ref_removal=True,
lostfront_chunksize=64,
peak_sign='-',
relative_threshold=5.5,
peak_span_ms=0.3,
)
cc.estimate_signals_noise(seg_num=0, duration=10.)
cc.run_signalprocessor(duration=60.)
# waveform/feature/cluster
cc.extract_some_waveforms(n_left=-25, n_right=40, mode='rand')
cc.clean_waveforms(alien_value_threshold=55.)
cc.project(method='global_pca', n_components=7)
cc.find_clusters(method='kmeans', n_clusters=7)
# manual stuff
cc.trash_small_cluster(n=5)
cc.order_clusters()
# do this before peeler
cc.make_catalogue_for_peeler()
For more example see examples.
**Persitent atributes**:
CatalogueConstructor have a mechanism to store/load some attributes in the
dirname of dataio. This is usefull to continue previous woprk on a catalogue.
The attributes are almost all numpy arrays and stored using the numpy.memmap
mechanism. So prefer local fast fast storage like ssd.
Here the list of theses attributes with shape and dtype. **N** is the total
number of peak detected. **M** is the number of selected peak for
waveform/feature/cluser. **C** is the number of clusters
* all_peaks (N, ) dtype = {0}
* signals_medians (nb_sample, nb_channel, ) float32
* signals_mads (nb_sample, nb_channel, ) float32
* clusters (c, ) dtype= {1}
* some_peaks_index (M) int64
* some_waveforms (M, width, nb_channel) float32
* some_features (M, nb_feature) float32
* channel_to_features (nb_chan, nb_component) bool
* some_noise_snippet (nb_noise, width, nb_channel) float32
* some_noise_index (nb_noise, ) int64
* some_noise_features (nb_noise, nb_feature) float32
* centroids_median (C, width, nb_channel) float32
* centroids_mad (C, width, nb_channel) float32
* centroids_mean (C, width, nb_channel) float32
* centroids_std (C, width, nb_channel) float32
* spike_waveforms_similarity (M, M) float32
* cluster_similarity (C, C) float32
* cluster_ratio_similarity (C, C) float32
""".format(_dtype_peak, _dtype_cluster)
def __init__(self, dataio, chan_grp=None, name='catalogue_constructor'):
"""
Parameters
----------
dataio: tirdesclous.DataIO
the dataio object for describe the dataset
chan_grp: int
The channel group key . See PRB file. None by default = take the first
channel group (often 0 but sometimes 1!!)
name: str a name
The name of the CatalogueConstructor. You can have several name for the
same dataset but several catalogue and try different options.
"""
self.dataio = dataio
if chan_grp is None:
chan_grp = min(self.dataio.channel_groups.keys())
self.chan_grp = chan_grp
self.nb_channel = self.dataio.nb_channel(chan_grp=self.chan_grp)
self.geometry = self.dataio.get_geometry(chan_grp=self.chan_grp)
self.catalogue_path = os.path.join(self.dataio.channel_group_path[chan_grp], name)
if not os.path.exists(self.catalogue_path):
os.mkdir(self.catalogue_path)
self.arrays = ArrayCollection(parent=self, dirname=self.catalogue_path)
self.info_filename = os.path.join(self.catalogue_path, 'info.json')
if not os.path.exists(self.info_filename):
#first init
self.info = {}
self.flush_info()
else:
with open(self.info_filename, 'r', encoding='utf8') as f:
self.info = json.load(f)
for name in _persitent_arrays:
# this set attribute to class if exsits
self.arrays.load_if_exists(name)
if self.all_peaks is not None:
self.memory_mode='memmap'
self.projector = None
[docs] def flush_info(self):
""" Flush info (mainly parameters) to json files.
"""
with open(self.info_filename, 'w', encoding='utf8') as f:
json.dump(self.info, f, indent=4)
def __repr__(self):
t = "CatalogueConstructor\n"
t += ' ' + self.dataio.channel_group_label(chan_grp=self.chan_grp) + '\n'
if self.all_peaks is None:
t += ' Signal pre-processing not done yet'
return t
#~ t += " nb_peak: {}\n".format(self.nb_peak)
nb_peak_by_segment = [ np.sum(self.all_peaks['segment']==i) for i in range(self.dataio.nb_segment)]
t += ' nb_peak_by_segment: '+', '.join('{}'.format(n) for n in nb_peak_by_segment)+'\n'
if self.some_waveforms is not None:
t += ' some_waveforms.shape: {}\n'.format(self.some_waveforms.shape)
if self.some_features is not None:
t += ' some_features.shape: {}\n'.format(self.some_features.shape)
if hasattr(self, 'cluster_labels'):
t += ' cluster_labels {}\n'.format(self.cluster_labels)
return t
@property
def nb_peak(self):
if self.all_peaks is None:
return 0
return self.all_peaks.size
@property
def cluster_labels(self):
if self.clusters is not None:
return self.clusters['cluster_label']
else:
return np.array([], dtype='int64')
@property
def positive_cluster_labels(self):
return self.cluster_labels[self.cluster_labels>=0]
[docs] def reload_data(self):
"""Reload persitent arrays.
"""
if not hasattr(self, 'memory_mode') or not self.memory_mode=='memmap':
return
for name in _persitent_arrays:
# this set attribute to class if exsits
self.arrays.load_if_exists(name)
def _reset_arrays(self, list_arrays):
#TODO fix this need to delete really this arrays but size=0 is buggy
for name in list_arrays:
self.arrays.detach_array(name)
setattr(self, name, None)
[docs] def set_preprocessor_params(self, chunksize=1024,
memory_mode='memmap',
internal_dtype = 'float32',
#signal preprocessor
signalpreprocessor_engine='numpy',
highpass_freq=300.,
lowpass_freq=None,
smooth_size=0,
common_ref_removal=False,
lostfront_chunksize=None,
#peak detector
peakdetector_engine='numpy',
peak_sign='-', relative_threshold=7, peak_span_ms=0.3,
):
"""
Set parameters for the preprocessor engine
Parameters
----------
chunksize: int. default 1024
Length of each chunk for processing.
For real time, the latency is between chunksize and 2*chunksize.
memory_mode: 'memmap' or 'ram' default 'memmap'
By default all arrays are persistent with memmap but you can
also have everything in ram.
internal_dtype: 'float32' (or 'float64')
Internal dtype for signals/waveforms/features.
Support of intteger signal and waveform is planned for one day and should
boost the process!!
signalpreprocessor_engine='numpy' or 'opencl'
If you have pyopencl installed and correct ICD installed you can try
'opencl' for high channel count some critial part of the processing is done on
the GPU.
highpass_freq: float dfault 300
High pass cut frquency of the filter. Can be None if the raw
dataset is already filtered.
lowpass_freq: float default is None
Low pass frequency of the filter. None by default.
For high samplign rate a low pass at 5~8kHz can be a good idea
for smoothing a bit waveform and avoid high noise.
smooth_size: int default 0
This a simple smooth convolution kernel. More or less act
like a low pass filter. Can be use instead lowpass_freq.
common_ref_removal: bool. False by dfault.
The remove the median of all channel sample by sample.
lostfront_chunksize: int. default None
size in sample of the margin at the front edge for each chunk to avoid border effect in backward filter.
In you don't known put None then lostfront_chunksize will be int(sample_rate/highpass_freq)*3 which is quite robust (<5% error)
compared to a true offline filtfilt.
peakdetector_engine: 'numpy' or 'opencl'
Engine for peak detection.
peak_sign: '-' or '+'
Signa of peak.
relative_threshold: int default 7
Threshold for peak detection. The preprocessed signal have units
expressed in MAD (robust STD). So 7 is MAD*7.
peak_span_ms: float default 0.3
Peak span to avoid double detection. In millisecond.
"""
self._reset_arrays(_persitent_arrays)
self.chunksize = chunksize
self.memory_mode = memory_mode
# set default lostfront_chunksize if none is provided
if lostfront_chunksize is None or lostfront_chunksize==0:
assert highpass_freq is not None, 'lostfront_chunksize=None needs a highpass_freq'
lostfront_chunksize = int(self.dataio.sample_rate/highpass_freq*3)
#~ self.arrays.initialize_array('all_peaks', self.memory_mode, _dtype_peak, (-1, ))
self.signal_preprocessor_params = dict(highpass_freq=highpass_freq, lowpass_freq=lowpass_freq,
smooth_size=smooth_size, common_ref_removal=common_ref_removal,
lostfront_chunksize=lostfront_chunksize, output_dtype=internal_dtype,
signalpreprocessor_engine=signalpreprocessor_engine)
SignalPreprocessor_class = signalpreprocessor.signalpreprocessor_engines[signalpreprocessor_engine]
self.signalpreprocessor = SignalPreprocessor_class(self.dataio.sample_rate, self.nb_channel, chunksize, self.dataio.source_dtype)
self.peak_detector_params = dict(peak_sign=peak_sign, relative_threshold=relative_threshold, peak_span_ms=peak_span_ms)
PeakDetector_class = peakdetector.peakdetector_engines[peakdetector_engine]
self.peakdetector = PeakDetector_class(self.dataio.sample_rate, self.nb_channel,
self.chunksize, internal_dtype)
#TODO make processed data as int32 ???
for i in range(self.dataio.nb_segment):
self.dataio.reset_processed_signals(seg_num=i, chan_grp=self.chan_grp, dtype=internal_dtype)
#~ self.nb_peak = 0
# put all params in info
self.info['internal_dtype'] = internal_dtype
self.info['chunksize'] = chunksize
self.info['signal_preprocessor_params'] = self.signal_preprocessor_params
self.info['peak_detector_params'] = self.peak_detector_params
self.flush_info()
[docs] def estimate_signals_noise(self, seg_num=0, duration=10.):
"""
This estimate the median and mad on processed signals on
a short duration. This will be necessary for normalisation
in next steps.
Note that if the noise is stable even a short duratio is OK.
Parameters
----------
seg_num: int
segment index
duration: float
duration in seconds
"""
length = int(duration*self.dataio.sample_rate)
length -= length%self.chunksize
assert length<self.dataio.get_segment_length(seg_num), 'duration exeed size'
name = 'filetered_sigs_for_noise_estimation_seg_{}'.format(seg_num)
shape=(length - self.signal_preprocessor_params['lostfront_chunksize'], self.nb_channel)
filtered_sigs = self.arrays.create_array(name, self.info['internal_dtype'], shape, 'memmap')
params2 = dict(self.signal_preprocessor_params)
params2.pop('signalpreprocessor_engine')
params2['normalize'] = False
self.signalpreprocessor.change_params(**params2)
iterator = self.dataio.iter_over_chunk(seg_num=seg_num, chan_grp=self.chan_grp, chunksize=self.chunksize, i_stop=length,
signal_type='initial')
for pos, sigs_chunk in iterator:
pos2, preprocessed_chunk = self.signalpreprocessor.process_data(pos, sigs_chunk)
if preprocessed_chunk is not None:
filtered_sigs[pos2-preprocessed_chunk.shape[0]:pos2, :] = preprocessed_chunk
#create persistant arrays
self.arrays.create_array('signals_medians', self.info['internal_dtype'], (self.nb_channel,), 'memmap')
self.arrays.create_array('signals_mads', self.info['internal_dtype'], (self.nb_channel,), 'memmap')
self.signals_medians[:] = signals_medians = np.median(filtered_sigs[:pos2], axis=0)
self.signals_mads[:] = np.median(np.abs(filtered_sigs[:pos2]-signals_medians),axis=0)*1.4826
#detach filetered signals even if the file remains.
self.arrays.detach_array(name)
def signalprocessor_one_chunk(self, pos, sigs_chunk, seg_num, detect_peak=True):
pos2, preprocessed_chunk = self.signalpreprocessor.process_data(pos, sigs_chunk)
if preprocessed_chunk is None:
return
self.dataio.set_signals_chunk(preprocessed_chunk, seg_num=seg_num, chan_grp=self.chan_grp,
i_start=pos2-preprocessed_chunk.shape[0], i_stop=pos2, signal_type='processed')
if detect_peak:
n_peaks, chunk_peaks = self.peakdetector.process_data(pos2, preprocessed_chunk)
if chunk_peaks is not None:
peaks = np.zeros(chunk_peaks.size, dtype=_dtype_peak)
peaks['index'] = chunk_peaks
peaks['segment'][:] = seg_num
peaks['cluster_label'][:] = labelcodes.LABEL_NO_WAVEFORM
self.arrays.append_chunk('all_peaks', peaks)
def run_signalprocessor_loop_one_segment(self, seg_num=0, duration=60., detect_peak=True):
length = int(duration*self.dataio.sample_rate)
length = min(length, self.dataio.get_segment_length(seg_num))
length -= length%self.chunksize
#TODO make this by segment
self.info['processed_length'] = length
self.flush_info()
#initialize engines
p = dict(self.signal_preprocessor_params)
p.pop('signalpreprocessor_engine')
p['normalize'] = True
p['signals_medians'] = self.signals_medians
p['signals_mads'] = self.signals_mads
self.signalpreprocessor.change_params(**p)
self.peakdetector.change_params(**self.peak_detector_params)
iterator = self.dataio.iter_over_chunk(seg_num=seg_num, chan_grp=self.chan_grp, chunksize=self.chunksize, i_stop=length,
signal_type='initial')
for pos, sigs_chunk in iterator:
#~ print(seg_num, pos, sigs_chunk.shape)
self.signalprocessor_one_chunk(pos, sigs_chunk, seg_num, detect_peak=detect_peak)
#maybe flush at each loop to avoid memory up but make it slower
#~ self.dataio.flush_processed_signals(seg_num=seg_num, chan_grp=self.chan_grp)
def finalize_signalprocessor_loop(self):
self.arrays.finalize_array('all_peaks')
#~ self._reset_waveform_and_features()
self._reset_arrays(_reset_after_peak_arrays)
self.on_new_cluster()
[docs] def run_signalprocessor(self, duration=60., detect_peak=True):
"""
this run (chunk by chunk), the signal preprocessing chain on
all segments.
The duration can be clip for very long recording. For catalogue
construction the user must have the intuition of how signal we must
have to get enough spike to detect clusters. If the duration is too short clusters
will not be visible. If too long the processing time will be unacceptable.
This totally depend on the dataset (nb channel, spike rate ...)
This also detect peak to avoid to useless access to the storage.
Parameters
----------
duration: float
duration in seconds for each segment
detect_peak: bool (default True)
Also detect peak.
"""
self.arrays.initialize_array('all_peaks', self.memory_mode, _dtype_peak, (-1, ))
#~ for i in range(self.dataio.nb_segment):
#~ self.dataio.reset_processed_signals(seg_num=i, chan_grp=self.chan_grp, dtype=internal_dtype)
for seg_num in range(self.dataio.nb_segment):
self.run_signalprocessor_loop_one_segment(seg_num=seg_num, duration=duration, detect_peak=detect_peak)
self.dataio.flush_processed_signals(seg_num=seg_num, chan_grp=self.chan_grp)
self.finalize_signalprocessor_loop()
[docs] def re_detect_peak(self, peakdetector_engine='numpy', peak_sign='-', relative_threshold=7, peak_span_ms=0.3):
"""
Peak are detected while **run_signalprocessor**.
But in some case for testing other threshold we can **re-detect peak** without signal processing.
Parameters
----------
peakdetector_engine: 'numpy' or 'opencl'
Engine for peak detection.
peak_sign: '-' or '+'
Signa of peak.
relative_threshold: int default 7
Threshold for peak detection. The preprocessed signal have units
expressed in MAD (robust STD). So 7 is MAD*7.
peak_span_ms: float default 0.3
Peak span to avoid double detection. In second.
"""
#TODO if not peak detector in class
self.peak_detector_params = dict(peak_sign=peak_sign, relative_threshold=relative_threshold, peak_span_ms=peak_span_ms)
PeakDetector_class = peakdetector.peakdetector_engines[peakdetector_engine]
self.peakdetector = PeakDetector_class(self.dataio.sample_rate, self.nb_channel,
self.info['chunksize'], self.info['internal_dtype'])
self.peakdetector.change_params(**self.peak_detector_params)
self.info['peak_detector_params'] = self.peak_detector_params
self.flush_info()
self.arrays.initialize_array('all_peaks', self.memory_mode, _dtype_peak, (-1, ))
#TODO clip i_stop with duration ???
for seg_num in range(self.dataio.nb_segment):
self.peakdetector.change_params(**self.peak_detector_params)#this reset the fifo index
iterator = self.dataio.iter_over_chunk(seg_num=seg_num, chan_grp=self.chan_grp,
chunksize=self.info['chunksize'], i_stop=None, signal_type='processed')
for pos, preprocessed_chunk in iterator:
n_peaks, chunk_peaks = self.peakdetector.process_data(pos, preprocessed_chunk)
if chunk_peaks is not None:
peaks = np.zeros(chunk_peaks.size, dtype=_dtype_peak)
peaks['index'] = chunk_peaks
peaks['segment'][:] = seg_num
peaks['cluster_label'][:] = labelcodes.LABEL_NO_WAVEFORM
self.arrays.append_chunk('all_peaks', peaks)
self.arrays.finalize_array('all_peaks')
#~ self._reset_waveform_and_features()
self._reset_arrays(_reset_after_peak_arrays)
self.on_new_cluster()
[docs] def find_good_limits(self, mad_threshold = 1.1, channel_percent=0.3, extract=True, min_left=-5, max_right=5):
"""
Find goods limits for the waveform.
Where the MAD is above noise level (=1.)
The technics constists in finding continuous samples
above 10% of backgroud noise for at least 30% of channels
**Parameters**
mad_threshold: (default 1.1) threshold noise
channel_percent: (default 0.3) percent of channel above this noise.
"""
old_n_left = self.info['waveform_extractor_params']['n_left']
old_n_right = self.info['waveform_extractor_params']['n_right']
median, mad = median_mad(self.some_waveforms, axis = 0)
# any channel above MAD mad_threshold
nb_above = np.sum(mad>=mad_threshold, axis=1)
#~ print('nb_above', nb_above)
#~ print('self.nb_channel*channel_percent', self.nb_channel*channel_percent)
above = nb_above>=self.nb_channel*channel_percent
#find max consequitive point that are True
#~ print('above', above)
up, = np.where(np.diff(above.astype(int))==1)
down, = np.where(np.diff(above.astype(int))==-1)
if len(up)==0 or len(down)==0:
return None, None
else:
up = up[up<max(down)]
down = down[down>min(up)]
if len(up)==0 or len(down)==0:
return None, None
else:
best = np.argmax(down-up)
n_left = int(self.info['waveform_extractor_params']['n_left'] + up[best])
n_right = int(self.info['waveform_extractor_params']['n_left'] + down[best]+1)
#~ print(old_n_left, old_n_right)
#~ print(n_left, n_right)
n_left = min(n_left, min_left)
n_right = max(n_right, max_right)
#~ print(n_left, n_right)
if extract:
self.projector = None
self.extract_some_waveforms(n_left=n_left, n_right=n_right,
index=self.some_peaks_index.copy(), # copy is to avoid reference loop
align_waveform=self.info['waveform_extractor_params']['align_waveform'])
return n_left, n_right
#~ n +=1
#~ print('extract_some_features', self.some_features.shape)
#ALIAS TODO remove it
project = extract_some_features
def apply_projection(self):
assert self.projector is not None
features = self.projector.transform(self.some_waveforms)
#trick to make it persistant
#~ self.arrays.create_array('some_features', self.info['internal_dtype'], features.shape, self.memory_mode)
#~ self.some_features[:] = features
self.arrays.add_array('some_features', some_features.astype(self.info['internal_dtype']), self.memory_mode)
[docs] def find_clusters(self, method='kmeans', selection=None, **kargs):
"""
Find cluster for peaks that have a waveform and feature.
"""
#done in a separate module cluster.py
if selection is not None:
old_labels = np.unique(self.all_peaks['cluster_label'][selection])
#~ print(old_labels)
labels = cluster.find_clusters(self, method=method, selection=selection, **kargs)
if selection is None:
self.on_new_cluster()
self.compute_all_centroid()
#maybe remove this but this is a good practice
self.order_clusters(by='waveforms_rms')
else:
new_labels = np.unique(labels)
for new_label in new_labels:
if new_label not in self.clusters['cluster_label'] and new_label>=0:
self.add_one_cluster(new_label)
if new_label>=0:
self.compute_one_centroid(new_label)
for old_label in old_labels:
ind = self.index_of_label(old_label)
nb_peak = np.sum(self.all_peaks['cluster_label']==old_label)
if nb_peak == 0:
self.pop_labels_from_cluster([old_label])
else:
self.clusters['nb_peak'][ind] = nb_peak
self.compute_one_centroid(old_label)
def on_new_cluster(self):
#~ print('cc.on_new_cluster')
if self.all_peaks is None:
return
cluster_labels = np.unique(self.all_peaks['cluster_label'])
clusters = np.zeros(cluster_labels.shape, dtype=_dtype_cluster)
clusters['cluster_label'][:] = cluster_labels
clusters['cell_label'][:] = cluster_labels
clusters['max_on_channel'][:] = -1
clusters['max_peak_amplitude'][:] = np.nan
clusters['waveform_rms'][:] = np.nan
for i, k in enumerate(cluster_labels):
clusters['nb_peak'][i] = np.sum(self.all_peaks['cluster_label']==k)
if self.clusters is not None:
#get previous _keep_cluster_attr_on_new
for i, c in enumerate(clusters):
#~ print(i, c)
if c['cluster_label'] in self.clusters['cluster_label']:
j = np.nonzero(c['cluster_label']==self.clusters['cluster_label'])[0][0]
for attr in _keep_cluster_attr_on_new:
#~ self.clusters[j]['cell_label'] in cluster_labels
#~ clusters[i]['cell_label'] = self.clusters[j]['cell_label']
clusters[attr][i] = self.clusters[attr][j]
#~ print('j', j)
#~ if clusters.size>0:
self.arrays.add_array('clusters', clusters, self.memory_mode)
def add_one_cluster(self, label):
assert label not in self.clusters['cluster_label']
clusters = np.zeros(self.clusters.size+1, dtype=_dtype_cluster)
if label>=0:
pos_insert = -1
clusters[:-1] = self.clusters
else:
pos_insert = 0
clusters[1:] = self.clusters
clusters['cluster_label'][pos_insert] = label
clusters['cell_label'][pos_insert] = label
clusters['max_on_channel'][pos_insert] = -1
clusters['max_peak_amplitude'][pos_insert] = np.nan
clusters['waveform_rms'][pos_insert] = np.nan
clusters['nb_peak'][pos_insert] = np.sum(self.all_peaks['cluster_label']==label)
self.arrays.add_array('clusters', clusters, self.memory_mode)
for name in _centroids_arrays:
arr = getattr(self, name).copy()
new_arr = np.zeros((arr.shape[0]+1, arr.shape[1], arr.shape[2]), dtype=arr.dtype)
if label>=0:
new_arr[:-1, :, :] = arr
else:
new_arr[1:, :, :] = arr
self.arrays.add_array(name, new_arr, self.memory_mode)
#TODO set one color
self.refresh_colors(reset=False)
self.compute_one_centroid(label)
def index_of_label(self, label):
ind = np.nonzero(self.clusters['cluster_label']==label)[0][0]
return ind
def remove_one_cluster(self, label):
print('WARNING remove_one_cluster')
# This should not be called any more
# because name ambiguous
self.pop_labels_from_cluster([label])
def pop_labels_from_cluster(self, labels):
# this reduce the array clusters by removing some labels
# warning all_peak are touched
if isinstance(labels, int):
labels = [labels]
keep = np.ones(self.clusters.size, dtype='bool')
for k in labels:
ind = self.index_of_label(k)
keep[ind] = False
clusters = self.clusters[keep].copy()
self.arrays.add_array('clusters', clusters, self.memory_mode)
for name in _centroids_arrays:
new_arr = getattr(self, name)[keep, :, :].copy()
self.arrays.add_array(name, new_arr, self.memory_mode)
def move_cluster_to_trash(self, labels):
if isinstance(labels, int):
labels = [labels]
mask = np.zeros(self.all_peaks.size, dtype='bool')
for k in labels:
mask |= self.all_peaks['cluster_label']== k
self.change_spike_label(mask, labelcodes.LABEL_TRASH)
def compute_one_centroid(self, k, flush=True):
#~ t1 = time.perf_counter()
ind = self.index_of_label(k)
n_left = int(self.info['waveform_extractor_params']['n_left'])
wf = self.some_waveforms[self.all_peaks['cluster_label'][self.some_peaks_index]==k]
median, mad = median_mad(wf, axis = 0)
mean, std = np.mean(wf, axis=0), np.std(wf, axis=0)
max_on_channel = np.argmax(np.abs(median[-n_left,:]), axis=0)
# to persistant arrays
self.centroids_median[ind, :, :] = median
self.centroids_mad[ind, :, :] = mad
self.centroids_mean[ind, :, :] = mean
self.centroids_std[ind, :, :] = std
self.clusters['max_on_channel'][ind] = max_on_channel
self.clusters['max_peak_amplitude'][ind] = median[-n_left, max_on_channel]
self.clusters['waveform_rms'][ind] = np.sqrt(np.mean(median**2))
if flush:
for name in ('clusters',) + _centroids_arrays:
self.arrays.flush_array(name)
#~ t2 = time.perf_counter()
#~ print('compute_one_centroid',k, t2-t1)
def compute_all_centroid(self):
t1 = time.perf_counter()
if self.some_waveforms is None:
for name in _centroids_arrays:
self.arrays.detach_array(name)
setattr(self, name, None)
return
n_left = int(self.info['waveform_extractor_params']['n_left'])
n_right = int(self.info['waveform_extractor_params']['n_right'])
for name in _centroids_arrays:
empty = np.zeros((self.cluster_labels.size, n_right - n_left, self.nb_channel), dtype=self.info['internal_dtype'])
self.arrays.add_array(name, empty, self.memory_mode)
#~ t1 = time.perf_counter()
for k in self.cluster_labels:
if k <0: continue
self.compute_one_centroid(k, flush=False)
for name in ('clusters',) + _centroids_arrays:
self.arrays.flush_array(name)
#~ self.arrays.flush_array('clusters')
#~ t2 = time.perf_counter()
#~ print('compute_all_centroid', t2-t1)
def refresh_colors(self, reset=True, palette='husl', interleaved=True):
labels = self.positive_cluster_labels
if reset:
n = labels.size
if interleaved and n>1:
n1 = np.floor(np.sqrt(n))
n2 = np.ceil(n/n1)
n = int(n1*n2)
n1, n2 = int(n1), int(n2)
else:
n = np.sum((self.clusters['cluster_label']>=0) & (self.clusters['color']==0))
if n>0:
colors_int32 = np.array([rgba_to_int32(r,g,b) for r,g,b in sns.color_palette(palette, n)])
if reset and interleaved and n>1:
colors_int32 = colors_int32.reshape(n1, n2).T.flatten()
colors_int32 = colors_int32[:labels.size]
if reset:
mask = self.clusters['cluster_label']>=0
self.clusters['color'][mask] = colors_int32
else:
mask = (self.clusters['cluster_label']>=0) & (self.clusters['color']==0)
self.clusters['color'][mask] = colors_int32
#Make colors accessible by key
self.colors = make_color_dict(self.clusters)
def change_spike_label(self, mask, label):
is_new = label not in self.clusters['cluster_label']
label_changed = np.unique(self.all_peaks['cluster_label'][mask]).tolist()
self.all_peaks['cluster_label'][mask] = label
if is_new:
self.add_one_cluster(label) # this also compute centroid
else:
label_changed = label_changed + [label]
to_remove = []
for k in label_changed:
ind = self.index_of_label(k)
nb_peak = np.sum(self.all_peaks['cluster_label']==k)
self.clusters['nb_peak'][ind] = nb_peak
if k>=0:
if nb_peak>0:
self.compute_one_centroid(k)
else:
to_remove.append(k)
self.pop_labels_from_cluster(to_remove)
self.arrays.flush_array('all_peaks')
self.arrays.flush_array('clusters')
[docs] def split_cluster(self, label, method='kmeans', **kargs):
"""
This split one cluster by applying a new clustering method only on the subset.
"""
mask = self.all_peaks['cluster_label']==label
self.find_clusters(method=method, selection=mask, **kargs)
def trash_small_cluster(self, n=10):
to_remove = []
for k in list(self.cluster_labels):
mask = self.all_peaks['cluster_label']==k
if np.sum(mask)<=n:
self.all_peaks['cluster_label'][mask] = -1
to_remove.append(k)
self.pop_labels_from_cluster(to_remove)
def compute_cluster_similarity(self, method='cosine_similarity_with_max'):
if self.centroids_median is None:
self.compute_all_centroid()
#~ t1 = time.perf_counter()
labels = self.cluster_labels
mask = labels>=0
wfs = self.centroids_median[mask, :, :]
wfs = wfs.reshape(wfs.shape[0], -1)
if wfs.size == 0:
cluster_similarity = None
else:
cluster_similarity = metrics.cosine_similarity_with_max(wfs)
if cluster_similarity is None:
self.arrays.detach_array('cluster_similarity')
self.cluster_similarity = None
else:
self.arrays.add_array('cluster_similarity', cluster_similarity.astype('float32'), self.memory_mode)
#~ t2 = time.perf_counter()
#~ print('compute_cluster_similarity', t2-t1)
def detect_high_similarity(self, threshold=0.95):
if self.cluster_similarity is None:
self.compute_cluster_similarity()
pairs = get_pairs_over_threshold(self.cluster_similarity, self.positive_cluster_labels, threshold)
return pairs
[docs] def auto_merge_high_similarity(self, threshold=0.95):
"""Recursively merge all pairs with similarity hihger that a given threshold
"""
pairs = self.detect_high_similarity(threshold=threshold)
already_merge = {}
for k1, k2 in pairs:
# merge if k2 still exists
if k1 in already_merge:
k1 = already_merge[k1]
print('auto_merge', k1, 'with', k2)
mask = self.all_peaks['cluster_label'] == k2
self.all_peaks['cluster_label'][mask] = k1
already_merge[k2] = k1
self.pop_labels_from_cluster([k2])
def compute_cluster_ratio_similarity(self, method='cosine_similarity_with_max'):
#~ print('compute_cluster_ratio_similarity')
if self.centroids_median is None:
self.compute_all_centroid()
#~ if not hasattr(self, 'centroids'):
#~ self.compute_centroid()
#~ t1 = time.perf_counter()
labels = self.positive_cluster_labels
#TODO: this is stupid because cosine_similarity is the same at every scale!!!
#so this is identique to compute_similarity()
#find something else
wf_normed = []
for ind, k in enumerate(self.clusters['cluster_label']):
if k<0: continue
#~ chan = self.centroids[k]['max_on_channel']
chan = self.clusters['max_on_channel'][ind]
#~ median = self.centroids[k]['median']
median = self.centroids_median[ind, :, :]
n_left = int(self.info['waveform_extractor_params']['n_left'])
wf_normed.append(median/np.abs(median[-n_left, chan]))
wf_normed = np.array(wf_normed)
if wf_normed.size == 0:
cluster_ratio_similarity = None
else:
wf_normed_flat = wf_normed.swapaxes(1, 2).reshape(wf_normed.shape[0], -1)
#~ cluster_ratio_similarity = metrics.compute_similarity(wf_normed_flat, 'cosine_similarity')
cluster_ratio_similarity = metrics.cosine_similarity_with_max(wf_normed_flat)
if cluster_ratio_similarity is None:
self.arrays.detach_array('cluster_ratio_similarity')
self.cluster_ratio_similarity = None
else:
self.arrays.add_array('cluster_ratio_similarity', cluster_ratio_similarity.astype('float32'), self.memory_mode)
#~ return labels, ratio_similarity, wf_normed_flat
#~ t2 = time.perf_counter()
#~ print('compute_cluster_ratio_similarity', t2-t1)
def detect_similar_waveform_ratio(self, threshold=0.9):
if self.cluster_ratio_similarity is None:
self.compute_cluster_ratio_similarity()
pairs = get_pairs_over_threshold(self.cluster_ratio_similarity, self.positive_cluster_labels, threshold)
return pairs
def clean_cluster(self, too_small=10):
self.trash_small_cluster(n=too_small)
def compute_spike_silhouette(self, size_max=1e7):
#~ t1 = time.perf_counter()
spike_silhouette = None
wf = self.some_waveforms
if wf is not None:
wf = wf.reshape(wf.shape[0], -1)
labels = self.all_peaks['cluster_label'][self.some_peaks_index]
if wf.size<size_max:
spike_silhouette = metrics.compute_silhouette(wf, labels, metric='euclidean')
if spike_silhouette is None:
self.arrays.detach_array('spike_silhouette')
self.spike_silhouette = None
else:
self.arrays.add_array('spike_silhouette', spike_silhouette.astype('float32'), self.memory_mode)
#~ t2 = time.perf_counter()
#~ print('compute_spike_silhouette', t2-t1)
[docs] def tag_same_cell(self, labels_to_group):
"""
In some situation in spike burst the amplitude change baldly.
In theses cases we can have 2 distinct cluster but the user suspect
that it is the same cell. In such cases the user must not merge
the 2 clusters because the centroids will represent nothing and
and the Peeler.will fail.
Instead we tag the 2 differents cluster as "same cell"
so same cell_label.
This is a manual action.
"""
inds, = np.nonzero(np.in1d(self.clusters['cluster_label'], labels_to_group))
self.clusters['cell_label'][inds] = min(labels_to_group)
[docs] def order_clusters(self, by='waveforms_rms'):
"""
This reorder labels from highest rms to lower rms.
The higher rms the smaller label.
Negative labels are not reassigned.
"""
if self.centroids_median is None:
self.compute_all_centroid()
if np.any(np.isnan(self.clusters[self.clusters['cluster_label']>=0]['waveform_rms'])):
# MUST compute centroids
#~ print('MUST compute centroids because nan')
#~ self.compute_centroid()
for ind, k in enumerate(self.clusters['cluster_label']):
if k<0:
continue
if np.isnan(self.clusters[ind]['waveform_rms']):
self.compute_one_centroid(k)
clusters = self.clusters.copy()
neg_clusters = clusters[clusters['cluster_label']<0]
pos_clusters = clusters[clusters['cluster_label']>=0]
#~ print('order_clusters', by)
if by=='waveforms_rms':
order = np.argsort(pos_clusters['waveform_rms'])[::-1]
elif by=='max_peak_amplitude':
order = np.argsort(np.abs(pos_clusters['max_peak_amplitude']))[::-1]
else:
raise(NotImplementedError)
sorted_labels = pos_clusters['cluster_label'][order]
#reassign labels for peaks and clusters
if len(sorted_labels)>0:
N = int(max(sorted_labels)*10)
else:
N = 0
self.all_peaks['cluster_label'][self.all_peaks['cluster_label']>=0] += N
for new, old in enumerate(sorted_labels+N):
self.all_peaks['cluster_label'][self.all_peaks['cluster_label']==old] = new
pos_clusters = pos_clusters[order].copy()
n = pos_clusters.size
pos_clusters['cluster_label'] = np.arange(n)
#this shoudl preserve previous identique cell_label
pos_clusters['cell_label'] += N
for i in range(n):
k = pos_clusters['cell_label'][i]
inds, = np.nonzero(pos_clusters['cell_label']==k)
if (len(inds)>=1) and (inds[0] == i):
pos_clusters['cell_label'][inds] = i
new_cluster = np.concatenate((neg_clusters, pos_clusters))
self.clusters[:] = new_cluster
# reasign centroids
mask_pos= (self.clusters['cluster_label']>=0)
for name in _centroids_arrays:
arr = getattr(self, name)
arr_pos = arr[mask_pos, :, :].copy()
arr[mask_pos, :, :] = arr_pos[order, :, :]
self.refresh_colors(reset=True)
#~ self._reset_metrics()
self._reset_arrays(_persistent_metrics)
def make_catalogue(self):
#TODO: offer possibility to resample some waveforms or choose the number
#~ t1 = time.perf_counter()
self.catalogue = {}
self.catalogue = {}
self.catalogue['chan_grp'] = self.chan_grp
n_left = self.catalogue['n_left'] = int(self.info['waveform_extractor_params']['n_left'] +2)
self.catalogue['n_right'] = int(self.info['waveform_extractor_params']['n_right'] -2)
self.catalogue['peak_width'] = self.catalogue['n_right'] - self.catalogue['n_left']
#for colors
self.refresh_colors(reset=False)
keep = self.cluster_labels>=0
cluster_labels = np.array(self.cluster_labels[keep], copy=True)
# TODO this is redundant with clusters, this shoudl be removed but imply some change in peeler.
self.catalogue['cluster_labels'] = cluster_labels
self.catalogue['clusters'] = self.clusters[keep].copy()
n, full_width, nchan = self.some_waveforms.shape
centers0 = np.zeros((len(cluster_labels), full_width - 4, nchan), dtype=self.info['internal_dtype'])
centers1 = np.zeros_like(centers0)
centers2 = np.zeros_like(centers0)
self.catalogue['centers0'] = centers0 # median of wavforms
self.catalogue['centers1'] = centers1 # median of first derivative of wavforms
self.catalogue['centers2'] = centers2 # median of second derivative of wavforms
subsample = np.arange(1.5, full_width-2.5, 1/20.)
self.catalogue['subsample_ratio'] = 20
interp_centers0 = np.zeros((len(cluster_labels), subsample.size, nchan), dtype=self.info['internal_dtype'])
self.catalogue['interp_centers0'] = interp_centers0
#~ print('peak_width', self.catalogue['peak_width'])
self.catalogue['label_to_index'] = {}
for i, k in enumerate(cluster_labels):
self.catalogue['label_to_index'][k] = i
#print('construct_catalogue', k)
# take peak of this cluster
# and reshaape (nb_peak, nb_channel, nb_csample)
#~ wf = self.some_waveforms[self.all_peaks['cluster_label']==k]
wf0 = self.some_waveforms[self.all_peaks['cluster_label'][self.some_peaks_index]==k]
#~ wf0 = wf0.copy()
#~ print(wf0.shape, wf0.size)
# compute first and second derivative on dim=1 (time)
# Note this was the old implementation but too slow
#~ kernel = np.array([1,0,-1])/2.
#~ kernel = kernel[None, :, None]
#~ wf1 = scipy.signal.fftconvolve(wf0,kernel,'same') # first derivative
#~ wf2 = scipy.signal.fftconvolve(wf1,kernel,'same') # second derivative
# this is teh new one
wf1 = np.zeros_like(wf0)
wf1[:, 1:-1, :] = (wf0[:, 2:,: ] - wf0[:, :-2,: ])/2.
wf2 = np.zeros_like(wf1)
wf2[:, 1:-1, :] = (wf1[:, 2:,: ] - wf1[:, :-2,: ])/2.
#median and
#eliminate margin because of border effect of derivative and reshape
center0 = np.median(wf0, axis=0)
centers0[i,:,:] = center0[2:-2, :]
centers1[i,:,:] = np.median(wf1, axis=0)[2:-2, :]
centers2[i,:,:] = np.median(wf2, axis=0)[2:-2, :]
#~ center0 = np.mean(wf0, axis=0)
#~ centers0[i,:,:] = center0[2:-2, :]
#~ centers1[i,:,:] = np.mean(wf1, axis=0)[2:-2, :]
#~ centers2[i,:,:] = np.mean(wf2, axis=0)[2:-2, :]
#interpolate centers0 for reconstruction inbetween bsample when jitter is estimated
f = scipy.interpolate.interp1d(np.arange(full_width), center0, axis=0, kind='cubic')
oversampled_center = f(subsample)
interp_centers0[i, :, :] = oversampled_center
#~ fig, ax = plt.subplots()
#~ ax.plot(np.arange(full_width-4), center0[2:-2, :], color='b', marker='o')
#~ ax.plot(subsample-2.,oversampled_center, color='c')
#~ plt.show()
#find max channel for each cluster for peak alignement
self.catalogue['max_on_channel'] = np.zeros_like(self.catalogue['cluster_labels'])
for i, k in enumerate(cluster_labels):
center = self.catalogue['centers0'][i,:,:]
self.catalogue['max_on_channel'][i] = np.argmax(np.abs(center[-n_left,:]), axis=0)
#colors
#~ self.catalogue['cluster_colors'] = {}
#~ self.catalogue['cluster_colors'].update(self.colors)
#params
self.catalogue['signal_preprocessor_params'] = dict(self.info['signal_preprocessor_params'])
self.catalogue['peak_detector_params'] = dict(self.info['peak_detector_params'])
self.catalogue['clean_waveforms_params'] = dict(self.info['clean_waveforms_params'])
self.catalogue['signals_medians'] = np.array(self.signals_medians, copy=True)
self.catalogue['signals_mads'] = np.array(self.signals_mads, copy=True)
#~ t2 = time.perf_counter()
#~ print('make_catalogue', t2-t1)
return self.catalogue
[docs] def make_catalogue_for_peeler(self):
"""
Make and save catalogue in the working dir for the Peeler.
"""
self.make_catalogue()
self.dataio.save_catalogue(self.catalogue, name='initial')
[docs] def create_savepoint(self):
"""this create a copy of the entire catalogue_constructor subdir
Usefull for the UI when the user wants to snapshot and try tricky merge/split.
"""
copy_path = self.catalogue_path + '_SAVEPOINT_{:%Y-%m-%d_%Hh%Mm%S}'.format(datetime.datetime.now())
if not os.path.exists(copy_path):
shutil.copytree(self.catalogue_path, copy_path)
return copy_path