diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index f06f9ca2c..c3c852485 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -7,19 +7,28 @@ from basicsr.models import lr_scheduler as lr_scheduler from basicsr.utils import get_root_logger +from basicsr.utils import accelerator_util from basicsr.utils.dist_util import master_only +try: + import torch_xla.core.xla_model as xm +except: + pass class BaseModel(): """Base model.""" def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.device = accelerator_util.default_device(opt) self.is_train = opt['is_train'] self.schedulers = [] self.optimizers = [] + @property + def accelerator(self): + return accelerator_util.accelerator_name(self.opt) + def feed_data(self, data): pass @@ -85,13 +94,18 @@ def get_current_log(self): return self.log_dict def model_to_device(self, net): - """Model to device. It also warps models with DistributedDataParallel + """Model to device. It also wraps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) + + if self.accelerator == 'xla': + # No need to use DataParallel or DistributedDataParallel with xmp + return net + if self.opt['dist']: find_unused_parameters = self.opt.get('find_unused_parameters', False) net = DistributedDataParallel( @@ -107,6 +121,13 @@ def get_optimizer(self, optim_type, params, lr, **kwargs): raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') return optimizer + def optimizer_step(self, optimizer): + if self.accelerator == 'xla': + xm.optimizer_step(optimizer) + xm.mark_step() + else: + optimizer.step() + def setup_schedulers(self): """Set up schedulers.""" train_opt = self.opt['train'] @@ -361,7 +382,7 @@ def reduce_loss_dict(self, loss_dict): loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): - if self.opt['dist']: + if self.opt['dist'] and self.accelerator != 'xla': keys = [] losses = [] for name, value in loss_dict.items(): diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 54c80bd6d..784a3cc31 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -111,7 +111,7 @@ def optimize_parameters(self, current_iter): loss_dict['l_style'] = l_style l_total.backward() - self.optimizer_g.step() + self.optimizer_step(self.optimizer_g) self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/basicsr/train.py b/basicsr/train.py index f63149c64..1e4649ff5 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -11,8 +11,14 @@ from basicsr.models import build_model from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) -from basicsr.utils.options import copy_opt_file, dict2str, parse_options +from basicsr.utils.options import copy_opt_file, dict2str, parse_options, preflight_options +from basicsr.utils import accelerator_util +from basicsr.utils import dist_util +try: + import torch_xla.core.xla_model as xm +except: + pass def init_tb_loggers(opt): # initialize wandb logger before tensorboard logger to allow proper sync @@ -88,13 +94,22 @@ def load_resume_state(opt): return resume_state -def train_pipeline(root_path): +def print_xla_device_info(opt): + if accelerator_util.accelerator_name(opt) == 'xla': + import torch_xla.core.xla_model as xm + print(f"XLA device: {xm.xla_device()} [{xm.xla_real_devices([xm.xla_device()])[0]}], ordinal: {xm.get_ordinal()}, replicas: {xm.xrt_world_size()}, is master: {xm.is_master_ordinal()}") + + # These are expected to be the same + rank, world_size = dist_util.get_dist_info() + assert(rank == xm.get_ordinal()) + assert(world_size == xm.xrt_world_size()) + +def _mp_train_pipeline(root_path): # parse options, set distributed setting, set ramdom seed opt, args = parse_options(root_path, is_train=True) opt['root_path'] = root_path - torch.backends.cudnn.benchmark = True - # torch.backends.cudnn.deterministic = True + print_xla_device_info(opt) # load resume states if necessary resume_state = load_resume_state(opt) @@ -139,6 +154,8 @@ def train_pipeline(root_path): if prefetch_mode is None or prefetch_mode == 'cpu': prefetcher = CPUPrefetcher(train_loader) elif prefetch_mode == 'cuda': + if accelerator_util.accelerator_name(opt) != 'cuda': + raise ValueError(f"prefetch_mode cuda is not compatible with accelerator {accelerator_util.accelerator_name(opt)}.") prefetcher = CUDAPrefetcher(train_loader, opt) logger.info(f'Use {prefetch_mode} prefetch dataloader') if opt['datasets']['train'].get('pin_memory') is not True: @@ -196,7 +213,6 @@ def train_pipeline(root_path): iter_timer.start() train_data = prefetcher.next() # end of iter - # end of epoch consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) @@ -210,6 +226,34 @@ def train_pipeline(root_path): tb_logger.close() +def _mp_train(rank, root_path): + _mp_train_pipeline(root_path) + +def train_pipeline(root_path): + """ + Determines whether multiple processes need to be spawned, and then invoke _mp_train_pipeline. + This mode is meant to be used with XLA multiprocessing (but it should work in other environments). + However, it is incompatible with command-line multi-process launchers. + This is also an appropriate entry point to be invoked from Real-ESRGAN for XLA MP support. + """ + # Initial parse to determine whether we need to run under xmp multiprocessing + opt, args = preflight_options() + + if opt.get('accelerator', 'cuda') == 'xla': + # We can't get the number of XLA devices because it would cause a replication error. + # We just assume xla multiprocessing is required, except if a launcher was used. + if args.launcher != "none": + raise ValueError(f"Launcher {args.launcher} is incompatible with XLA multiprocessing.") + + import torch_xla.distributed.xla_multiprocessing as xmp + xmp.spawn(_mp_train, args=(root_path,), start_method='fork') + else: + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + _mp_train_pipeline(root_path) + + if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) train_pipeline(root_path) diff --git a/basicsr/utils/accelerator_util.py b/basicsr/utils/accelerator_util.py new file mode 100644 index 000000000..9290d2c90 --- /dev/null +++ b/basicsr/utils/accelerator_util.py @@ -0,0 +1,35 @@ +import torch +try: + import torch_xla.core.xla_model as xm +except: + pass + +def accelerator_name(opt): + if opt['num_gpu'] == 0: + return 'cpu' + return opt.get('accelerator', 'cuda') + +def default_device(opt): + accelerator = accelerator_name(opt) + if accelerator == 'xla': + return xm.xla_device() + return accelerator + +def device_count(opt): + accelerator = accelerator_name(opt) + if accelerator == 'xla': + # Devices of the same hw family. + # Note: returns 1 when replication is in place! + # device = xm.xla_device() + # devices = xm.get_xla_supported_devices(xm.xla_device_hw(device)) + # return len(devices) + + # This works when replication is active + return xm.xrt_world_size() + return torch.cuda.device_count() + +def use_xmp(opt): + accelerator = accelerator_name(opt) + if accelerator != 'xla': + return False + return device_count(opt) > 1 \ No newline at end of file diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py index 0fab887b2..684749238 100644 --- a/basicsr/utils/dist_util.py +++ b/basicsr/utils/dist_util.py @@ -6,6 +6,16 @@ import torch.distributed as dist import torch.multiprocessing as mp +_xmp_dist = None + +class XMPDist(): + def __init__(self): + super().__init__() + + import torch_xla.core.xla_model as xm + self.rank = xm.get_ordinal() + self.world_size = xm.xrt_world_size() + self.is_master = xm.is_master_ordinal() def init_dist(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: @@ -17,6 +27,9 @@ def init_dist(launcher, backend='nccl', **kwargs): else: raise ValueError(f'Invalid launcher type: {launcher}') +def init_xmp(): + global _xmp_dist + _xmp_dist = XMPDist() def _init_dist_pytorch(backend, **kwargs): rank = int(os.environ['RANK']) @@ -58,6 +71,9 @@ def _init_dist_slurm(backend, port=None): def get_dist_info(): + if _xmp_dist is not None: + return _xmp_dist.rank, _xmp_dist.world_size + if dist.is_available(): initialized = dist.is_initialized() else: @@ -75,6 +91,9 @@ def master_only(func): @functools.wraps(func) def wrapper(*args, **kwargs): + if _xmp_dist is not None and _xmp_dist.is_master: + return func(*args, **kwargs) + rank, _ = get_dist_info() if rank == 0: return func(*args, **kwargs) diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index 09bfa5a5b..6f0ef1823 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -6,8 +6,8 @@ from os import path as osp from basicsr.utils import set_random_seed -from basicsr.utils.dist_util import get_dist_info, init_dist, master_only - +from basicsr.utils.accelerator_util import use_xmp, device_count +from basicsr.utils.dist_util import get_dist_info, init_dist, init_xmp, master_only def ordered_yaml(): """Support OrderedDict for yaml. @@ -79,7 +79,17 @@ def _postprocess_yml_value(value): return value -def parse_options(root_path, is_train=True): +def preflight_options(): + """ + Just parse all the options for initial verification. + + Attempting to access xla devices (such as trying to determine the number of + available devices, for instance) results in xmp not being able to create the + replication devices. + + We use this function so the main training script can determine whether + we need to use XLA MP replication before proceeding further. + """ parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') @@ -94,8 +104,16 @@ def parse_options(root_path, is_train=True): with open(args.opt, mode='r') as f: opt = yaml.load(f, Loader=ordered_yaml()[0]) + return opt, args + +def parse_options(root_path, is_train=True): + opt, args = preflight_options() + # distributed settings - if args.launcher == 'none': + if use_xmp(opt): + opt['dist'] = True + init_xmp() + elif args.launcher == 'none': opt['dist'] = False print('Disable distributed.', flush=True) else: @@ -135,7 +153,7 @@ def parse_options(root_path, is_train=True): opt['name'] = 'debug_' + opt['name'] if opt['num_gpu'] == 'auto': - opt['num_gpu'] = torch.cuda.device_count() + opt['num_gpu'] = device_count(opt) # datasets for phase, dataset in opt['datasets'].items():