diff --git a/mart/configs/lightning.yaml b/mart/configs/lightning.yaml index 714250f3..0174bab3 100644 --- a/mart/configs/lightning.yaml +++ b/mart/configs/lightning.yaml @@ -38,8 +38,9 @@ tags: ["dev"] # Train it or not. fit: True -# check performance on test set, using the best model achieved during training +# check performance on validation and test set, using the best model achieved during training # lightning chooses best model based on metric specified in checkpoint callback +validate: False test: True # Whether to resume training using configuration and checkpoint in specified directory diff --git a/mart/tasks/lightning.py b/mart/tasks/lightning.py index 3539b81f..65795c8e 100644 --- a/mart/tasks/lightning.py +++ b/mart/tasks/lightning.py @@ -77,6 +77,11 @@ def lightning(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: train_metrics = trainer.callback_metrics + # Evaluate model on validation set, using the best model achieved during training + if cfg.get("validate"): + log.info("Starting validation!") + trainer.validate(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + # Evaluate model on test set, using the best model achieved during training if cfg.get("test"): log.info("Starting testing!")