diff --git a/neuroEncoder b/neuroEncoder index a97106c..b9ef0ae 100755 --- a/neuroEncoder +++ b/neuroEncoder @@ -551,13 +551,12 @@ def main(args): ): print("Testing sleep set") # Test sleep as NN - outputs_sleep = TrainerBayes.test_sleep_as_NN( + TrainerBayes.test_sleep_as_NN( DataHelper.fullBehavior, bayesMatrices, windowSizeMS=windowSizeMS, l_function=l_function, ) - print(outputs_sleep) # Save and create alignment tools from neuroencoders.importData.compareSpikeFiltering import WaveFormComparator diff --git a/neuroencoders/fullEncoder/an_network.py b/neuroencoders/fullEncoder/an_network.py index dd292af..4e81de7 100755 --- a/neuroencoders/fullEncoder/an_network.py +++ b/neuroencoders/fullEncoder/an_network.py @@ -22,11 +22,11 @@ import numpy as np import pandas as pd import tensorflow as tf +import wandb from keras import ops as kops from keras.layers import Lambda from tqdm import tqdm - -import wandb +from wandb.integration.keras import WandbMetricsLogger # Get utility functions from neuroencoders.fullEncoder import nnUtils @@ -45,7 +45,6 @@ ) from neuroencoders.importData.epochs_management import get_epochs_mask, inEpochsMask from neuroencoders.utils.global_classes import DataHelper, Params, Project -from wandb.integration.keras import WandbMetricsLogger # We generate a model with the functional Model interface in tensorflow @@ -1409,7 +1408,7 @@ def create_indices(vals): if kwargs.get("inference_mode", False) else {"train": totMask_backup} ) - if not isinstance(speedMask, dict): + if speedMask is not None and not isinstance(speedMask, dict): # it means we have just one set of keys speedMask_backup = speedMask.copy() speedMask = ( diff --git a/neuroencoders/importData/rawdata_parser.py b/neuroencoders/importData/rawdata_parser.py index 8379bb7..8e7e085 100755 --- a/neuroencoders/importData/rawdata_parser.py +++ b/neuroencoders/importData/rawdata_parser.py @@ -1571,11 +1571,15 @@ def update(val): try: ls[dim][2][iaxis].remove() except (AttributeError, KeyError): + # Scatter plot may not exist yet or may have been removed; + # safely ignore and continue to create new plot. pass else: try: ls[dim][2][iaxis].remove() except (AttributeError, KeyError): + # Scatter plot may not exist yet or may have been removed; + # safely ignore and continue to create new plot. pass if SetData["useLossPredTrainSet"]: ls[dim][2] = ax[dim].scatter( @@ -1621,6 +1625,8 @@ def update(val): try: ls[dim][2].remove() except (AttributeError, KeyError): + # Scatter plot may not exist yet or may have been removed; + # safely ignore and continue to create new plot. pass ls[dim][2] = ax[dim].scatter( timeToShow[ @@ -1636,6 +1642,8 @@ def update(val): try: l3.remove() except (AttributeError, KeyError): + # Scatter plot may not exist yet or may have been removed; + # safely ignore and continue to create new plot. pass # modify the xlim of the axes according to the changed epochs diff --git a/neuroencoders/transformData/linearizer.py b/neuroencoders/transformData/linearizer.py index aa7d45d..937e56c 100755 --- a/neuroencoders/transformData/linearizer.py +++ b/neuroencoders/transformData/linearizer.py @@ -362,6 +362,8 @@ def b1update(n): self.lPoints.remove() fig.canvas.draw() except (AttributeError, KeyError): + # Previous scatter points may not exist or may have already been removed; + # in that case there is nothing to clean up before redrawing. pass self.l0s = try_linearization(ax, self.l0s) self.lPoints = ax[0].scatter( @@ -438,6 +440,8 @@ def onclick(event): self.lPoints.remove() fig.canvas.draw() except (AttributeError, KeyError): + # Previous linearization points may not exist or may have already been + # removed; in that case there is nothing to clean up before re-drawing. pass if len(self.nnPoints) > 2: self.n_points = len(self.nnPoints) diff --git a/neuroencoders/utils/MOBS_Functions.py b/neuroencoders/utils/MOBS_Functions.py index 9833d65..40958bc 100755 --- a/neuroencoders/utils/MOBS_Functions.py +++ b/neuroencoders/utils/MOBS_Functions.py @@ -1392,7 +1392,7 @@ def load_results( useTrain=phase != self.phase, useTest=phase != "training", ) - timeStepPred = self.data_helper[epochMask] + timeStepPred = self.data_helper.fullBehavior["positionTime"][epochMask] outputs = self.bayes.test_as_NN( self.data_helper.fullBehavior, self.bayes_matrices, @@ -2025,7 +2025,6 @@ def convert_to_df(self, redo=False): phase_name = suffix.strip("_") if suffix else "all" for id, win in enumerate(self.windows_values): - str(win) data_helper_win = self.data_helper resultsNN_suffix = self.resultsNN_phase[suffix] @@ -2778,8 +2777,6 @@ def mean_error_matrix_linerrors_by_speed( linTrue_fast = [] for _, row in df.iterrows(): # get speed_mask from training Mouse_Results object - row["mouse"] - row["manipe"] speed_mask = ( self.results_df.query( "nameExp == @nameExp and phase == 'training' and winMS == @winMS and mouse == @mouse_val and manipe == @mouse_manipe" @@ -2840,8 +2837,6 @@ def mean_error_matrix_linerrors_by_speed( linTrue = [] for _, row in df.iterrows(): # get speed_mask from training Mouse_Results object - row["mouse"] - row["manipe"] speed_mask = ( self.results_df.query( "nameExp == @nameExp and phase == 'training' and winMS == @winMS and mouse == @mouse_val and manipe == @mouse_manipe" @@ -4076,16 +4071,6 @@ def correlation_per_mouse_spikes( with open(clusters_time_file, "rb") as f: clusters_time = pickle.load(f) except FileNotFoundError: - os.path.abspath( - os.path.join( - mouse_results.folderResult, - "..", - "..", - "last_bayes", - "results", - f"clusters_pre_wTrain_{'True' if row['phase'] == 'training' else 'False'}.pkl", - ) - ) clusters_time_file = os.path.abspath( os.path.join( mouse_results.folderResult, @@ -4290,16 +4275,6 @@ def barplot_correlation_spikes( with open(clusters_time_file, "rb") as f: clusters_time = pickle.load(f) except FileNotFoundError: - os.path.abspath( - os.path.join( - mouse_results.folderResult, - "..", - "..", - "last_bayes", - "results", - f"clusters_pre_wTrain_{'True' if row['phase'] == 'training' else 'False'}.pkl", - ) - ) clusters_time_file = os.path.abspath( os.path.join( mouse_results.folderResult, @@ -5277,7 +5252,8 @@ def get_true_train_mask(row, df): # --- outlier labeling --- df_winMS = err_df.copy() df_winMS = df_winMS.rename(columns={"mouse_manipe": "mouse"}) - df_winMS[ + # Filter the dataframe to relevant columns + df_winMS = df_winMS[ ["stride", "mouse", "mean_error" if reduce_fn == "mean" else "median_error"] ].dropna() diff --git a/neuroencoders/utils/Spike.py b/neuroencoders/utils/Spike.py index bacf1f5..627919c 100755 --- a/neuroencoders/utils/Spike.py +++ b/neuroencoders/utils/Spike.py @@ -95,23 +95,24 @@ def __init__(self, path, time_unit="us"): else: L = [] openstruc(spikes[k], k, L) - for shank in L: - if shank.endswith("_intset"): - shank = shank[0:-7] - start = np.squeeze(spikes[shank]["start"][:]) - stop = np.squeeze(spikes[shank]["stop"][:]) - self.info[shank] = nts.IntervalSet( + for shank_key in L: + if shank_key.endswith("_intset"): + # Remove the "_intset" suffix to get the actual shank name + shank_name = shank_key[0:-7] + start = np.squeeze(spikes[shank_name]["start"][:]) + stop = np.squeeze(spikes[shank_name]["stop"][:]) + self.info[shank_name] = nts.IntervalSet( start, stop, time_units=time_unit ) elif isinstance( - np.squeeze(spikes[shank][:])[0], h5py.h5r.Reference + np.squeeze(spikes[shank_key][:])[0], h5py.h5r.Reference ): - self.info[shank] = ref2str( - np.squeeze(spikes[shank][:]), spikes + self.info[shank_key] = ref2str( + np.squeeze(spikes[shank_key][:]), spikes ) else: - self.info[shank] = np.squeeze(spikes[shank][:]) + self.info[shank_key] = np.squeeze(spikes[shank_key][:]) def get_spikes(self, idx=None): import numpy as np diff --git a/runAllMice.py b/runAllMice.py index db09023..a1d0605 100755 --- a/runAllMice.py +++ b/runAllMice.py @@ -248,7 +248,6 @@ def process_directory(dir, win, force, redo, lstmAndTransfo=False): nbEpochs, "--target", target_bayes, - # "--flat_prior", "--striding", str(win), ] @@ -364,35 +363,35 @@ def run_commands_parallel(mouse_commands): # print(f"Found directories: {dirs}") mouse_commands = {} - for dir in dirs: - if any((mouse in dir or mouse[:-1] in dir) for mouse in mice_nb) or not mice_nb: - if "M1199_MFB" not in dir: - mouse_commands[dir] = [] + for directory in dirs: + if any((mouse in directory or mouse[:-1] in directory) for mouse in mice_nb) or not mice_nb: + if "M1199_MFB" not in directory: + mouse_commands[directory] = [] for win in win_values: if lstm: for lstmAndTransfo in [False, True]: # print( - # f"Processing {dir} with window {win} and lstmAndTransfo {lstmAndTransfo}" + # f"Processing {directory} with window {win} and lstmAndTransfo {lstmAndTransfo}" # ) cmd_ann, cmd_bayes = process_directory( - dir=dir, + dir=directory, win=win, force=force, redo=redo, lstmAndTransfo=lstmAndTransfo, ) if cmd_ann: - mouse_commands[dir].append(cmd_ann) + mouse_commands[directory].append(cmd_ann) if cmd_bayes: - mouse_commands[dir].append(cmd_bayes) + mouse_commands[directory].append(cmd_bayes) else: - cmd_ann, cmd_bayes = process_directory(dir, win, force, redo) + cmd_ann, cmd_bayes = process_directory(directory, win, force, redo) if cmd_ann: - mouse_commands[dir].append(cmd_ann) + mouse_commands[directory].append(cmd_ann) if cmd_bayes: - mouse_commands[dir].append(cmd_bayes) + mouse_commands[directory].append(cmd_bayes) else: - print(f"Processing M1199MFB mouse in directory: {dir}") + print(f"Processing M1199MFB mouse in directory: {directory}") list_exps = [] if any("MFB1" in mouse for mouse in mice_nb): list_exps.append("exp1") @@ -401,15 +400,15 @@ def run_commands_parallel(mouse_commands): if not list_exps: list_exps = ["exp1", "exp2"] for dirmfb in list_exps: - mouse_commands[os.path.join(dir, dirmfb)] = [] + mouse_commands[os.path.join(directory, dirmfb)] = [] for win in win_values: cmd_ann, cmd_bayes = process_directory( - os.path.join(dir, dirmfb), win, force, redo + os.path.join(directory, dirmfb), win, force, redo ) if cmd_ann: - mouse_commands[os.path.join(dir, dirmfb)].append(cmd_ann) + mouse_commands[os.path.join(directory, dirmfb)].append(cmd_ann) if cmd_bayes: - mouse_commands[os.path.join(dir, dirmfb)].append(cmd_bayes) + mouse_commands[os.path.join(directory, dirmfb)].append(cmd_bayes) if rsync: PathForExperiments["realPath"] = PathForExperiments["path"].apply( @@ -417,10 +416,10 @@ def run_commands_parallel(mouse_commands): ) try: Mouse = PathForExperiments[ - PathForExperiments["realPath"] == os.path.realpath(dir) + PathForExperiments["realPath"] == os.path.realpath(directory) ].iloc[0]["name"] print(f"Mouse: {Mouse} from PathForExperiments") - SOURCE = os.path.realpath(dir) + SOURCE = os.path.realpath(directory) DESTINATION = PathForExperiments[ PathForExperiments["name"] == Mouse ].iloc[0]["network_path"] @@ -432,16 +431,16 @@ def run_commands_parallel(mouse_commands): SOURCE, DESTINATION, "--force", - "--dry-run", ] - if "M1199_MFB" not in dir: - mouse_commands[dir].append(runNasCMD) + if "M1199_MFB" not in directory: + mouse_commands[directory].append(runNasCMD) else: for dirmfb in ["exp1", "exp2"]: - mouse_commands[os.path.join(dir, dirmfb)].append(runNasCMD) - except Exception as e: + mouse_commands[os.path.join(directory, dirmfb)].append(runNasCMD) + except (IndexError, KeyError) as e: + # Exception is expected when mouse directory structure is non-standard + # or when mouse is not found in PathForExperiments print(f"Error finding mouse in PathForExperiments: {e}") - pass if mode == "sequential": run_commands_sequentially(mouse_commands)