Source code for tridesclous.gui.waveformviewer

from .myqt import QT
import pyqtgraph as pg

import numpy as np
import pandas as pd

from .base import WidgetBase


class MyViewBox(pg.ViewBox):
    doubleclicked = QT.pyqtSignal()
    gain_zoom = QT.pyqtSignal(float)
    def __init__(self, *args, **kwds):
        pg.ViewBox.__init__(self, *args, **kwds)
        #~ self.disableAutoRange()
    def mouseClickEvent(self, ev):
        ev.accept()
    def mouseDoubleClickEvent(self, ev):
        self.doubleclicked.emit()
        ev.accept()
    #~ def mouseDragEvent(self, ev):
        #~ ev.ignore()
    def wheelEvent(self, ev, axis=None):
        if ev.modifiers() == QT.Qt.ControlModifier:
            z = 10 if ev.delta()>0 else 1/10.
        else:
            z = 1.3 if ev.delta()>0 else 1/1.3
        self.gain_zoom.emit(z)
        ev.accept()



class WaveformViewerBase(WidgetBase):
    #base for both WaveformViewer (Catalogue) and PeelerWaveformViewer
    

    
    def __init__(self, controller=None, parent=None):
        WidgetBase.__init__(self, parent=parent, controller=controller)
        
        self.layout = QT.QVBoxLayout()
        self.setLayout(self.layout)
        
        #~ self.create_settings()
        
        self.create_toolbar()
        self.layout.addWidget(self.toolbar)

        self.graphicsview = pg.GraphicsView()
        self.layout.addWidget(self.graphicsview)
        self.initialize_plot()
        
        self.alpha = 60
        self.refresh()
    
    def create_toolbar(self):
        tb = self.toolbar = QT.QToolBar()
        
        #Mode flatten or geometry
        self.combo_mode = QT.QComboBox()
        tb.addWidget(self.combo_mode)
        #~ self.mode = 'flatten'
        #~ self.combo_mode.addItems([ 'flatten', 'geometry'])
        self.mode = 'geometry'
        self.combo_mode.addItems([ 'geometry', 'flatten'])
        self.combo_mode.currentIndexChanged.connect(self.on_combo_mode_changed)
        tb.addSeparator()
        
        
        but = QT.QPushButton('settings')
        but.clicked.connect(self.open_settings)
        tb.addWidget(but)

        but = QT.QPushButton('scale')
        but.clicked.connect(self.zoom_range)
        tb.addWidget(but)

        but = QT.QPushButton('refresh')
        but.clicked.connect(self.refresh)
        tb.addWidget(but)
    
    def on_combo_mode_changed(self):
        self.mode = str(self.combo_mode.currentText())
        self.initialize_plot()
        self.refresh()
    
    def on_params_changed(self, params, changes):
        for param, change, data in changes:
            if change != 'value': continue
            if param.name()=='flip_bottom_up':
                self.initialize_plot()
        self.refresh()

    def initialize_plot(self):
        #~ print('WaveformViewer.initialize_plot', self.controller.some_waveforms)
        if self.controller.get_waveforms_shape() is None:
            return
        self.viewBox1 = MyViewBox()
        self.viewBox1.disableAutoRange()

        grid = pg.GraphicsLayout(border=(100,100,100))
        self.graphicsview.setCentralItem(grid)
        
        
        self.plot1 = grid.addPlot(row=0, col=0, rowspan=2, viewBox=self.viewBox1)
        self.plot1.hideButtons()
        self.plot1.showAxis('left', True)

        self.curve_one_waveform = pg.PlotCurveItem([], [], pen=pg.mkPen(QT.QColor( 'white'), width=1), connect='finite')
        self.plot1.addItem(self.curve_one_waveform)
        
        if self.mode=='flatten':
            grid.nextRow()
            grid.nextRow()
            self.viewBox2 = MyViewBox()
            self.viewBox2.disableAutoRange()
            self.plot2 = grid.addPlot(row=2, col=0, rowspan=1, viewBox=self.viewBox2)
            self.plot2.hideButtons()
            self.plot2.showAxis('left', True)
            self.viewBox2.setXLink(self.viewBox1)
            self.factor_y = 1.

        elif self.mode=='geometry':
            self.plot2 = None
            
            chan_grp = self.controller.chan_grp
            channel_group = self.controller.dataio.channel_groups[chan_grp]
            #~ print(channel_group['geometry'])
            if channel_group['geometry'] is None:
                print('no geometry')
                self.xvect = None
            else:

                shape = self.controller.get_waveforms_shape()
                width = shape[0]
                
                self.xvect = np.zeros(shape[0]*shape[1], dtype='float32')

                self.arr_geometry = []
                for i, chan in enumerate(self.controller.channel_indexes):
                    x, y = channel_group['geometry'][chan]
                    self.arr_geometry.append([x, y])
                self.arr_geometry = np.array(self.arr_geometry, dtype='float64')
                
                if self.params['flip_bottom_up']:
                    self.arr_geometry[:, 1] *= -1.
                
                xpos = self.arr_geometry[:,0]
                ypos = self.arr_geometry[:,1]
                
                if np.unique(xpos).size>1:
                    self.delta_x = np.min(np.diff(np.sort(np.unique(xpos))))
                else:
                    self.delta_x = np.unique(xpos)[0]
                if np.unique(ypos).size>1:
                    self.delta_y = np.min(np.diff(np.sort(np.unique(ypos))))
                else:
                    self.delta_y = np.unique(ypos)[0]
                self.factor_y = .3
                if self.delta_x>0.:
                    #~ espx = self.delta_x/2. *.95
                    espx = self.delta_x/2.5
                else:
                    espx = .5
                for i, chan in enumerate(channel_group['channels']):
                    x, y = channel_group['geometry'][chan]
                    self.xvect[i*width:(i+1)*width] = np.linspace(x-espx, x+espx, num=width)
        
        self.wf_min, self.wf_max = self.controller.get_min_max_centroids()
        
        
        self._x_range = None
        self._y1_range = None
        self._y2_range = None
        
        self.viewBox1.gain_zoom.connect(self.gain_zoom)
        
        self.viewBox1.doubleclicked.connect(self.open_settings)
        
        #~ self.viewBox.xsize_zoom.connect(self.xsize_zoom)    
    

    def gain_zoom(self, factor_ratio):
        self.factor_y *= factor_ratio
        
        self.refresh()
    
    def zoom_range(self):
        self._x_range = None
        self._y1_range = None
        self._y2_range = None
        self.refresh()
    def refresh(self):
        
        if not hasattr(self, 'viewBox1'):
            self.initialize_plot()
        
        if not hasattr(self, 'viewBox1'):
            return
        
        n_selected = np.sum(self.controller.spike_selection)
        
        if self.params['show_only_selected_cluster'] and n_selected==1:
            cluster_visible = {k:False for k in self.controller.cluster_visible}
            ind, = np.nonzero(self.controller.spike_selection)
            ind = ind[0]
            k = self.controller.spikes[ind]['cluster_label']
            cluster_visible[k] = True
        else:
            cluster_visible = self.controller.cluster_visible
        
        if self.mode=='flatten':
            self.refresh_mode_flatten(cluster_visible)
        elif self.mode=='geometry':
            self.refresh_mode_geometry(cluster_visible)
        
        self._refresh_one_spike(n_selected)
    
    
    def refresh_mode_flatten(self, cluster_visible):
        if self._x_range is not None:
            #~ self._x_range = self.plot1.getXRange()
            #~ self._y1_range = self.plot1.getYRange()
            #~ self._y2_range = self.plot2.getYRange()
            #this may change with pyqtgraph
            self._x_range = tuple(self.viewBox1.state['viewRange'][0])
            self._y1_range = tuple(self.viewBox1.state['viewRange'][1])
            self._y2_range = tuple(self.viewBox2.state['viewRange'][1])
            
        
        self.plot1.clear()
        self.plot2.clear()
        self.plot1.addItem(self.curve_one_waveform)
        
        if self.controller.spike_index ==[]:
            return

        nb_channel = self.controller.nb_channel
        #~ d = self.controller.info['waveform_extractor_params']
        #~ n_left, n_right = d['n_left'], d['n_right']
        n_left, n_right = self.controller.get_waveform_left_right()
        width = n_right - n_left
        
        #lines
        def addSpan(plot):
            white = pg.mkColor(255, 255, 255, 20)
            for i in range(nb_channel):
                if i%2==1:
                    region = pg.LinearRegionItem([width*i, width*(i+1)-1], movable = False, brush = white)
                    plot.addItem(region, ignoreBounds=True)
                    for l in region.lines:
                        l.setPen(white)
                vline = pg.InfiniteLine(pos = -n_left + width*i, angle=90, movable=False, pen = pg.mkPen('w'))
                plot.addItem(vline)
        
        if self.params['plot_limit_for_flatten']:
            addSpan(self.plot1)
            addSpan(self.plot2)
        
        if self.params['display_threshold']:
            thresh = self.controller.get_threshold()
            thresh_line = pg.InfiniteLine(pos=thresh, angle=0, movable=False, pen = pg.mkPen('w'))
            self.plot1.addItem(thresh_line)

            
            
        
        
        #waveforms
        
        if self.params['metrics']=='median/mad':
            key1, key2 = 'median', 'mad'
        elif self.params['metrics']=='mean/std':
            key1, key2 = 'mean', 'std'
        
        shape = self.controller.get_waveforms_shape()
        if shape is None:
            return
        xvect = np.arange(shape[0]*shape[1])
        
        #~ for i,k in enumerate(self.controller.centroids):
        for k in cluster_visible:
            #~ if not self.controller.cluster_visible[k]:
            if not cluster_visible[k]:
                continue
            
            #~ wf0 = self.controller.centroids[k][key1].T.flatten()
            #~ mad = self.controller.centroids[k][key2].T.flatten()
            wf0 = self.controller.get_waveform_centroid(k, key1)
            if wf0 is None: continue
            wf0 = wf0.T.flatten()
            mad = self.controller.get_waveform_centroid(k, key2)
            
            color = self.controller.qcolors.get(k, QT.QColor( 'white'))
            curve = pg.PlotCurveItem(xvect, wf0, pen=pg.mkPen(color, width=2))
            self.plot1.addItem(curve)
            
            
            if self.params['fillbetween'] and mad is not None:
                mad = mad.T.flatten()
                color2 = QT.QColor(color)
                color2.setAlpha(self.alpha)
                curve1 = pg.PlotCurveItem(xvect, wf0+mad, pen=color2)
                curve2 = pg.PlotCurveItem(xvect, wf0-mad, pen=color2)
                self.plot1.addItem(curve1)
                self.plot1.addItem(curve2)
                
                fill = pg.FillBetweenItem(curve1=curve1, curve2=curve2, brush=color2)
                self.plot1.addItem(fill)
            
            if mad is not None:
                curve = pg.PlotCurveItem(xvect, mad, pen=color)
                self.plot2.addItem(curve)        

        if self.params['show_channel_num']:
            for i, (chan, name) in enumerate(self.controller.channel_indexes_and_names):
                itemtxt = pg.TextItem('{}: {}'.format(i, name), anchor=(.5,.5), color='#FFFF00')
                itemtxt.setFont(QT.QFont('', pointSize=12))
                self.plot1.addItem(itemtxt)
                itemtxt.setPos(width*i-n_left, 0)

        
        if self._x_range is None:
            self._x_range = xvect[0], xvect[-1]
            self._y1_range = self.wf_min*1.1, self.wf_max*1.1
            self._y2_range = 0., 5.
        
        self.plot1.setXRange(*self._x_range, padding = 0.0)
        self.plot1.setYRange(*self._y1_range, padding = 0.0)
        self.plot2.setYRange(*self._y2_range, padding = 0.0)

        

    def refresh_mode_geometry(self, cluster_visible):
        if self._x_range is not None:
            #~ self._x_range = self.plot1.getXRange()
            #~ self._y1_range = self.plot1.getYRange()
            #this may change with pyqtgraph
            self._x_range = tuple(self.viewBox1.state['viewRange'][0])
            self._y1_range = tuple(self.viewBox1.state['viewRange'][1])

        self.plot1.clear()
        
        if self.xvect is None:
            return
        
        shape = self.controller.get_waveforms_shape()
        if shape is None:
            return
        
        # if n_left/n_right have change need new xvect
        if self.xvect.shape[0] != shape[0] * shape[1]:
            self.initialize_plot()
        
        self.plot1.addItem(self.curve_one_waveform)

        
        
        if self.params['metrics']=='median/mad':
            key1, key2 = 'median', 'mad'
        elif self.params['metrics']=='mean/std':
            key1, key2 = 'mean', 'std'

        ypos = self.arr_geometry[:,1]
        
        for k in cluster_visible:
            if not cluster_visible[k]:
                continue

            wf = self.controller.get_waveform_centroid(k, key1)
            
            if wf is None: continue
            
            wf = wf*self.factor_y*self.delta_y + ypos[None, :]
            wf[0,:] = np.nan
            wf = wf.T.reshape(-1)
            
            color = self.controller.qcolors.get(k, QT.QColor( 'white'))
            curve = pg.PlotCurveItem(self.xvect, wf, pen=pg.mkPen(color, width=2), connect='finite')
            self.plot1.addItem(curve)
        
        if self.params['show_channel_num']:
            chan_grp = self.controller.chan_grp
            channel_group = self.controller.dataio.channel_groups[chan_grp]            
            for i, (chan, name) in enumerate(self.controller.channel_indexes_and_names):
                x, y = self.arr_geometry[i, : ]
                itemtxt = pg.TextItem('{}: {}'.format(i, name), anchor=(.5,.5), color='#FFFF00')
                itemtxt.setFont(QT.QFont('', pointSize=12))
                self.plot1.addItem(itemtxt)
                itemtxt.setPos(x, y)
        
        if self._x_range is None:
            self._x_range = np.min(self.xvect), np.max(self.xvect)
            self._y1_range = np.min(ypos)-self.delta_y*2, np.max(ypos)+self.delta_y*2
        
        self.plot1.setXRange(*self._x_range, padding = 0.0)
        self.plot1.setYRange(*self._y1_range, padding = 0.0)
        
    
    def _refresh_one_spike(self, n_selected):
        #TODO peak the selected peak if only one
        
        if n_selected!=1 or not self.params['plot_selected_spike']: 
            self.curve_one_waveform.setData([], [])
            return
        
        ind, = np.nonzero(self.controller.spike_selection)
        ind = ind[0]
        seg_num = self.controller.spike_segment[ind]
        peak_ind = self.controller.spike_index[ind]
        
        n_left, n_right = self.controller.get_waveform_left_right()
        
        wf = self.controller.dataio.get_signals_chunk(seg_num=seg_num, chan_grp=self.controller.chan_grp,
                i_start=peak_ind+n_left, i_stop=peak_ind+n_right,
                signal_type='processed')
        
        if wf.shape[0]==(n_right-n_left):
            #this avoid border bugs
            if self.mode=='flatten':
                wf = wf.T.flatten()
                xvect = np.arange(wf.size)
                self.curve_one_waveform.setData(xvect, wf)
            elif self.mode=='geometry':
                ypos = self.arr_geometry[:,1]
                wf = wf*self.factor_y*self.delta_y + ypos[None, :]
                wf[0,:] = np.nan
                wf = wf.T.reshape(-1)
                self.curve_one_waveform.setData(self.xvect, wf)
    
    def on_spike_selection_changed(self):
        #~ n_selected = np.sum(self.controller.spike_selection)
        #~ self._refresh_one_spike(n_selected)
        self.refresh()




[docs]class WaveformViewer(WaveformViewerBase): """ **Waveform viewer** is undoubtedly the view to inspect waveforms. Note that in some aspect **Waveform hist viewer** can be a better firend. All centroid (median or mean) of visible cluster are plotted here. 2 main modes: * **geometry** waveforms are organized with 2d geometry given by PRB file. * **flatten** each chunk of each channel is put side by side in channel order than it can be ploted in 1d. The bottom view is th mad. On good cluster the mad must as close as possible from the value 1 because 1 is the normalized noise. The **geometry** mode is more intuitive and help users about spatial information. But the **flatten** mode is really important because is give information about the variance (mad or std) for each point and about peak alignement. The centoid is dfine by median+mad but you can also check with mean+std. For healthy cluster it should more or less the same. Important for zooming: * **geometry** : zoomXY geometry = right click, move = left click and mouse wheel = zoom waveforms * **flatten**: zoomXY = right click and move = left click Settings: * **plot_selected_spike**: superimposed one slected peak on centroid * **show_only_selected_cluster**: this auto hide all cluster except the one of selected spike * **plot_limit_for_flatten**: for flatten mode this plot line for delimiting channels. Plotting is important but it slow down the zoom. * **metrics**: choose median+mad or mean+std. * *show_channel_num**: what could it be ? * **flip_bottom_up**: in geometry this flip bottom up the channel geometry. * **display_threshold**: what could it be ? """ _params = [{'name': 'plot_selected_spike', 'type': 'bool', 'value': False }, {'name': 'show_only_selected_cluster', 'type': 'bool', 'value': False}, {'name': 'plot_limit_for_flatten', 'type': 'bool', 'value': True }, {'name': 'metrics', 'type': 'list', 'values': ['median/mad', 'mean/std'] }, {'name': 'fillbetween', 'type': 'bool', 'value': True }, {'name': 'show_channel_num', 'type': 'bool', 'value': False}, {'name': 'flip_bottom_up', 'type': 'bool', 'value': False}, {'name': 'display_threshold', 'type': 'bool', 'value' : True }, ]
class PeelerWaveformViewer(WaveformViewerBase): """ **Waveform viewer** """ _params = [{'name': 'plot_selected_spike', 'type': 'bool', 'value': True }, {'name': 'show_only_selected_cluster', 'type': 'bool', 'value': True}, {'name': 'plot_limit_for_flatten', 'type': 'bool', 'value': True }, {'name': 'metrics', 'type': 'list', 'values': ['median/mad'] }, {'name': 'fillbetween', 'type': 'bool', 'value': True }, {'name': 'show_channel_num', 'type': 'bool', 'value': False}, {'name': 'flip_bottom_up', 'type': 'bool', 'value': False}, {'name': 'display_threshold', 'type': 'bool', 'value' : True }, ]