diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py index db6c63a7..e447f9a2 100644 --- a/mambular/data_utils/dataset.py +++ b/mambular/data_utils/dataset.py @@ -24,6 +24,8 @@ def __init__( labels=None, regression=True, ): + assert cat_features_list or num_features_list + self.cat_features_list = cat_features_list # Categorical features tensors self.num_features_list = num_features_list # Numerical features tensors self.embeddings_list = embeddings_list # Embeddings tensors (optional) @@ -44,7 +46,8 @@ def __init__( self.labels = None # No labels in prediction mode def __len__(self): - return len(self.num_features_list[0]) # Use numerical features length + _feats = self.num_features_list if self.num_features_list else self.cat_features_list + return len(_feats[0]) def __getitem__(self, idx): """Retrieves the features and label for a given index.