From 7c110c96e0792c707e015edfee66305ea2278eb4 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 20 Jul 2023 11:36:29 -0700 Subject: [PATCH 1/3] Add option to run model over validation set --- mart/configs/lightning.yaml | 4 +++- mart/tasks/lightning.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mart/configs/lightning.yaml b/mart/configs/lightning.yaml index 714250f3..31d4397e 100644 --- a/mart/configs/lightning.yaml +++ b/mart/configs/lightning.yaml @@ -38,8 +38,10 @@ tags: ["dev"] # Train it or not. fit: True -# check performance on test set, using the best model achieved during training + +# check performance on validation and test set, using the best model achieved during training # lightning chooses best model based on metric specified in checkpoint callback +validate: True test: True # Whether to resume training using configuration and checkpoint in specified directory diff --git a/mart/tasks/lightning.py b/mart/tasks/lightning.py index 3539b81f..65795c8e 100644 --- a/mart/tasks/lightning.py +++ b/mart/tasks/lightning.py @@ -77,6 +77,11 @@ def lightning(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: train_metrics = trainer.callback_metrics + # Evaluate model on validation set, using the best model achieved during training + if cfg.get("validate"): + log.info("Starting validation!") + trainer.validate(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + # Evaluate model on test set, using the best model achieved during training if cfg.get("test"): log.info("Starting testing!") From d329a79614d292cdbb7ddd43c40bc8d6a94852a8 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 20 Jul 2023 11:39:02 -0700 Subject: [PATCH 2/3] style --- mart/configs/lightning.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/mart/configs/lightning.yaml b/mart/configs/lightning.yaml index 31d4397e..7fbd76cd 100644 --- a/mart/configs/lightning.yaml +++ b/mart/configs/lightning.yaml @@ -38,7 +38,6 @@ tags: ["dev"] # Train it or not. fit: True - # check performance on validation and test set, using the best model achieved during training # lightning chooses best model based on metric specified in checkpoint callback validate: True From 991e8f8623072b39654382bc3a68c871afd92a9c Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 20 Jul 2023 11:43:54 -0700 Subject: [PATCH 3/3] Default validate to false --- mart/configs/lightning.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/configs/lightning.yaml b/mart/configs/lightning.yaml index 7fbd76cd..0174bab3 100644 --- a/mart/configs/lightning.yaml +++ b/mart/configs/lightning.yaml @@ -40,7 +40,7 @@ fit: True # check performance on validation and test set, using the best model achieved during training # lightning chooses best model based on metric specified in checkpoint callback -validate: True +validate: False test: True # Whether to resume training using configuration and checkpoint in specified directory