"""
.. autoclass:: Peeler
:members:
"""
import os
import json
from collections import OrderedDict, namedtuple
import time
import numpy as np
import scipy.signal
from .peeler_tools import _dtype_spike
from tqdm import tqdm
from .peeler_engine_classic import PeelerEngineClassic
from .peeler_engine_oldclassic import PeelerEngineOldClassic
from .peeler_engine_classic_cl import PeelerEngineClassicOpenCl
from .peeler_engine_testing import PeelerEngineTesting
from .peeler_engine_parallel import PeelerEngineParallel
from .peeler_engine_geometry import PeelerEngineGeometrical
#~ from .peeler_engine_strict import PeelerEngineStrict
peeler_engines = {
'classic' : PeelerEngineClassic,
# 'classic_opencl' : PeelerEngineClassicOpenCl, ### not working
'classic_old' : PeelerEngineOldClassic,
'testing' : PeelerEngineTesting,
# 'classic_parallel' : PeelerEngineParallel, ### not working
'geometrical' : PeelerEngineGeometrical,
}
[docs]class Peeler:
"""
The peeler is core of spike sorting itself.
It basically do a *template matching* on a signals.
This class nedd a *catalogue* constructed by :class:`CatalogueConstructor`.
Then the compting is applied chunk chunk on the raw signal itself.
So this class is the same for both offline/online computing.
At each chunk, the algo is basically this one:
1. apply the processing chain (filter, normamlize, ....)
2. Detect peaks
3. Try to classify peak and detect the *jitter*
4. With labeled peak create a prediction for the chunk
5. Substract the prediction from the processed signals.
6. Go back to **2** until there is no peak or only peaks that can't be labeled.
7. return labeld spikes from this or previous chunk and the processed signals (for display or recoding)
The main difficulty in the implemtation is to deal with edge because spikes
waveforms can spread out in between 2 chunk.
Note that the global latency depend on this é paramters:
* lostfront_chunksize
* chunksize
"""
def __init__(self, dataio):
#for online dataio is None
self.dataio = dataio
def __repr__(self):
t = "Peeler <id: {}> \n workdir: {}\n".format(id(self), self.dataio.dirname)
return t
def change_params(self, catalogue=None, engine='classic', internal_dtype='float32', chunksize=1024, **params):
assert catalogue is not None
self.catalogue = catalogue
self.internal_dtype = internal_dtype
self.chunksize = chunksize
self.engine_name = engine
self.peeler_engine = peeler_engines[engine]()
self.peeler_engine.change_params(catalogue=catalogue, internal_dtype=internal_dtype, chunksize=chunksize, **params)
def process_one_chunk(self, pos, sigs_chunk):
return self.peeler_engine.process_one_chunk(pos, sigs_chunk)
def initialize_online_loop(self, sample_rate=None, nb_channel=None, source_dtype=None):
self.peeler_engine.initialize_before_each_segment(sample_rate=sample_rate, nb_channel=nb_channel, source_dtype=source_dtype)
def run_offline_loop_one_segment(self, seg_num=0, duration=None, progressbar=True):
chan_grp = self.catalogue['chan_grp']
kargs = {}
kargs['sample_rate'] = self.dataio.sample_rate
kargs['nb_channel'] = self.dataio.nb_channel(chan_grp)
kargs['source_dtype'] = self.dataio.source_dtype
kargs['geometry'] = self.dataio.get_geometry(chan_grp)
self.peeler_engine.initialize_before_each_segment(**kargs)
if duration is not None:
length = int(duration*self.dataio.sample_rate)
else:
length = self.dataio.get_segment_length(seg_num)
#~ length -= length%self.chunksize
#initialize engines
self.dataio.reset_processed_signals(seg_num=seg_num, chan_grp=chan_grp, dtype=self.internal_dtype)
self.dataio.reset_spikes(seg_num=seg_num, chan_grp=chan_grp, dtype=_dtype_spike)
iterator = self.dataio.iter_over_chunk(seg_num=seg_num, chan_grp=chan_grp, chunksize=self.chunksize,
i_stop=length, signal_type='initial')
if progressbar:
iterator = tqdm(iterable=iterator, total=length//self.chunksize)
for pos, sigs_chunk in iterator:
#~ sig_index, preprocessed_chunk, total_spike, spikes = self.process_one_chunk(pos, sigs_chunk)
sig_index, preprocessed_chunk, total_spike, spikes = self.peeler_engine.process_one_chunk(pos, sigs_chunk)
if sig_index<=0:
continue
# save preprocessed_chunk to file
self.dataio.set_signals_chunk(preprocessed_chunk, seg_num=seg_num,chan_grp=chan_grp,
i_start=sig_index-preprocessed_chunk.shape[0], i_stop=sig_index,
signal_type='processed')
if spikes is not None and spikes.size>0:
self.dataio.append_spikes(seg_num=seg_num, chan_grp=chan_grp, spikes=spikes)
extra_spikes = self.peeler_engine.get_remaining_spikes()
if extra_spikes is not None:
if extra_spikes.size>0:
self.dataio.append_spikes(seg_num=seg_num, chan_grp=chan_grp, spikes=extra_spikes)
self.dataio.flush_processed_signals(seg_num=seg_num, chan_grp=chan_grp)
self.dataio.flush_spikes(seg_num=seg_num, chan_grp=chan_grp)
def run_offline_all_segment(self, **kargs):
assert hasattr(self, 'catalogue'), 'So peeler.change_params first'
for seg_num in range(self.dataio.nb_segment):
self.run_offline_loop_one_segment(seg_num=seg_num, **kargs)
run = run_offline_all_segment