"""
.. autoclass:: DataIO
:members:
"""
import os, shutil
import json
from collections import OrderedDict
import numpy as np
import pandas as pd
from urllib.request import urlretrieve
import pickle
from .datasource import data_source_classes
from .iotools import ArrayCollection
from .tools import download_probe, create_prb_file_from_dict, fix_prb_file_py2
from .export import export_list, export_dict
_signal_types = ['initial', 'processed']
[docs]class DataIO:
"""
Class to acces the dataset (raw data, processed, catalogue,
spikes) in read/write mode.
All operations on the dataset are done througth that class.
The dataio :
* work in a path. Almost everything is persistent.
* needed by CatalogueConstructor and Peeler
* have a datasource member that access raw data
* store/load processed signals
* store/load spikes
* store/load the catalogue
* deal with sevral channel groups given a PRB file
* deal with several segment of recording (aka several files for raw data)
* export the results (labeled spikes) to differents format
The underlying data storage is a simple tree on the file system.
Everything is organised as simple as possible in sub folder
(by channel group then segment index).
In each folder:
* arrays.json describe the list of numpy array (name, dtype, shape)
* XXX.raw are the raw numpy arrays and load with a simple memmap.
* some array are struct arrays (aka array of struct)
The datasource system is based on neo.rawio so all format in neo.rawio are
available in tridesclous. neo.rawio is able to read chunk of signals indexed
on time axis and channel axis.
The raw dataset do not need to be inside the working directory but can be somewhere outside.
The info.json describe the link to the *datasource* (raw data)
Many raw dataset are saved by the device with an underlying int16.
DataIO save the processed signals as float32 by default. So if
you have a 10Go raw dataset tridesclous will need at least 20 Go more for storage
of the processed signals.
**Usage**::
# initialize a directory
dataio = DataIO(dirname='/path/to/a/working/dir')
# set a data source
filenames = ['file1.raw', 'file2.raw']
dataio.dataio.set_data_source(type='RawData', filenames=filenames,
sample_rate=10000, total_channel=16, dtype='int16')
# set a PRB file
dataio.set_probe_file('/path/to/a/file.prb')
# or dowload it
dataio.download_probe('kampff_128', origin='spyking-circus')
# check lenght and channel groups
print(dataio)
"""
@staticmethod
def check_initialized(dirname):
if not os.path.exists(dirname):
return False
info_filename = os.path.join(dirname, 'info.json')
if not os.path.exists(info_filename):
return False
return True
def __init__(self, dirname='test'):
self.dirname = dirname
if not os.path.exists(dirname):
os.mkdir(dirname)
self.info_filename = os.path.join(self.dirname, 'info.json')
if not os.path.exists(self.info_filename):
#first init
self.info = {}
self.flush_info()
self.datasource = None
else:
with open(self.info_filename, 'r', encoding='utf8') as f:
self.info = json.load(f)
#~ print('*'*50)
#~ print(self.info_filename)
#~ print(self.info)
#~ print('*'*50)
#~ try:
#~ if 1:
if len(self.info)>0:
#~ self._reload_info()
self._reload_channel_group()
self._reload_data_source()
self._open_processed_data()
else:
self.datasource = None
#~ except:
#~ self.info = {}
#~ self.flush_info()
#~ self.datasource = None
def __repr__(self):
t = "DataIO <id: {}> \n workdir: {}\n".format(id(self), self.dirname)
if len(self.info) ==0 or self.datasource is None:
t += "\n Not datasource is set yet"
return t
t += " sample_rate: {}\n".format(self.sample_rate)
t += " total_channel: {}\n".format(self.total_channel)
if len(self.channel_groups)==1:
k0, cg0 = next(iter(self.channel_groups.items()))
ch_names = np.array(self.datasource.get_channel_names())[cg0['channels']]
if len(ch_names)>8:
chantxt = "[{} ... {}]".format(' '.join(ch_names[:4]),' '.join(ch_names[-4:]))
else:
chantxt = "[{}]".format(' '.join(ch_names))
t += " channel_groups: {} {}\n".format(k0, chantxt)
else:
t += " channel_groups: {}\n".format(', '.join(['{} ({}ch)'.format(cg, self.nb_channel(cg))
for cg in self.channel_groups.keys() ]))
t += " nb_segment: {}\n".format(self.nb_segment)
if self.nb_segment<5:
lengths = [ self.datasource.get_segment_shape(i)[0] for i in range(self.nb_segment)]
t += ' length: '+' '.join('{}'.format(l) for l in lengths)+'\n'
t += ' durations: '+' '.join('{:0.1f}'.format(l/self.sample_rate) for l in lengths)+' s.\n'
if t.endswith('\n'):
t = t[:-1]
return t
def flush_info(self):
with open(self.info_filename, 'w', encoding='utf8') as f:
json.dump(self.info, f, indent=4)
[docs] def set_data_source(self, type='RawData', **kargs):
"""
Set the datasource. Must be done only once otherwise raise error.
Parameters
------------------
type: str ('RawData', 'Blackrock', 'Neuralynx', ...)
The name of the neo.rawio class used to open the dataset.
kargs:
depends on the class used. They are transimted to neo class.
So see neo doc for kargs
"""
assert type in data_source_classes, 'this source type do not exists yet!!'
assert 'datasource_type' not in self.info, 'datasource is already set'
# force abs path name
if 'filenames' in kargs:
kargs['filenames'] = [ os.path.abspath(f) for f in kargs['filenames']]
if 'dirnames' in kargs:
kargs['dirnames'] = [ os.path.abspath(f) for f in kargs['dirnames']]
self.info['datasource_type'] = type
self.info['datasource_kargs'] = kargs
self._reload_data_source()
# be default chennel group all channels
channel_groups = {0:{'channels':list(range(self.total_channel))}}
self.set_channel_groups( channel_groups, probe_filename='default.prb')
self.flush_info()
# this create segment path
self._open_processed_data()
def _reload_data_source(self):
assert 'datasource_type' in self.info
kargs = self.info['datasource_kargs']
self.datasource = data_source_classes[self.info['datasource_type']](**kargs)
self.total_channel = self.datasource.total_channel
self.nb_segment = self.datasource.nb_segment
self.sample_rate = self.datasource.sample_rate
self.source_dtype = self.datasource.dtype
def _reload_channel_group(self):
#TODO test in prb is compatible with py3
d = {}
probe_filename = os.path.join(self.dirname, self.info['probe_filename'])
with open(probe_filename) as f:
exec(f.read(), None, d)
channel_groups = d['channel_groups']
for chan_grp, channel_group in channel_groups.items():
assert 'channels' in channel_group
channel_group['channels'] = list(channel_group['channels'])
if 'geometry' not in channel_group or channel_group['geometry'] is None:
channels = channel_group['channels']
geometry = self._make_fake_geometry(channels)
channel_group['geometry'] = geometry
self.channel_groups = channel_groups
def _rm_old_probe_file(self):
old_filename = self.info.get('probe_filename', None)
if old_filename is not None:
os.remove(os.path.join(self.dirname, old_filename))
[docs] def set_probe_file(self, src_probe_filename):
"""
Set the probe file.
The probe file is copied inside the working dir.
"""
self._rm_old_probe_file()
probe_filename = os.path.join(self.dirname, os.path.basename(src_probe_filename))
try:
shutil.copyfile(src_probe_filename, probe_filename)
except shutil.SameFileError:
# print('probe allready in dir')
pass
fix_prb_file_py2(probe_filename)
# check that the geometry is 2D
with open(probe_filename) as f:
d = {}
exec(f.read(), None, d)
channel_groups = d['channel_groups']
for chan_grp, channel_group in channel_groups.items():
geometry = channel_group.get('geometry', None)
if geometry is not None:
for c, v in geometry.items():
assert len(v) == 2, 'Tridesclous need 2D geometry'
self.info['probe_filename'] = os.path.basename(probe_filename)
self.flush_info()
self._reload_channel_group()
self._open_processed_data()
[docs] def download_probe(self, probe_name, origin='kwikteam'):
"""
Download a prb file from github into the working dir.
The spiking-circus and kwikteam propose a list prb file.
See:
* https://github.com/kwikteam/probes
* https://github.com/spyking-circus/spyking-circus/tree/master/probes
Parameters
------------------
probe_name: str
the name of file in github
origin: 'kwikteam' or 'spyking-circus'
github project
"""
self._rm_old_probe_file()
probe_filename = download_probe(self.dirname, probe_name, origin=origin)
#~ if origin == 'kwikteam':
#~ #Max Hunter made a list of neuronexus probes, many thanks
#~ baseurl = 'https://raw.githubusercontent.com/kwikteam/probes/master/'
#~ elif origin == 'spyking-circus':
#~ # Pierre Yger made a list of various probe file, many thanks
#~ baseurl = 'https://raw.githubusercontent.com/spyking-circus/spyking-circus/master/probes/'
#~ else:
#~ raise(NotImplementedError)
#~ if not probe_name.endswith('.prb'):
#~ probe_name += '.prb'
#~ probe_filename = os.path.join(self.dirname,probe_name)
#~ urlretrieve(baseurl+probe_name, probe_filename)
#~ fix_prb_file_py2(probe_filename)#fix range to list(range
fix_prb_file_py2(probe_filename)
self.info['probe_filename'] = os.path.basename(probe_filename)
self.flush_info()
self._reload_channel_group()
self._open_processed_data()
def _make_fake_geometry(self, channels):
if len(channels)!=4:
# assume that it is a linear probes with 200 um
geometry = { c: [0, i*200] for i, c in enumerate(channels) }
else:
# except for tetrode
geometry = dict(zip(channels, [(0., 50.), (50., 0.), (0., -50.), (-50, 0.)]))
return geometry
[docs] def set_channel_groups(self, channel_groups, probe_filename='channels.prb'):
"""
Set manually the channel groups dictionary.
"""
self._rm_old_probe_file()
# checks
for chan_grp, channel_group in channel_groups.items():
assert 'channels' in channel_group
channel_group['channels'] = list(channel_group['channels'])
if 'geometry' not in channel_group or channel_group['geometry'] is None:
channels = channel_group['channels']
#~ geometry = { c: [0, i] for i, c in enumerate(channels) }
geometry = self._make_fake_geometry(channels)
channel_group['geometry'] = geometry
# write with hack on json to put key as inteteger (normally not possible in json)
#~ with open(os.path.join(self.dirname,probe_filename) , 'w', encoding='utf8') as f:
#~ txt = json.dumps(channel_groups,indent=4)
#~ for chan_grp in channel_groups.keys():
#~ txt = txt.replace('"{}":'.format(chan_grp), '{}:'.format(chan_grp))
#~ for chan in channel_groups[chan_grp]['channels']:
#~ txt = txt.replace('"{}":'.format(chan), '{}:'.format(chan))
#~ txt = 'channel_groups = ' +txt
#~ f.write(txt)
create_prb_file_from_dict(channel_groups, os.path.join(self.dirname,probe_filename))
self.info['probe_filename'] = probe_filename
self.flush_info()
self._reload_channel_group()
self._open_processed_data()
def add_one_channel_group(self, channels=[], chan_grp=0, geometry=None):
channels = list(channels)
if geometry is None:
geometry = self._make_fake_geometry(channels)
self.channel_groups[chan_grp] = {'channels': channels, 'geometry':geometry}
#rewrite with same name
self.set_channel_groups(self.channel_groups, probe_filename=self.info['probe_filename'])
[docs] def get_geometry(self, chan_grp=0):
"""
Get the geometry for a given channel group in a numpy array way.
"""
channel_group = self.channel_groups[chan_grp]
geometry = [ channel_group['geometry'][chan] for chan in channel_group['channels'] ]
geometry = np.array(geometry, dtype='float64')
return geometry
[docs] def nb_channel(self, chan_grp=0):
"""
Number of channel for a channel group.
"""
#~ print('DataIO.nb_channel', self.channel_groups)
return len(self.channel_groups[chan_grp]['channels'])
[docs] def channel_group_label(self, chan_grp=0):
"""
Label of channel for a group.
"""
label = 'chan_grp {} - '.format(chan_grp)
channels = self.channel_groups[chan_grp]['channels']
ch_names = np.array(self.datasource.get_channel_names())[channels]
if len(ch_names)<8:
label += ' '.join(ch_names)
else:
label += ' '.join(ch_names[:3]) + ' ... ' + ' '.join(ch_names[-2:])
return label
def _open_processed_data(self):
self.channel_group_path = {}
self.segments_path = {}
for chan_grp in self.channel_groups.keys():
self.segments_path[chan_grp] = []
cg_path = os.path.join(self.dirname, 'channel_group_{}'.format(chan_grp))
self.channel_group_path[chan_grp] = cg_path
if not os.path.exists(cg_path):
os.mkdir(cg_path)
for i in range(self.nb_segment):
segment_path = os.path.join(cg_path, 'segment_{}'.format(i))
if not os.path.exists(segment_path):
os.mkdir(segment_path)
self.segments_path[chan_grp].append(segment_path)
self.arrays = {}
for chan_grp in self.channel_groups.keys():
self.arrays[chan_grp] = []
for i in range(self.nb_segment):
arrays = ArrayCollection(parent=None, dirname=self.segments_path[chan_grp][i])
self.arrays[chan_grp].append(arrays)
for name in ['processed_signals', 'spikes']:
self.arrays[chan_grp][i].load_if_exists(name)
[docs] def get_segment_length(self, seg_num):
"""
Segment length (in sample) for a given segment index
"""
full_shape = self.datasource.get_segment_shape(seg_num)
return full_shape[0]
[docs] def get_segment_shape(self, seg_num, chan_grp=0):
"""
Segment shape for a given segment index and channel group.
"""
full_shape = self.datasource.get_segment_shape(seg_num)
shape = (full_shape[0], self.nb_channel(chan_grp))
return shape
[docs] def get_signals_chunk(self, seg_num=0, chan_grp=0,
i_start=None, i_stop=None,
signal_type='initial'): #return_type='raw_numpy'
"""
Get a chunk of signal for for a given segment index and channel group.
The signal can be the 'initial' (aka raw signal), the none filetered signals or
the 'processed' signal.
Parameters
------------------
seg_num: int
segment index
chan_grp: int
channel group key
i_start: int or None
start index (included)
i_stop: int or None
stop index (not included)
signal_type: str
'initial' or 'processed'
"""
channels = self.channel_groups[chan_grp]['channels']
if signal_type=='initial':
data = self.datasource.get_signals_chunk(seg_num=seg_num, i_start=i_start, i_stop=i_stop)
data = data[:, channels]
elif signal_type=='processed':
data = self.arrays[chan_grp][seg_num].get('processed_signals')[i_start:i_stop, :]
else:
raise(ValueError, 'signal_type is not valide')
return data
#~ if return_type=='raw_numpy':
#~ return data
#~ elif return_type=='on_scale_numpy':
#~ raise(NotImplementedError)
#~ elif return_type=='pandas':
#~ raise(NotImplementedError)
[docs] def iter_over_chunk(self, seg_num=0, chan_grp=0, i_stop=None, chunksize=1024, **kargs):
"""
Create an iterable on signals. ('initial' or 'processed')
Usage
----------
for ind, sig_chunk in data.iter_over_chunk(seg_num=0, chan_grp=0, chunksize=1024, signal_type='processed'):
do_something_on_chunk(sig_chunk)
"""
if i_stop is not None:
length = min(self.get_segment_shape(seg_num, chan_grp=chan_grp)[0], i_stop)
else:
length = self.get_segment_shape(seg_num, chan_grp=chan_grp)[0]
nloop = length//chunksize
for i in range(nloop):
i_stop = (i+1)*chunksize
i_start = i_stop - chunksize
sigs_chunk = self.get_signals_chunk(seg_num=seg_num, chan_grp=chan_grp, i_start=i_start, i_stop=i_stop, **kargs)
yield i_stop, sigs_chunk
# lat chunk = very bad idea because it break chunksize in many place (OpenCL signal processor, Peeler, ...)
#~ if i_stop<length:
#~ i_start = i_stop
#~ i_stop = length
#~ sigs_chunk = self.get_signals_chunk(seg_num=seg_num, chan_grp=chan_grp, i_start=i_start, i_stop=i_stop, **kargs)
#~ yield i_stop, sigs_chunk
[docs] def reset_processed_signals(self, seg_num=0, chan_grp=0, dtype='float32'):
"""
Reset processed signals.
"""
self.arrays[chan_grp][seg_num].create_array('processed_signals', dtype,
self.get_segment_shape(seg_num, chan_grp=chan_grp), 'memmap')
[docs] def set_signals_chunk(self,sigs_chunk, seg_num=0, chan_grp=0, i_start=None, i_stop=None, signal_type='processed'):
"""
Set a signal chunk (only for 'processed')
"""
assert signal_type != 'initial'
if signal_type=='processed':
data = self.arrays[chan_grp][seg_num].get('processed_signals')
data[i_start:i_stop, :] = sigs_chunk
[docs] def flush_processed_signals(self, seg_num=0, chan_grp=0):
"""
Flush the underlying memmap for processed signals.
"""
self.arrays[chan_grp][seg_num].flush_array('processed_signals')
[docs] def reset_spikes(self, seg_num=0, chan_grp=0, dtype=None):
"""
Reset spikes.
"""
assert dtype is not None
self.arrays[chan_grp][seg_num].initialize_array('spikes', 'memmap', dtype, (-1,))
[docs] def append_spikes(self, seg_num=0, chan_grp=0, spikes=None):
"""
Append spikes.
"""
if spikes is None: return
self.arrays[chan_grp][seg_num].append_chunk('spikes', spikes)
[docs] def flush_spikes(self, seg_num=0, chan_grp=0):
"""
Flush underlying memmap for spikes.
"""
self.arrays[chan_grp][seg_num].finalize_array('spikes')
def is_spike_computed(self, chan_grp=0):
done = all(self.arrays[chan_grp][seg_num].has_key('spikes') for seg_num in range(self.nb_segment))
return done
[docs] def get_spikes(self, seg_num=0, chan_grp=0, i_start=None, i_stop=None):
"""
Read spikes
"""
if not self.arrays[chan_grp][seg_num].has_key('spikes'):
return None
spikes = self.arrays[chan_grp][seg_num].get('spikes')
if spikes is None:
return
return spikes[i_start:i_stop]
[docs] def save_catalogue(self, catalogue, name='initial'):
"""
Save the catalogue made by `CatalogueConstructor` and needed
by `Peeler` inside the working dir.
Note that you can construct several catalogue for the same dataset
to compare then just change the name. Different folder name so.
"""
catalogue = dict(catalogue)
chan_grp = catalogue['chan_grp']
dir = os.path.join(self.dirname,'channel_group_{}'.format(chan_grp), 'catalogues', name)
if not os.path.exists(dir):
os.makedirs(dir)
arrays = ArrayCollection(parent=None, dirname=dir)
to_rem = []
for k, v in catalogue.items():
if isinstance(v, np.ndarray):
arrays.add_array(k, v, 'memmap')
to_rem.append(k)
for k in to_rem:
catalogue.pop(k)
# JSON is not possible for now because some key in catalogue are integer....
# So bad....
#~ with open(os.path.join(dir, 'catalogue.json'), 'w', encoding='utf8') as f:
#~ json.dump(catalogue, f, indent=4)
with open(os.path.join(dir, 'catalogue.pickle'), 'wb') as f:
pickle.dump(catalogue, f)
[docs] def load_catalogue(self, name='initial', chan_grp=0):
"""
Load the catalogue dict.
"""
dir = os.path.join(self.dirname,'channel_group_{}'.format(chan_grp), 'catalogues', name)
filename = os.path.join(dir, 'catalogue.pickle')
#~ with open(os.path.join(dir, 'catalogue.json'), 'r', encoding='utf8') as f:
#~ catalogue = json.load(f)
if not os.path.exists(filename):
return
with open(filename, 'rb') as f:
catalogue = pickle.load(f)
arrays = ArrayCollection(parent=None, dirname=dir)
arrays.load_all()
for k in arrays.keys():
catalogue[k] = np.array(arrays.get(k), copy=True)
return catalogue
[docs] def export_spikes(self, export_path=None,
split_by_cluster=False, use_cell_label=True, formats=None):
"""
Export spikes to other format (csv, matlab, excel, ...)
Parameters
------------------
export_path: str or None
export path. If None (default then inside working dir)
split_by_cluster: bool (default False)
Each cluster is split to a diffrent file or not.
use_cell_label: bool (deafult True)
if true cell_label is used if false cluster_label is used
formats: 'csv' or 'mat' or 'xlsx'
The output format.
"""
if export_path is None:
export_path = os.path.join(self.dirname, 'export')
if formats is None:
exporters = export_list
elif isinstance(formats, str):
assert formats in export_dict
exporters = [ export_dict[formats] ]
elif isinstance(format, list):
exporters = [ export_dict[format] for format in formats]
else:
return
for chan_grp in self.channel_groups.keys():
catalogue = self.load_catalogue(chan_grp=chan_grp)
if catalogue is None:
continue
if not self.is_spike_computed(chan_grp=chan_grp):
continue
for seg_num in range(self.nb_segment):
spikes = self.get_spikes(seg_num=seg_num, chan_grp=chan_grp)
if spikes is None: continue
args = (spikes, catalogue, seg_num, chan_grp, export_path,)
kargs = dict(split_by_cluster=split_by_cluster, use_cell_label=use_cell_label)
for exporter in exporters:
exporter(*args, **kargs)