From ae59513e200ffa57a4b3aa2ac1437178f091581f Mon Sep 17 00:00:00 2001 From: Neerav Kaushal Date: Mon, 12 Jul 2021 12:41:32 -0400 Subject: [PATCH] update train.py for bypass=False in VNet input needs to be added to target and output for bypass=False, so that the eulerian loss for bypass and no bypass equals. --- map2map/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/map2map/train.py b/map2map/train.py index c97d0c4..abafaf2 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -264,6 +264,7 @@ def train(epoch, loader, model, criterion, lag_out, lag_tgt = output, target eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) + #eul_out, eul_tgt = lag2eul([lag_out+input, lag_tgt+input], **args.misc_kwargs) #----for bypass=False if batch <= 5 and rank == 0: print('Eulerian shape :', eul_out.shape, flush=True) @@ -357,6 +358,7 @@ def validate(epoch, loader, model, criterion, logger, device, args): lag_out, lag_tgt = output, target eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) + #eul_out, eul_tgt = lag2eul([lag_out+input, lag_tgt+input], **args.misc_kwargs) #----for bypass=False lag_loss = criterion(lag_out, lag_tgt) eul_loss = criterion(eul_out, eul_tgt)