diff --git a/mart/models/modular.py b/mart/models/modular.py index 141461bc..a27c6867 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -134,10 +134,6 @@ def training_step(self, batch, batch_idx): for log_name, output_key in self.training_step_log.items(): self.log(f"training/{log_name}", output[output_key]) - assert "loss" in output - return output - - def training_step_end(self, output): if self.training_metrics is not None: # Some models only return loss in the training mode. if self.output_preds_key not in output or self.output_target_key not in output: @@ -145,8 +141,8 @@ def training_step_end(self, output): f"You have specified training_metrics, but the model does not return {self.output_preds_key} or {self.output_target_key} during training. You can either nullify training_metrics or configure the model to return {self.output_preds_key} and {self.output_target_key} in the training output." ) self.training_metrics(output[self.output_preds_key], output[self.output_target_key]) - loss = output.pop(self.output_loss_key) - return loss + + return output[self.output_loss_key] def training_epoch_end(self, outputs): if self.training_metrics is not None: @@ -168,13 +164,9 @@ def validation_step(self, batch, batch_idx): for log_name, output_key in self.validation_step_log.items(): self.log(f"validation/{log_name}", output[output_key]) - return output - - def validation_step_end(self, output): self.validation_metrics(output[self.output_preds_key], output[self.output_target_key]) - # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) - output.clear() + return None def validation_epoch_end(self, outputs): metrics = self.validation_metrics.compute() @@ -194,13 +186,9 @@ def test_step(self, batch, batch_idx): for log_name, output_key in self.test_step_log.items(): self.log(f"test/{log_name}", output[output_key]) - return output - - def test_step_end(self, output): self.test_metrics(output[self.output_preds_key], output[self.output_target_key]) - # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) - output.clear() + return None def test_epoch_end(self, outputs): metrics = self.test_metrics.compute()