Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions neuroEncoder
Original file line number Diff line number Diff line change
Expand Up @@ -551,13 +551,12 @@ def main(args):
):
print("Testing sleep set")
# Test sleep as NN
outputs_sleep = TrainerBayes.test_sleep_as_NN(
TrainerBayes.test_sleep_as_NN(
DataHelper.fullBehavior,
bayesMatrices,
windowSizeMS=windowSizeMS,
l_function=l_function,
)
print(outputs_sleep)

# Save and create alignment tools
from neuroencoders.importData.compareSpikeFiltering import WaveFormComparator
Expand Down
7 changes: 3 additions & 4 deletions neuroencoders/fullEncoder/an_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import numpy as np
import pandas as pd
import tensorflow as tf
import wandb
from keras import ops as kops
from keras.layers import Lambda
from tqdm import tqdm

import wandb
from wandb.integration.keras import WandbMetricsLogger

# Get utility functions
from neuroencoders.fullEncoder import nnUtils
Expand All @@ -45,7 +45,6 @@
)
from neuroencoders.importData.epochs_management import get_epochs_mask, inEpochsMask
from neuroencoders.utils.global_classes import DataHelper, Params, Project
from wandb.integration.keras import WandbMetricsLogger


# We generate a model with the functional Model interface in tensorflow
Expand Down Expand Up @@ -1409,7 +1408,7 @@ def create_indices(vals):
if kwargs.get("inference_mode", False)
else {"train": totMask_backup}
)
if not isinstance(speedMask, dict):
if speedMask is not None and not isinstance(speedMask, dict):
# it means we have just one set of keys
speedMask_backup = speedMask.copy()
speedMask = (
Expand Down
8 changes: 8 additions & 0 deletions neuroencoders/importData/rawdata_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,11 +1571,15 @@ def update(val):
try:
ls[dim][2][iaxis].remove()
except (AttributeError, KeyError):
# Scatter plot may not exist yet or may have been removed;
# safely ignore and continue to create new plot.
pass
else:
try:
ls[dim][2][iaxis].remove()
except (AttributeError, KeyError):
# Scatter plot may not exist yet or may have been removed;
# safely ignore and continue to create new plot.
pass
if SetData["useLossPredTrainSet"]:
ls[dim][2] = ax[dim].scatter(
Expand Down Expand Up @@ -1621,6 +1625,8 @@ def update(val):
try:
ls[dim][2].remove()
except (AttributeError, KeyError):
# Scatter plot may not exist yet or may have been removed;
# safely ignore and continue to create new plot.
pass
ls[dim][2] = ax[dim].scatter(
timeToShow[
Expand All @@ -1636,6 +1642,8 @@ def update(val):
try:
l3.remove()
except (AttributeError, KeyError):
# Scatter plot may not exist yet or may have been removed;
# safely ignore and continue to create new plot.
pass

# modify the xlim of the axes according to the changed epochs
Expand Down
4 changes: 4 additions & 0 deletions neuroencoders/transformData/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ def b1update(n):
self.lPoints.remove()
fig.canvas.draw()
except (AttributeError, KeyError):
# Previous scatter points may not exist or may have already been removed;
# in that case there is nothing to clean up before redrawing.
pass
self.l0s = try_linearization(ax, self.l0s)
self.lPoints = ax[0].scatter(
Expand Down Expand Up @@ -438,6 +440,8 @@ def onclick(event):
self.lPoints.remove()
fig.canvas.draw()
except (AttributeError, KeyError):
# Previous linearization points may not exist or may have already been
# removed; in that case there is nothing to clean up before re-drawing.
pass
if len(self.nnPoints) > 2:
self.n_points = len(self.nnPoints)
Expand Down
30 changes: 3 additions & 27 deletions neuroencoders/utils/MOBS_Functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ def load_results(
useTrain=phase != self.phase,
useTest=phase != "training",
)
timeStepPred = self.data_helper[epochMask]
timeStepPred = self.data_helper.fullBehavior["positionTime"][epochMask]
outputs = self.bayes.test_as_NN(
self.data_helper.fullBehavior,
self.bayes_matrices,
Expand Down Expand Up @@ -2025,7 +2025,6 @@ def convert_to_df(self, redo=False):
phase_name = suffix.strip("_") if suffix else "all"

for id, win in enumerate(self.windows_values):
str(win)
data_helper_win = self.data_helper
resultsNN_suffix = self.resultsNN_phase[suffix]

Expand Down Expand Up @@ -2778,8 +2777,6 @@ def mean_error_matrix_linerrors_by_speed(
linTrue_fast = []
for _, row in df.iterrows():
# get speed_mask from training Mouse_Results object
row["mouse"]
row["manipe"]
speed_mask = (
self.results_df.query(
"nameExp == @nameExp and phase == 'training' and winMS == @winMS and mouse == @mouse_val and manipe == @mouse_manipe"
Expand Down Expand Up @@ -2840,8 +2837,6 @@ def mean_error_matrix_linerrors_by_speed(
linTrue = []
for _, row in df.iterrows():
# get speed_mask from training Mouse_Results object
row["mouse"]
row["manipe"]
speed_mask = (
self.results_df.query(
"nameExp == @nameExp and phase == 'training' and winMS == @winMS and mouse == @mouse_val and manipe == @mouse_manipe"
Expand Down Expand Up @@ -4076,16 +4071,6 @@ def correlation_per_mouse_spikes(
with open(clusters_time_file, "rb") as f:
clusters_time = pickle.load(f)
except FileNotFoundError:
os.path.abspath(
os.path.join(
mouse_results.folderResult,
"..",
"..",
"last_bayes",
"results",
f"clusters_pre_wTrain_{'True' if row['phase'] == 'training' else 'False'}.pkl",
)
)
clusters_time_file = os.path.abspath(
os.path.join(
mouse_results.folderResult,
Expand Down Expand Up @@ -4290,16 +4275,6 @@ def barplot_correlation_spikes(
with open(clusters_time_file, "rb") as f:
clusters_time = pickle.load(f)
except FileNotFoundError:
os.path.abspath(
os.path.join(
mouse_results.folderResult,
"..",
"..",
"last_bayes",
"results",
f"clusters_pre_wTrain_{'True' if row['phase'] == 'training' else 'False'}.pkl",
)
)
clusters_time_file = os.path.abspath(
os.path.join(
mouse_results.folderResult,
Expand Down Expand Up @@ -5277,7 +5252,8 @@ def get_true_train_mask(row, df):
# --- outlier labeling ---
df_winMS = err_df.copy()
df_winMS = df_winMS.rename(columns={"mouse_manipe": "mouse"})
df_winMS[
# Filter the dataframe to relevant columns
df_winMS = df_winMS[
["stride", "mouse", "mean_error" if reduce_fn == "mean" else "median_error"]
].dropna()

Expand Down
21 changes: 11 additions & 10 deletions neuroencoders/utils/Spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,24 @@ def __init__(self, path, time_unit="us"):
else:
L = []
openstruc(spikes[k], k, L)
for shank in L:
if shank.endswith("_intset"):
shank = shank[0:-7]
start = np.squeeze(spikes[shank]["start"][:])
stop = np.squeeze(spikes[shank]["stop"][:])
self.info[shank] = nts.IntervalSet(
for shank_key in L:
if shank_key.endswith("_intset"):
# Remove the "_intset" suffix to get the actual shank name
shank_name = shank_key[0:-7]
start = np.squeeze(spikes[shank_name]["start"][:])
stop = np.squeeze(spikes[shank_name]["stop"][:])
self.info[shank_name] = nts.IntervalSet(
start, stop, time_units=time_unit
)
elif isinstance(
np.squeeze(spikes[shank][:])[0], h5py.h5r.Reference
np.squeeze(spikes[shank_key][:])[0], h5py.h5r.Reference
):
self.info[shank] = ref2str(
np.squeeze(spikes[shank][:]), spikes
self.info[shank_key] = ref2str(
np.squeeze(spikes[shank_key][:]), spikes
)

else:
self.info[shank] = np.squeeze(spikes[shank][:])
self.info[shank_key] = np.squeeze(spikes[shank_key][:])

def get_spikes(self, idx=None):
import numpy as np
Expand Down
49 changes: 24 additions & 25 deletions runAllMice.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ def process_directory(dir, win, force, redo, lstmAndTransfo=False):
nbEpochs,
"--target",
target_bayes,
# "--flat_prior",
"--striding",
str(win),
]
Expand Down Expand Up @@ -364,35 +363,35 @@ def run_commands_parallel(mouse_commands):

# print(f"Found directories: {dirs}")
mouse_commands = {}
for dir in dirs:
if any((mouse in dir or mouse[:-1] in dir) for mouse in mice_nb) or not mice_nb:
if "M1199_MFB" not in dir:
mouse_commands[dir] = []
for directory in dirs:
if any((mouse in directory or mouse[:-1] in directory) for mouse in mice_nb) or not mice_nb:
if "M1199_MFB" not in directory:
mouse_commands[directory] = []
for win in win_values:
if lstm:
for lstmAndTransfo in [False, True]:
# print(
# f"Processing {dir} with window {win} and lstmAndTransfo {lstmAndTransfo}"
# f"Processing {directory} with window {win} and lstmAndTransfo {lstmAndTransfo}"
# )
cmd_ann, cmd_bayes = process_directory(
dir=dir,
dir=directory,
win=win,
force=force,
redo=redo,
lstmAndTransfo=lstmAndTransfo,
)
if cmd_ann:
mouse_commands[dir].append(cmd_ann)
mouse_commands[directory].append(cmd_ann)
if cmd_bayes:
mouse_commands[dir].append(cmd_bayes)
mouse_commands[directory].append(cmd_bayes)
else:
cmd_ann, cmd_bayes = process_directory(dir, win, force, redo)
cmd_ann, cmd_bayes = process_directory(directory, win, force, redo)
if cmd_ann:
mouse_commands[dir].append(cmd_ann)
mouse_commands[directory].append(cmd_ann)
if cmd_bayes:
mouse_commands[dir].append(cmd_bayes)
mouse_commands[directory].append(cmd_bayes)
else:
print(f"Processing M1199MFB mouse in directory: {dir}")
print(f"Processing M1199MFB mouse in directory: {directory}")
list_exps = []
if any("MFB1" in mouse for mouse in mice_nb):
list_exps.append("exp1")
Expand All @@ -401,26 +400,26 @@ def run_commands_parallel(mouse_commands):
if not list_exps:
list_exps = ["exp1", "exp2"]
for dirmfb in list_exps:
mouse_commands[os.path.join(dir, dirmfb)] = []
mouse_commands[os.path.join(directory, dirmfb)] = []
for win in win_values:
cmd_ann, cmd_bayes = process_directory(
os.path.join(dir, dirmfb), win, force, redo
os.path.join(directory, dirmfb), win, force, redo
)
if cmd_ann:
mouse_commands[os.path.join(dir, dirmfb)].append(cmd_ann)
mouse_commands[os.path.join(directory, dirmfb)].append(cmd_ann)
if cmd_bayes:
mouse_commands[os.path.join(dir, dirmfb)].append(cmd_bayes)
mouse_commands[os.path.join(directory, dirmfb)].append(cmd_bayes)

if rsync:
PathForExperiments["realPath"] = PathForExperiments["path"].apply(
lambda x: os.path.realpath(x)
)
try:
Mouse = PathForExperiments[
PathForExperiments["realPath"] == os.path.realpath(dir)
PathForExperiments["realPath"] == os.path.realpath(directory)
].iloc[0]["name"]
print(f"Mouse: {Mouse} from PathForExperiments")
SOURCE = os.path.realpath(dir)
SOURCE = os.path.realpath(directory)
DESTINATION = PathForExperiments[
PathForExperiments["name"] == Mouse
].iloc[0]["network_path"]
Expand All @@ -432,16 +431,16 @@ def run_commands_parallel(mouse_commands):
SOURCE,
DESTINATION,
"--force",
"--dry-run",
]
if "M1199_MFB" not in dir:
mouse_commands[dir].append(runNasCMD)
if "M1199_MFB" not in directory:
mouse_commands[directory].append(runNasCMD)
else:
for dirmfb in ["exp1", "exp2"]:
mouse_commands[os.path.join(dir, dirmfb)].append(runNasCMD)
except Exception as e:
mouse_commands[os.path.join(directory, dirmfb)].append(runNasCMD)
except (IndexError, KeyError) as e:
# Exception is expected when mouse directory structure is non-standard
# or when mouse is not found in PathForExperiments
print(f"Error finding mouse in PathForExperiments: {e}")
pass

if mode == "sequential":
run_commands_sequentially(mouse_commands)
Expand Down