diff --git a/realtime_decoder/encoder_process.py b/realtime_decoder/encoder_process.py index d2161b6..ce03ddb 100644 --- a/realtime_decoder/encoder_process.py +++ b/realtime_decoder/encoder_process.py @@ -3,7 +3,7 @@ import time import fcntl import numpy as np - +import copy from mpi4py import MPI from typing import Sequence, List @@ -81,6 +81,7 @@ def __init__(self, config, trode, pos_bin_struct): N = self._config['encoder']['bufsize'] dim = self._config['encoder']['mark_dim'] self._marks = np.zeros((N, dim), dtype='= self._marks.shape[0]: - mark_idx = self._marks.shape[0] - else: - mark_idx = self._mark_idx + #NOTE(DS): if number of mark exceeds + if self._mark_idx >= self._config['encoder']['bufsize']:#self._marks.shape[0]: + mark_idx = self._config['encoder']['bufsize'] - #print(mark) + else: + mark_idx = self._mark_idx in_range = np.ones(mark_idx, dtype=bool) + #in_range = np.ones(self._marks.shape[0], dtype=bool) if self.p['use_filter']: std = self.p['filter_std'] n_std = self.p['filter_n_std'] @@ -280,6 +288,8 @@ def save(self): np.savez( filename, marks=self._marks, + marks_indices = self._chosen_indices, + bufsize = self._config['encoder']['bufsize'], positions=self._positions, mark_idx=np.atleast_1d(self._mark_idx), occupancy=self._occupancy, @@ -518,8 +528,8 @@ def _process_spike(self, spike_msg): t_start_kde, t_end_kde, t_start_enc_send, t_end_enc_send ) - # record result + self.write_record( binary_record.RecordIDs.ENCODER_OUTPUT, spike_timestamp, elec_grp_id, @@ -535,6 +545,8 @@ def _process_spike(self, spike_msg): # either first spike or not enough neighboring spikes # (assuming filter is on). still record result else: + if len(mark_vec) is not 8: + print(f"******************mark_vec: {len(mark_vec)}*******************") self.write_record( binary_record.RecordIDs.ENCODER_OUTPUT, spike_timestamp, elec_grp_id, @@ -542,7 +554,8 @@ def _process_spike(self, spike_msg): encoding_spike, -1, # since didn't compute credible interval decoder_rank, False, self.p['vel_thresh'], self.p['frozen_model'], - self._task_state, -1, + self._task_state, + -1, *mark_vec, *np.zeros(self.p['num_bins']) ) @@ -626,7 +639,27 @@ def _process_pos(self, pos_msg): if self._task_state != 1 and self._save_early: # we also save encoder models at the end of the program, # but we do it here as well just to be safe + + n_spikes_currently_in_buffer = np.min([encoder._mark_idx,encoder._marks.shape[0]-1]) + print(f"n_spikes_current_in_buffer in encoder {encoder._trode}: {n_spikes_currently_in_buffer}") + n_spikes_capacity_buffer = self._config['encoder']['bufsize'] + if n_spikes_currently_in_buffer > n_spikes_capacity_buffer: + encoder._chosen_indices = np.sort(np.random.choice(n_spikes_currently_in_buffer,n_spikes_currently_in_buffer,replace=False)) + self.class_log.info( + f"in encoder {encoder._trode}: choosing {n_spikes_capacity_buffer} from {n_spikes_currently_in_buffer} spikes" + ) + encoder._marks = copy.deepcopy(encoder._marks[encoder._chosen_indices]) + encoder._positions = copy.deepcopy(encoder._positions[encoder._chosen_indices]) + + else: + encoder._chosen_indices = np.arange(encoder._mark_idx) + encoder.save() + encoder._marks = encoder._marks[:np.min([n_spikes_currently_in_buffer,n_spikes_capacity_buffer])] + encoder._positions = encoder._positions[:np.min([n_spikes_currently_in_buffer,n_spikes_capacity_buffer])] + self.class_log.info( + f"encoder {encoder._trode} shape: {encoder._marks.shape}") + self._save_early = False self._pos_counter += 1 diff --git a/realtime_decoder/ripple_process.py b/realtime_decoder/ripple_process.py index dcc156e..bef03fa 100644 --- a/realtime_decoder/ripple_process.py +++ b/realtime_decoder/ripple_process.py @@ -108,6 +108,7 @@ def add_new_data(self, data): self._x_ripple[ii, 1:] = self._x_ripple[ii, :-1] if ii == 0: # new input is incoming data + #print(f"data shape : {data.shape}") self._x_ripple[ii, 0] = data else: # new input is IIR output of previous stage self._x_ripple[ii, 0] = self._y_ripple[ii - 1, 0] @@ -715,6 +716,7 @@ def _reset_stats(self, trodes:List): deviation of the ripple envelope""" num_signals = len(trodes) + print(f"num_signals: {len(trodes)}") # stats for individual traces self._means = np.zeros(num_signals) self._M2 = np.zeros(num_signals) diff --git a/realtime_decoder/stimulation.py b/realtime_decoder/stimulation.py index f9858ea..74e25cf 100644 --- a/realtime_decoder/stimulation.py +++ b/realtime_decoder/stimulation.py @@ -65,13 +65,15 @@ def __init__(self, comm, rank, config, trodes_client): avg_spike_rate_labels = [f'avg_spike_rate_{x}' for x in range(num_decoders)] credible_int_labels = [f'credible_int_{x}' for x in range(num_decoders)] - rls = ['region_box', 'region_arm1', 'region_arm2'] + arm_ids = config['encoder']['position']['arm_ids'] + + rls = ['region_box' if arm_id == 0 else f'region_arm{arm_id}' for arm_id in arm_ids] region_labels = [f'{rl}_{x}' for x in range(num_decoders) for rl in rls] - bls = ['base_box', 'base_arm1', 'base_arm2'] + bls = ['base_box' if arm_id == 0 else f'base_arm{arm_id}' for arm_id in arm_ids] base_labels = [f'{bl}_{x}' for x in range(num_decoders) for bl in bls] - armls = ['box', 'arm1', 'arm2'] + armls = ['box' if arm_id == 0 else f'arm{arm_id}' for arm_id in arm_ids] arm_labels = [f'{arml}_{x}' for x in range(num_decoders) for arml in armls] rtrodes = config['trode_selection']['ripples'] @@ -191,12 +193,22 @@ def __init__(self, comm, rank, config, trodes_client): self._automatic_threshold_update = self._config['stimulation']['automatic_threshold_update'] self._num_scm_each_arm_per_minute = self._config['stimulation']['num_each_arm_per_minute'] - self._arm_1_posterior = [0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.26,0.27] #NOTE(DS): in case the buffer is too small - self._arm_2_posterior = [0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.26,0.27] + self._arm_1_posterior = [0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.26,0.27, + 0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25, + 0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,] #NOTE(DS): in case the buffer is too small + self._arm_2_posterior = [0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.26,0.27, + 0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25, + 0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,] #NOTE(DS): in case the buffer is too small + self._initial_number_of_posterior_buffer_values = np.array(self._arm_1_posterior).shape[0] + print(f"initial number of posterior puffer values: {self._initial_number_of_posterior_buffer_values}") self._task_state_2_start_time = None self._timepoints_per_sec = self._config['sampling_rate']['spikes'] self._elapsed_minutes = 0 + self._arm_1_lower_threshold_but_many_spikes_event = 0 + self._arm_2_lower_threshold_but_many_spikes_event = 0 + + def handle_message(self, msg, mpi_status): """Process a (non neural data) received MPI message""" @@ -754,6 +766,12 @@ def _update_prob_sums(self, marginal_prob): ind = self._dec_ind arm_probs = self._compute_arm_probs(marginal_prob) + ''' + print(f"self._arm_ps_buff: {self._arm_ps_buff}") + print(f"ind: {ind}") + print(f"self._dd_ind: {self._dd_ind}") + print(f"arm_probs: {arm_probs}") + ''' self._arm_ps_buff[ind, self._dd_ind] = arm_probs ps_arm1, ps_arm2, ps_arm1_base, ps_arm2_base = self._compute_region_probs( @@ -789,27 +807,18 @@ def _compute_region_probs(self, prob): # this configurable if the position bin size changes, # for example #NOTE(DS): this is with 41 bins -- original - arm1_start_bin = 25 - int(self._well_angle_range) - arm2_start_bin = 41 - int(self._within_angle_range) - - ps_arm1 = prob[arm1_start_bin:25].sum() #NOTE(DS): originally, 20-25 - ps_arm2 = prob[arm2_start_bin:41].sum() #NOTE(DS): originally, 36-41 - ps_arm1_base = prob[13:arm1_start_bin].sum() - ps_arm2_base = prob[29:arm2_start_bin].sum() - - ''' - #NOTE(DS): This is with 21 bins -- modified - arm1_end_bin = 13 - arm2_end_bin = 21 - arm1_target_start_bin = arm1_end_bin - int(self._well_angle_range) - arm2_target_start_bin = arm2_end_bin - int(self._within_angle_range) - - - ps_arm1 = prob[arm1_target_start_bin:arm1_end_bin].sum() #NOTE(DS): originally, 20-25 - ps_arm2 = prob[arm2_target_start_bin:arm2_end_bin].sum() #NOTE(DS): originally, 36-41 - ps_arm1_base = prob[7:arm1_target_start_bin].sum() - ps_arm2_base = prob[15:arm2_target_start_bin].sum() - ''' + arm1_end = self.p['arm_coords'][1][1] + arm2_end = self.p['arm_coords'][2][1] + arm1_start = self.p['arm_coords'][1][0] + arm2_start = self.p['arm_coords'][2][0] + + arm1_detection_start_bin = arm1_end - int(self._well_angle_range) + 1 + arm2_detection_start_bin = arm2_end - int(self._within_angle_range) + 1 + + ps_arm1 = prob[arm1_detection_start_bin:(arm1_end+1)].sum() #NOTE(DS): originally, 20-25 + ps_arm2 = prob[arm2_detection_start_bin:(arm2_end+1)].sum() #NOTE(DS): originally, 36-41 + ps_arm1_base = prob[arm1_start:arm1_detection_start_bin].sum() + ps_arm2_base = prob[arm1_start:arm2_detection_start_bin].sum() return ps_arm1, ps_arm2, ps_arm1_base, ps_arm2_base @@ -958,6 +967,11 @@ def _handle_replay(self, arm, msg): else: if above_threshold == False: print(f"Replay arm {arm} detected with lower target posterior prob: {target_posterior_prob} than threshold: {arm_thresh}") + if num_spikes_in_event >= 6: + if arm == 1: + self._arm_1_lower_threshold_but_many_spikes_event += 1 + elif arm == 2: + self._arm_2_lower_threshold_but_many_spikes_event += 1 else: print(f" ") print(f" ") @@ -977,13 +991,17 @@ def _handle_replay(self, arm, msg): print(f"num spikes(TS{self._task_state}) : {num_spikes_in_event}, {trodes_of_spike}") print(f"Unique trodes(TS{self._task_state}): {num_unique}, {np.unique(trodes_of_spike)}") + potentially_duplicated_spikes = False + if num_unique == 2: + if np.abs(np.diff(np.unique(trodes_of_spike))) == 1: + potentially_duplicated_spikes = True + send_shortcut = self._check_send_shortcut( self.p_replay['enabled'] - ) and (above_threshold or num_spikes_in_event >= 6) # NOTE(DS): num_spikes_in_event >6 is to detect SWR + ) and (above_threshold or num_spikes_in_event >= 6) and (not potentially_duplicated_spikes) # NOTE(DS): num_spikes_in_event >6 is to detect SWR if num_unique >= self.p_replay['min_unique_trodes']: - if send_shortcut: if arm == 1: self._trodes_client.send_statescript_shortcut_message(14) @@ -1003,7 +1021,8 @@ def _handle_replay(self, arm, msg): print(f" ") print(f" ") - if (np.sum(self._num_rewards[1:]) in [5,10,20,40]) and self._automatic_threshold_update: + #NOTE(DS): now update everytime + if (np.sum(self._num_rewards[1:]) in np.arange(5,200)) and self._automatic_threshold_update: self._update_scm_threshold() self.write_record( @@ -1033,14 +1052,70 @@ def _handle_replay(self, arm, msg): ) def _update_scm_threshold(self): + + #NOTE(DS): This is old version -- for SC92 that I determine the number of events/minutess desired_number_of_scm = np.ceil(self._num_scm_each_arm_per_minute * self._elapsed_minutes) - index_for_desired_number_of_scm = -int(desired_number_of_scm + 1) - self.p_replay['primary_arm_threshold'] = np.sort(self._arm_1_posterior)[index_for_desired_number_of_scm] - self.p_replay['secondary_arm_threshold'] = np.sort(self._arm_2_posterior)[index_for_desired_number_of_scm] + desired_number_of_scm_arm1 = desired_number_of_scm - self._arm_1_lower_threshold_but_many_spikes_event + desired_number_of_scm_arm2 = desired_number_of_scm - self._arm_2_lower_threshold_but_many_spikes_event - print(f"new arm 1 thresh: {self.p_replay['primary_arm_threshold']} and arm 2 thresh: {self.p_replay['secondary_arm_threshold']}") + if desired_number_of_scm_arm1 < 1: + desired_number_of_scm_arm1 = 0 + + if desired_number_of_scm_arm2 < 1: + desired_number_of_scm_arm2 = 0 + + + index_for_desired_number_of_scm1 = -int(desired_number_of_scm_arm1 + 1) + index_for_desired_number_of_scm2 = -int(desired_number_of_scm_arm2 + 1) + + + + self.p_replay['primary_arm_threshold'] = np.sort(self._arm_1_posterior)[index_for_desired_number_of_scm1] + self.p_replay['secondary_arm_threshold'] = np.sort(self._arm_2_posterior)[index_for_desired_number_of_scm2] + + print(f"number of arm 1 detected events: {len(self._arm_1_posterior)- self._initial_number_of_posterior_buffer_values}" ) + print(f"number of arm 2 detected events:{len(self._arm_2_posterior)- self._initial_number_of_posterior_buffer_values}" ) + print(f"number of arm 1 below threshold but many cell events:{self._arm_1_lower_threshold_but_many_spikes_event}") + print(f"number of arm 2 below threshold but many cell events:{self._arm_2_lower_threshold_but_many_spikes_event}") + ''' + baseline_threshold = 0.25 + diff_num_detected_event_threshold = 2 + + + #desired_number_of_scm = np.min([self._num_rewards[1],self._num_rewards[2]]) + diff_num_detected_event_threshold + print(f"number of arm 1 detected events: {len(self._arm_1_posterior)- self._initial_number_of_posterior_buffer_values}" ) + print(f"number of arm 2 detected events:{len(self._arm_2_posterior)- self._initial_number_of_posterior_buffer_values}" ) + print(f"number of arm 1 below threshold but many cell events:{self._arm_1_lower_threshold_but_many_spikes_event}") + print(f"number of arm 2 below threshold but many cell events:{self._arm_2_lower_threshold_but_many_spikes_event}") + + desired_number_of_scm = np.min([len(self._arm_1_posterior),len(self._arm_2_posterior)]) -\ + self._initial_number_of_posterior_buffer_values + diff_num_detected_event_threshold + if len(self._arm_1_posterior) < len(self._arm_2_posterior): + desired_number_of_scm_arm1 = desired_number_of_scm + desired_number_of_scm_arm2 = desired_number_of_scm - self._arm_2_lower_threshold_but_many_spikes_event + + elif len(self._arm_1_posterior) > len(self._arm_2_posterior): + desired_number_of_scm_arm1 = desired_number_of_scm - self._arm_1_lower_threshold_but_many_spikes_event + desired_number_of_scm_arm2 = desired_number_of_scm + + else: + desired_number_of_scm_arm1 = desired_number_of_scm + desired_number_of_scm_arm2 = desired_number_of_scm + + if desired_number_of_scm_arm1 < 1: + desired_number_of_scm_arm1 = 0 + if desired_number_of_scm_arm2 < 1: + desired_number_of_scm_arm2 = 0 + + index_for_desired_number_of_scm_arm1 = -int(desired_number_of_scm_arm1 + 1) + index_for_desired_number_of_scm_arm2 = -int(desired_number_of_scm_arm2 + 1) + + self.p_replay['primary_arm_threshold'] = np.sort(self._arm_1_posterior)[index_for_desired_number_of_scm_arm1] + self.p_replay['secondary_arm_threshold'] = np.sort(self._arm_2_posterior)[index_for_desired_number_of_scm_arm2] + ''' + print(f"new arm 1 thresh: {self.p_replay['primary_arm_threshold']} and arm 2 thresh: {self.p_replay['secondary_arm_threshold']}") def _find_replay_instructive(self, msg): """Look for a potential replay event for an instructive task""" @@ -1261,7 +1336,7 @@ def _init_data_buffers(self): # dim 2 is 3 because 3 regions - box, arm1, arm2 # _ps_ - shorthand for probability sum - self._arm_ps_buff = np.zeros((num_decoders, N, 3)) + self._arm_ps_buff = np.zeros((num_decoders, N, len(self._config['encoder']['position']['arm_ids']))) self._region_ps_buff = np.zeros_like(self._arm_ps_buff) self._region_ps_base_buff = np.zeros_like(self._arm_ps_buff)