"""
.. autoclass:: Peeler
:members:
"""
import os
import json
from collections import OrderedDict, namedtuple
import time
import numpy as np
import scipy.signal
from .peeler_tools import _dtype_spike
from tqdm import tqdm
#~ from .peeler_engine_testing import PeelerEngineTesting
from .peeler_engine_geometry import PeelerEngineGeometrical
from .peeler_engine_geometry_cl import PeelerEngineGeometricalCl
peeler_engines = {
#~ 'testing' : PeelerEngineTesting,
'geometrical' : PeelerEngineGeometrical,
'geometrical_opencl' : PeelerEngineGeometricalCl,
}
[docs]class Peeler:
"""
The peeler is core of spike sorting itself.
It basically do a *template matching* on a signals.
This class nedd a *catalogue* constructed by :class:`CatalogueConstructor`.
Then the compting is applied chunk chunk on the raw signal itself.
So this class is the same for both offline/online computing.
At each chunk, the algo is basically this one:
1. apply the processing chain (filter, normamlize, ....)
2. Detect peaks
3. Try to classify peak and detect the *jitter*
4. With labeled peak create a prediction for the chunk
5. Substract the prediction from the processed signals.
6. Go back to **2** until there is no peak or only peaks that can't be labeled.
7. return labeld spikes from this or previous chunk and the processed signals (for display or recoding)
The main difficulty in the implemtation is to deal with edge because spikes
waveforms can spread out in between 2 chunk.
Note that the global latency depend on this paramters:
* pad_width
* chunksize
"""
def __init__(self, dataio):
#for online dataio is None
self.dataio = dataio
def __repr__(self):
t = "Peeler <id: {}> \n workdir: {}\n".format(id(self), self.dataio.dirname)
return t
[docs] def change_params(self, catalogue=None, engine='geometrical', internal_dtype='float32',
chunksize=1024, speed_test_mode=False, **params):
"""
speed_test_mode: only for offline mode create a log file with
run time for each buffers
"""
assert catalogue is not None
self.catalogue = catalogue
self.engine_name = engine
self.internal_dtype = internal_dtype
self.chunksize = chunksize
self.speed_test_mode = speed_test_mode
self.peeler_engine = peeler_engines[engine]()
self.peeler_engine.change_params(catalogue=catalogue, internal_dtype=internal_dtype, chunksize=chunksize, **params)
def process_one_chunk(self, pos, sigs_chunk):
# this is for online
return self.peeler_engine.process_one_chunk(pos, sigs_chunk)
#~ abs_head_index, preprocessed_chunk, self.total_spike, all_spikes, = self.peeler_engine.process_one_chunk(pos, sigs_chunk)
#~ print(pos, sigs_chunk.shape, abs_head_index, preprocessed_chunk.shape)
#~ return abs_head_index, preprocessed_chunk, self.total_spike, all_spikes
def initialize_online_loop(self, sample_rate=None, nb_channel=None, source_dtype=None, geometry=None):
# global initialize
self.peeler_engine.initialize(sample_rate=sample_rate, nb_channel=nb_channel,
source_dtype=source_dtype, already_processed=False, geometry=geometry)
self.peeler_engine.initialize_before_each_segment(already_processed=False)
def run_offline_loop_one_segment(self, seg_num=0, duration=None, progressbar=True):
chan_grp = self.catalogue['chan_grp']
if duration is not None:
length = int(duration*self.dataio.sample_rate)
else:
length = self.dataio.get_segment_length(seg_num)
# check if the desired length is already computed or not for this particular segment
already_processed = self.dataio.already_processed(seg_num=seg_num, chan_grp=chan_grp, length=length)
self.peeler_engine.initialize_before_each_segment(already_processed=already_processed)
#~ print('run_offline_loop_one_segment already_processed', already_processed)
if already_processed:
# ready from "processed'
signal_type = 'processed'
else:
# read from "initial"
# activate signal processor
signal_type = 'initial'
#initialize engines
self.dataio.reset_processed_signals(seg_num=seg_num, chan_grp=chan_grp, dtype=self.internal_dtype)
self.dataio.reset_spikes(seg_num=seg_num, chan_grp=chan_grp, dtype=_dtype_spike)
iterator = self.dataio.iter_over_chunk(seg_num=seg_num, chan_grp=chan_grp, chunksize=self.chunksize,
i_stop=length, signal_type=signal_type)
if progressbar:
iterator = tqdm(iterable=iterator, total=length//self.chunksize)
if self.speed_test_mode:
process_run_times = []
for pos, sigs_chunk in iterator:
if self.speed_test_mode:
t0 = time.perf_counter()
sig_index, preprocessed_chunk, total_spike, spikes = self.peeler_engine.process_one_chunk(pos, sigs_chunk)
if self.speed_test_mode:
t1 = time.perf_counter()
process_run_times.append(t1-t0)
if sig_index<=0:
continue
if not already_processed:
# save preprocessed_chunk to file
self.dataio.set_signals_chunk(preprocessed_chunk, seg_num=seg_num,chan_grp=chan_grp,
i_start=sig_index-preprocessed_chunk.shape[0], i_stop=sig_index,
signal_type='processed')
if spikes is not None and spikes.size>0:
self.dataio.append_spikes(seg_num=seg_num, chan_grp=chan_grp, spikes=spikes)
extra_spikes = self.peeler_engine.get_remaining_spikes()
if extra_spikes is not None:
if extra_spikes.size>0:
self.dataio.append_spikes(seg_num=seg_num, chan_grp=chan_grp, spikes=extra_spikes)
if not already_processed:
self.dataio.flush_processed_signals(seg_num=seg_num, chan_grp=chan_grp, processed_length=int(sig_index))
self.dataio.flush_spikes(seg_num=seg_num, chan_grp=chan_grp)
if self.speed_test_mode:
process_run_times = np.array(process_run_times, dtype='float64')
log_path = self.dataio.get_log_path(chan_grp=chan_grp)
filename = os.path.join(log_path, 'peeler_run_times_seg{}.npy'.format(seg_num))
np.save(filename, process_run_times)
def run(self, duration=None, progressbar=True):
assert hasattr(self, 'catalogue'), 'So peeler.change_params first'
chan_grp = self.catalogue['chan_grp']
duration_per_segment = self.dataio.get_duration_per_segments(duration)
already_processed_segs = []
for seg_num in range(self.dataio.nb_segment):
length = int(duration_per_segment[seg_num]*self.dataio.sample_rate)
# check if the desired length is already computed or not
already_processed = self.dataio.already_processed(seg_num=seg_num, chan_grp=chan_grp, length=length)
already_processed_segs.append(already_processed)
kargs = {}
kargs['sample_rate'] = self.dataio.sample_rate
kargs['nb_channel'] = self.dataio.nb_channel(chan_grp)
if any(already_processed_segs):
kargs['source_dtype'] = self.internal_dtype
else:
kargs['source_dtype'] = self.dataio.source_dtype
kargs['geometry'] = self.dataio.get_geometry(chan_grp)
kargs['already_processed'] = all(already_processed_segs)
self.peeler_engine.initialize(**kargs)
for seg_num in range(self.dataio.nb_segment):
self.run_offline_loop_one_segment(seg_num=seg_num, duration=duration_per_segment[seg_num], progressbar=progressbar)
# old alias just in case
run_offline_all_segment = run
[docs] def get_run_times(self, chan_grp=0, seg_num=0):
"""
need speed_test_mode=True in params
"""
p = self.dataio.get_log_path(chan_grp=chan_grp)
filename = os.path.join(p, 'peeler_run_times_seg{}.npy'.format(seg_num))
run_times = np.load(filename)
return run_times