Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 59 additions & 26 deletions realtime_decoder/encoder_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import fcntl
import numpy as np

import copy
from mpi4py import MPI
from typing import Sequence, List

Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(self, config, trode, pos_bin_struct):
N = self._config['encoder']['bufsize']
dim = self._config['encoder']['mark_dim']
self._marks = np.zeros((N, dim), dtype='<f8')
self._chosen_indices = 0
self._positions = np.zeros(N, dtype='<f4')
self._mark_idx = 0
self._occupancy = np.zeros(self._config['encoder']['position']['num_bins'])
Expand Down Expand Up @@ -110,9 +111,14 @@ def _load_model(self):
)
else:
with np.load(files[0]) as f:
self._marks = f['marks']
self._positions = f['positions']
self._mark_idx = f['mark_idx'][0]
self._marks = f['marks']
print(f['mark_idx'])
if f['mark_idx'][0] < self._config['encoder']['bufsize']: #NOTE(DS): it seem to be a offset of 1.
self._mark_idx = f['mark_idx'][0]-1
else:
self._mark_idx = self._config['encoder']['bufsize']-1
Comment on lines +116 to +120

self._occupancy = f['occupancy']
self._occupancy_ct = f['occupancy_ct'][0]
self.class_log.info(f"Loaded encoding model from {files[0]}")
Expand All @@ -132,13 +138,8 @@ def _init_params(self):
self.p['num_occupancy_points'] = self._config['display']['encoder']['occupancy']

def add_new_mark(self, mark):

'''
# this is where the mark_size increases over time
self._marks[self._mark_idx%self._marks.shape[0]] = mark
self._positions[self._mark_idx%self._marks.shape[0]] = self._position
self._mark_idx += 1
'''

# NOTE(DS): Having only the most recent spikes bias the encoding
if self._mark_idx < self._marks.shape[0]:
self._marks[self._mark_idx%self._marks.shape[0]] = mark
Expand All @@ -149,6 +150,7 @@ def add_new_mark(self, mark):
self._marks[self._mark_idx%self._marks.shape[0]] = mark
self._positions[self._mark_idx%self._marks.shape[0]] = self._position
self._mark_idx += 3
'''

'''
if self._mark_idx < self._marks.shape[0]:
Expand All @@ -166,16 +168,21 @@ def add_new_mark(self, mark):
)
'''

''' # NOTE(DS): This make buf_size meaningless
self._marks = np.vstack((
self._marks,
np.zeros_like(self._marks)
))
self._positions = np.hstack((
self._positions,
np.zeros_like(self._positions)
))
'''
if self._mark_idx == self._marks.shape[0]:
# NOTE(DS): This make buf_size meaningless
self._marks = np.vstack((
self._marks,
np.zeros_like(self._marks)
))
self._positions = np.hstack((
self._positions,
np.zeros_like(self._positions)
))

# this is where the mark_size increases over time
self._marks[self._mark_idx] = mark
self._positions[self._mark_idx] = self._position
self._mark_idx += 1
Comment on lines +171 to +185



Expand All @@ -190,15 +197,16 @@ def get_joint_prob(self, mark):
if self._mark_idx == 0:
return None

if self._mark_idx >= self._marks.shape[0]:
mark_idx = self._marks.shape[0]
else:
mark_idx = self._mark_idx

#NOTE(DS): if number of mark exceeds
if self._mark_idx >= self._config['encoder']['bufsize']:#self._marks.shape[0]:
mark_idx = self._config['encoder']['bufsize']

#print(mark)
else:
mark_idx = self._mark_idx

in_range = np.ones(mark_idx, dtype=bool)
#in_range = np.ones(self._marks.shape[0], dtype=bool)
if self.p['use_filter']:
std = self.p['filter_std']
n_std = self.p['filter_n_std']
Expand Down Expand Up @@ -280,6 +288,8 @@ def save(self):
np.savez(
filename,
marks=self._marks,
marks_indices = self._chosen_indices,
bufsize = self._config['encoder']['bufsize'],
positions=self._positions,
mark_idx=np.atleast_1d(self._mark_idx),
occupancy=self._occupancy,
Expand Down Expand Up @@ -518,8 +528,8 @@ def _process_spike(self, spike_msg):
t_start_kde, t_end_kde,
t_start_enc_send, t_end_enc_send
)

# record result

self.write_record(
binary_record.RecordIDs.ENCODER_OUTPUT,
spike_timestamp, elec_grp_id,
Expand All @@ -535,14 +545,17 @@ def _process_spike(self, spike_msg):
# either first spike or not enough neighboring spikes
# (assuming filter is on). still record result
else:
if len(mark_vec) is not 8:
print(f"******************mark_vec: {len(mark_vec)}*******************")
Comment on lines +548 to +549
self.write_record(
binary_record.RecordIDs.ENCODER_OUTPUT,
spike_timestamp, elec_grp_id,
self._current_pos, self._current_vel,
encoding_spike, -1, # since didn't compute credible interval
decoder_rank, False,
self.p['vel_thresh'], self.p['frozen_model'],
self._task_state, -1,
self._task_state,
-1,
*mark_vec, *np.zeros(self.p['num_bins'])
)

Expand Down Expand Up @@ -626,7 +639,27 @@ def _process_pos(self, pos_msg):
if self._task_state != 1 and self._save_early:
# we also save encoder models at the end of the program,
# but we do it here as well just to be safe

n_spikes_currently_in_buffer = np.min([encoder._mark_idx,encoder._marks.shape[0]-1])
print(f"n_spikes_current_in_buffer in encoder {encoder._trode}: {n_spikes_currently_in_buffer}")
n_spikes_capacity_buffer = self._config['encoder']['bufsize']
if n_spikes_currently_in_buffer > n_spikes_capacity_buffer:
encoder._chosen_indices = np.sort(np.random.choice(n_spikes_currently_in_buffer,n_spikes_currently_in_buffer,replace=False))
self.class_log.info(
f"in encoder {encoder._trode}: choosing {n_spikes_capacity_buffer} from {n_spikes_currently_in_buffer} spikes"
)
encoder._marks = copy.deepcopy(encoder._marks[encoder._chosen_indices])
encoder._positions = copy.deepcopy(encoder._positions[encoder._chosen_indices])

else:
encoder._chosen_indices = np.arange(encoder._mark_idx)

Comment on lines +643 to +656
encoder.save()
encoder._marks = encoder._marks[:np.min([n_spikes_currently_in_buffer,n_spikes_capacity_buffer])]
encoder._positions = encoder._positions[:np.min([n_spikes_currently_in_buffer,n_spikes_capacity_buffer])]
self.class_log.info(
f"encoder {encoder._trode} shape: {encoder._marks.shape}")

self._save_early = False

self._pos_counter += 1
Expand Down
2 changes: 2 additions & 0 deletions realtime_decoder/ripple_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def add_new_data(self, data):

self._x_ripple[ii, 1:] = self._x_ripple[ii, :-1]
if ii == 0: # new input is incoming data
#print(f"data shape : {data.shape}")
self._x_ripple[ii, 0] = data
else: # new input is IIR output of previous stage
self._x_ripple[ii, 0] = self._y_ripple[ii - 1, 0]
Expand Down Expand Up @@ -715,6 +716,7 @@ def _reset_stats(self, trodes:List):
deviation of the ripple envelope"""

num_signals = len(trodes)
print(f"num_signals: {len(trodes)}")
# stats for individual traces
Comment on lines 718 to 720
self._means = np.zeros(num_signals)
self._M2 = np.zeros(num_signals)
Expand Down
Loading