Source code for tridesclous.gui.silhouette
"""
This view is from taken from sklearn examples.
See http://scikit-learn.org: plot-kmeans-silhouette-analysis-py
"""
from .myqt import QT
import pyqtgraph as pg
import numpy as np
import matplotlib.cm
import matplotlib.colors
from .base import WidgetBase
from .tools import ParamDialog
class MyViewBox(pg.ViewBox):
doubleclicked = QT.pyqtSignal()
def mouseDoubleClickEvent(self, ev):
self.doubleclicked.emit()
ev.accept()
def raiseContextMenu(self, ev):
#for some reasons enableMenu=False is not taken (bug ????)
pass
[docs]class Silhouette(WidgetBase):
"""
**Silhouette** display the silhouette score.
Implemented with sklearn. Must compute metrics first.
See:
* `Silhouette wikipedia <https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_
* `Silhouette sklearn <http://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_silhouette_analysis.html#sphx-glr-auto-examples-cluster-plot-kmeans-silhouette-analysis-py>`_
"""
_params = [
]
def __init__(self, controller=None, parent=None):
WidgetBase.__init__(self, parent=parent, controller=controller)
self.layout = QT.QVBoxLayout()
self.setLayout(self.layout)
h = QT.QHBoxLayout()
self.layout.addLayout(h)
h.addWidget(QT.QLabel('<b>Silhouette</b>') )
but = QT.QPushButton('settings')
but.clicked.connect(self.open_settings)
h.addWidget(but)
self.graphicsview = pg.GraphicsView()
self.layout.addWidget(self.graphicsview)
self.alpha = 60
self.initialize_plot()
self.refresh()
def on_params_changed(self):
self.compute_slihouette()
self.refresh()
def initialize_plot(self):
self.viewBox = MyViewBox()
self.viewBox.doubleclicked.connect(self.open_settings)
self.viewBox.disableAutoRange()
self.plot = pg.PlotItem(viewBox=self.viewBox)
self.graphicsview.setCentralItem(self.plot)
self.plot.hideButtons()
def refresh(self):
self.plot.clear()
silhouette_values = self.controller.spike_silhouette
if silhouette_values is None:
return
if silhouette_values.shape != self.controller.spike_label.shape:
return
silhouette_avg = np.mean(silhouette_values)
silhouette_by_labels = {}
labels = self.controller.spike_label
labels_list = np.unique(labels)
for k in labels_list:
v = silhouette_values[k==labels]
v.sort()
silhouette_by_labels[k] = v
self.vline = pg.InfiniteLine(pos=silhouette_avg, angle = 90, movable = False, pen = '#FF0000')
self.plot.addItem(self.vline)
y_lower = 10
cluster_visible = self.controller.cluster_visible
visibles = [c for c, v in self.controller.cluster_visible.items() if v and c>=0]
for k in visibles:
if k not in silhouette_by_labels:
continue
v = silhouette_by_labels[k]
color = self.controller.qcolors[k]
color2 = QT.QColor(color)
color2.setAlpha(self.alpha)
y_upper = y_lower + v.size
y_vect = np.arange(y_lower, y_upper)
curve1 = pg.PlotCurveItem(np.zeros(v.size), y_vect, pen=color)
curve2 = pg.PlotCurveItem(v, y_vect, pen=color)
self.plot.addItem(curve1)
self.plot.addItem(curve2)
fill = pg.FillBetweenItem(curve1=curve1, curve2=curve2, brush=color2)
self.plot.addItem(fill)
txt = pg.TextItem( text='{}'.format(k), color='#FFFFFF', anchor=(0, 0.5), border=None)#, fill=pg.mkColor((128,128,128, 180)))
self.plot.addItem(txt)
txt.setPos(0, (y_upper+y_lower)/2.)
y_lower = y_upper + 10
self.plot.setXRange(-.5, 1.)
self.plot.setYRange(0,y_lower)
def on_spike_selection_changed(self):
pass
def on_spike_label_changed(self):
#~ self.compute_slihouette()
self.refresh()
def on_colors_changed(self):
self.refresh()
def on_cluster_visibility_changed(self):
self.refresh()