Source code for tridesclous.gui.waveformhistviewer

from .myqt import QT
import pyqtgraph as pg

import numpy as np
import matplotlib.cm
import matplotlib.colors

import time

from .base import WidgetBase
from .tools import ParamDialog
from ..tools import median_mad
from .. import labelcodes

class MyViewBox(pg.ViewBox):
    doubleclicked = QT.pyqtSignal()
    gain_zoom = QT.pyqtSignal(float)
    def mouseDoubleClickEvent(self, ev):
        self.doubleclicked.emit()
        ev.accept()
    def wheelEvent(self, ev, axis=None):
        if ev.modifiers() == QT.Qt.ControlModifier:
            z = 10 if ev.delta()>0 else 1/10.
        else:
            z = 1.3 if ev.delta()>0 else 1/1.3
        self.gain_zoom.emit(z)
        ev.accept()
    def raiseContextMenu(self, ev):
        #for some reasons enableMenu=False is not taken (bug ????)
        pass
        

[docs]class WaveformHistViewer(WidgetBase): """ **Waveform histogram viewer** is also a important thing. It is equivalent to **Waveform veiwer** in **flatten** mode but with a 2d histogram that show the density (probability) of a cluster. So waveforms are flatten from (nb_peak, nb_sample, nb_channel) to (nb_peak, nb_channel*nb_sample) and are binarized on a 2d histogram. Then this is plotted as a map. The color code the density. This is the best friend to see if two cluster are well discrimitated somewhere or if one cluster must be split. Important: * use right click for X/Y zoom * use left clik to move * use **mouse wheel** for color zoom.Really important to play with this to discover low density * intentionnaly not all cluster are displayed other we see nothing. The best is to plot 2 by 2. Furthermore it faster to plot with few cluster. * don't forget to display the **noise snippet** to validate that the mad is 1 for all channel. Settings: * **colormap** hot is good because loaw density are black like background. * **data** choose waveforms or features * **bin_min** y limts of histogram * **bin_max** y limts of histogram * **bin_size** * **display_threshold** * **max_label** maximum number of labels displayed simulteneously (2 by default but you can set more) """ _params = [ {'name': 'colormap', 'type': 'list', 'values' : ['hot', 'viridis', 'jet', 'gray', ] }, {'name': 'data', 'type': 'list', 'values' : ['waveforms', 'features', ] }, {'name': 'bin_min', 'type': 'float', 'value' : -20. }, {'name': 'bin_max', 'type': 'float', 'value' : 8. }, {'name': 'bin_size', 'type': 'float', 'value' : .1 }, {'name': 'display_threshold', 'type': 'bool', 'value' : True }, {'name': 'max_label', 'type': 'int', 'value' : 2 }, ] def __init__(self, controller=None, parent=None): WidgetBase.__init__(self, parent=parent, controller=controller) self.layout = QT.QVBoxLayout() self.setLayout(self.layout) h = QT.QHBoxLayout() self.layout.addLayout(h) but = QT.QPushButton('Show 1D dist', checkable=True) h.addWidget(but) but.clicked.connect(self.show_hide_1d_dist) self.graphicsview = pg.GraphicsView() self.layout.addWidget(self.graphicsview) self.graphicsview2 = pg.GraphicsView() self.layout.addWidget(self.graphicsview2) self.graphicsview2.hide() self.create_settings() self.initialize_plot() self.similarity = None self.on_params_changed()#this do refresh def on_params_changed(self, ): #params, changes #~ for param, change, data in changes: #~ if change != 'value': continue #~ if param.name()=='data': N = 512 cmap_name = self.params['colormap'] cmap = matplotlib.cm.get_cmap(cmap_name , N) lut = [] for i in range(N): r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) lut.append([r*255,g*255,b*255]) self.lut = np.array(lut, dtype='uint8') self._x_range = None self._y_range = None self.refresh() def initialize_plot(self): if self.controller.some_waveforms is None: return self.viewBox = MyViewBox() self.viewBox.doubleclicked.connect(self.open_settings) self.viewBox.gain_zoom.connect(self.gain_zoom) self.viewBox.disableAutoRange() self.plot = pg.PlotItem(viewBox=self.viewBox) self.graphicsview.setCentralItem(self.plot) self.plot.hideButtons() self.image = pg.ImageItem() self.plot.addItem(self.image) #~ self.curve1 = pg.PlotCurveItem() #~ self.plot.addItem(self.curve1) #~ self.curve2 = pg.PlotCurveItem() #~ self.plot.addItem(self.curve2) self.curves = [] thresh = self.controller.get_threshold() self.thresh_line = pg.InfiniteLine(pos=thresh, angle=0, movable=False, pen = pg.mkPen('w')) self.plot.addItem(self.thresh_line) self.params.blockSignals(True) #~ self.params['bin_min'] = np.min(self.controller.some_waveforms) #~ self.params['bin_max'] = np.max(self.controller.some_waveforms) #~ print(self.controller.some_waveforms.shape) #~ print(self.controller.some_peaks_index.shape) #~ print(self.controller.some_peaks_index) #~ print(self.controller.spike_label[self.controller.some_peaks_index].shape) keep = self.controller.spike_label[self.controller.some_peaks_index]>=0 wfs = self.controller.some_waveforms[keep, :, :] if wfs.shape[0]>0: self.params['bin_min'] = np.percentile(wfs, .001) self.params['bin_max'] = np.percentile(wfs, 99.999) self.params.blockSignals(False) def gain_zoom(self, v): #~ print('v', v) levels = self.image.getLevels() if levels is not None: self.image.setLevels(levels * v, update=True) def refresh(self): if not hasattr(self, 'viewBox'): self.initialize_plot() if not hasattr(self, 'viewBox'): return if self._x_range is not None: #~ self._x_range = self.plot.getXRange() #~ self._y_range = self.plot.getYRange() #this may change with pyqtgraph self._x_range = tuple(self.viewBox.state['viewRange'][0]) self._y_range = tuple(self.viewBox.state['viewRange'][1]) cluster_visible = self.controller.cluster_visible visibles = [k for k, v in cluster_visible.items() if v and k>=-1 ] #remove old curves for curve in self.curves: self.plot.removeItem(curve) self.curves = [] if len(visibles)>self.params['max_label'] or len(visibles)==0: self.image.hide() return #~ if len(visibles)==1: #~ self.curve2.hide() if self.controller.some_peaks_index is None: self.plot.clear() return labels = self.controller.spike_label[self.controller.some_peaks_index] keep = np.in1d(labels, visibles) if self.params['data']=='waveforms': wf = self.controller.some_waveforms if wf is None: self.plot.clear() return data_kept = wf[keep].copy() if data_kept.size == 0: self.plot.clear() return data_kept = data_kept.swapaxes(1,2).reshape(data_kept.shape[0], -1) elif self.params['data']=='features': data = self.controller.some_features data_kept = data[keep] if data is None: self.plot.clear() return #TODO change for PCA if self.params['data']=='waveforms': bin_min, bin_max = self.params['bin_min'], self.params['bin_max'] bin_size = self.params['bin_size'] bins = np.arange(bin_min, bin_max, self.params['bin_size']) elif self.params['data']=='features': bin_min, bin_max = np.min(data_kept), np.max(data_kept) #~ n = 500 bins = np.linspace(bin_min, bin_max, 500) bin_size = bins[1] - bins[0] #~ med, mad = median_mad(data_kept, axis=0) #~ min, max = np.min(med-10*mad), np.max(med+10*mad) #~ n = self.params['nb_bin'] #~ bin = (max-min)/(n-1) n = bins.size hist2d = np.zeros((data_kept.shape[1], bins.size)) indexes0 = np.arange(data_kept.shape[1]) data_bined = np.floor((data_kept-bin_min)/bin_size).astype('int32') data_bined = data_bined.clip(0, bins.size-1) for d in data_bined: hist2d[indexes0, d] += 1 if self.controller.cluster_visible[labelcodes.LABEL_NOISE] and self.controller.some_noise_snippet is not None: #~ print('labelcodes.LABEL_NOISE in cluster_visible', labelcodes.LABEL_NOISE in cluster_visible, cluster_visible) if self.params['data']=='waveforms': noise = self.controller.some_noise_snippet noise = noise.swapaxes(1,2).reshape(noise.shape[0], -1) noise_bined = np.floor((noise-bin_min)/bin_size).astype('int32') noise_bined = noise_bined.clip(0, bins.size-1) for d in noise_bined: hist2d[indexes0, d] += 1 #~ elif self.params['data']=='features': self.image.setImage(hist2d, lut=self.lut)#, levels=[0, self._max]) self.image.setRect(QT.QRectF(-0.5, bin_min, data_kept.shape[1], bin_max-bin_min)) self.image.show() #~ for k, curve in zip(visibles, [self.curve1, self.curve2]): for k in visibles: median = self.controller.get_waveform_centroid(k, 'median') if median is None: continue if self.params['data']=='waveforms': y = median.T.flatten() else: y = np.median(data[labels==k], axis=0) color = self.controller.qcolors.get(k, QT.QColor( 'white')) curve = pg.PlotCurveItem(x=indexes0, y=y, pen=pg.mkPen(color, width=2)) self.plot.addItem(curve) self.curves.append(curve) #~ curve.setData() #~ curve.setPen() #~ curve.show() if self.params['display_threshold'] and self.params['data']=='waveforms' : self.thresh_line.show() else: self.thresh_line.hide() if self._x_range is None: self._x_range = 0, indexes0[-1] #hist2d.shape[1] self._y_range = bin_min, bin_max self.plot.setXRange(*self._x_range, padding = 0.0) self.plot.setYRange(*self._y_range, padding = 0.0) def on_spike_selection_changed(self): pass def on_spike_label_changed(self): self.refresh() def on_colors_changed(self): self.refresh() def on_cluster_visibility_changed(self): self.refresh() def on_cluster_tag_changed(self): pass def show_hide_1d_dist(self, v=None): #~ print(v) if v: self.graphicsview2.show() else: self.graphicsview2.hide()