From 2d3a7560f550bee62a7b9b866d24f01f04fa68ef Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 27 Aug 2024 14:05:34 +0900 Subject: [PATCH 1/5] Enable onnx validation --- yolo/config/task/validation.yaml | 1 + yolo/lazy.py | 3 +++ yolo/tools/solver.py | 5 +++-- yolo/utils/deploy_utils.py | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/yolo/config/task/validation.yaml b/yolo/config/task/validation.yaml index 41ccbbad..3d0af3ea 100644 --- a/yolo/config/task/validation.yaml +++ b/yolo/config/task/validation.yaml @@ -1,5 +1,6 @@ task: validation +fast_inference: # onnx, trt, deploy or Empty data: batch_size: 16 image_size: ${image_size} diff --git a/yolo/lazy.py b/yolo/lazy.py index 1bc5577e..abb2ca68 100644 --- a/yolo/lazy.py +++ b/yolo/lazy.py @@ -6,6 +6,7 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) +from loguru import logger from yolo.config.config import Config from yolo.model.yolo import create_model from yolo.tools.data_loader import create_dataloader @@ -32,6 +33,8 @@ def main(cfg: Config): if cfg.task.task == "train": solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp) if cfg.task.task == "validation": + if cfg.task.fast_inference in ["trt", "deploy"]: + logger.warning("โš ๏ธ ONNX is only tested, not responsible about using trt and deploy.") solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device) if cfg.task.task == "inference": solver = ModelTester(cfg, model, converter, progress, device) diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 51ceffc0..c5d7311c 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -233,8 +233,9 @@ def __init__( self.coco_gt = COCO(json_path) def solve(self, dataloader, epoch_idx=1): - # logger.info("๐Ÿงช Start Validation!") - self.model.eval() + logger.info("๐Ÿงช Start Validation!") + if isinstance(self.model, torch.nn.Module): + self.model.eval() predict_json, mAPs = [], defaultdict(list) self.progress.start_one_epoch(len(dataloader), task="Validate") for batch_size, images, targets, rev_tensor, img_paths in dataloader: diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 9ca709b0..329c9319 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -17,7 +17,7 @@ def __init__(self, cfg: Config): self._validate_compiler() if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" - self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}" + self.model_path = Path("weights") / f"{Path(cfg.weight).stem}.{self.compiler}" def _validate_compiler(self): if self.compiler not in ["onnx", "trt", "deploy"]: From 700aaf8500006140f2d4de21f8561cfcb961eba3 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 27 Aug 2024 16:28:15 +0900 Subject: [PATCH 2/5] Enable openvino IR inference --- yolo/config/task/inference.yaml | 3 +- yolo/config/task/validation.yaml | 3 +- yolo/utils/deploy_utils.py | 50 ++++++++++++++++++++++++++++++-- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/yolo/config/task/inference.yaml b/yolo/config/task/inference.yaml index 383d9cc7..88855c06 100644 --- a/yolo/config/task/inference.yaml +++ b/yolo/config/task/inference.yaml @@ -1,6 +1,7 @@ task: inference -fast_inference: # onnx, trt, deploy or Empty +fast_inference: # onnx, openvino, trt, deploy or Empty +precision: FP32 # for openvino data: source: demo/images/inference/image.png image_size: ${image_size} diff --git a/yolo/config/task/validation.yaml b/yolo/config/task/validation.yaml index 3d0af3ea..bee5a66f 100644 --- a/yolo/config/task/validation.yaml +++ b/yolo/config/task/validation.yaml @@ -1,6 +1,7 @@ task: validation -fast_inference: # onnx, trt, deploy or Empty +fast_inference: # onnx, openvino, trt, deploy or Empty +precision: FP32 # for openvino data: batch_size: 16 image_size: ${image_size} diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 329c9319..0b80a36d 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -17,10 +17,14 @@ def __init__(self, cfg: Config): self._validate_compiler() if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" - self.model_path = Path("weights") / f"{Path(cfg.weight).stem}.{self.compiler}" + + if str(self.compiler).lower() == "openvino": + self.model_path: Path = Path("weights") / f"{Path(cfg.weight).stem}.xml" + else: + self.model_path: Path = Path("weights") / f"{Path(cfg.weight).stem}.{self.compiler}" def _validate_compiler(self): - if self.compiler not in ["onnx", "trt", "deploy"]: + if self.compiler not in ["onnx", "openvino", "trt", "deploy"]: logger.warning(f"โš ๏ธ Compiler '{self.compiler}' is not supported. Using original model.") self.compiler = None if self.cfg.device == "mps" and self.compiler == "trt": @@ -30,6 +34,8 @@ def _validate_compiler(self): def load_model(self, device): if self.compiler == "onnx": return self._load_onnx_model(device) + elif self.compiler == "openvino": + return self._load_openvino_model() elif self.compiler == "trt": return self._load_trt_model().to(device) elif self.compiler == "deploy": @@ -82,6 +88,46 @@ def _create_onnx_model(self, providers): logger.info(f"๐Ÿ“ฅ ONNX model saved to {self.model_path}") return InferenceSession(self.model_path, providers=providers) + def _load_openvino_model(self, device: str = "cpu"): + from openvino import Core, CompiledModel + + original_call = CompiledModel.__call__ + def openvino_call(self: CompiledModel, *args, **kwargs): + outputs = original_call(self, *args, **kwargs) + + model_outputs, layer_output = [], [] + for idx, (_, predict) in enumerate(outputs.items()): + layer_output.append(torch.from_numpy(predict).to(device)) + if idx % 3 == 2: + model_outputs.append(layer_output) + layer_output = [] + if len(model_outputs) == 6: + model_outputs = model_outputs[:3] + return {"Main": model_outputs} + + CompiledModel.__call__ = openvino_call + + try: + core = Core() + model_ov = core.read_model(str(self.model_path)) + logger.info("๐Ÿš€ Using OpenVINO as MODEL frameworks!") + except Exception as e: + logger.warning(f"๐Ÿˆณ Error loading OpenVINO model: {e}") + model_ov = self._create_openvino_model() + return core.compile_model(model_ov, "CPU") + + def _create_openvino_model(self): + import openvino as ov + + if not (onnx_model_path := self.model_path.with_suffix(".onnx")).exists(): + self._create_onnx_model(["CPUExecutionProvider"]) + + model_ov = ov.convert_model(onnx_model_path, input=(ov.runtime.PartialShape((-1, 3, *self.cfg.image_size)),)) + + ov.save_model(model_ov, self.model_path, compress_to_fp16=(self.cfg.task.precision == "FP16")) + logger.info(f"๐Ÿ“ฅ ONNX model saved to {self.model_path}") + return model_ov + def _load_trt_model(self): from torch2trt import TRTModule From 2c2955e8c277f1c1e5628fb450aa706b630e6d0e Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 27 Aug 2024 18:09:57 +0900 Subject: [PATCH 3/5] Update for FP16 compression --- yolo/utils/deploy_utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 0b80a36d..0d2ff1ef 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -19,7 +19,10 @@ def __init__(self, cfg: Config): cfg.weight = Path("weights") / f"{cfg.model.name}.pt" if str(self.compiler).lower() == "openvino": - self.model_path: Path = Path("weights") / f"{Path(cfg.weight).stem}.xml" + if self.cfg.task.precision == "FP16": + self.model_path: Path = Path("weights") / f"{Path(cfg.weight).stem}_fp16.xml" + else: + self.model_path: Path = Path("weights") / f"{Path(cfg.weight).stem}.xml" else: self.model_path: Path = Path("weights") / f"{Path(cfg.weight).stem}.{self.compiler}" @@ -117,14 +120,24 @@ def openvino_call(self: CompiledModel, *args, **kwargs): return core.compile_model(model_ov, "CPU") def _create_openvino_model(self): - import openvino as ov + from openvino import convert_model, save_model, PartialShape - if not (onnx_model_path := self.model_path.with_suffix(".onnx")).exists(): + if "fp16" in str(self.model_path): + onnx_model_path = Path(str(self.model_path).replace("_fp16.xml", ".onnx")) + else: + onnx_model_path = self.model_path.with_suffix(".onnx") + if not onnx_model_path.exists(): self._create_onnx_model(["CPUExecutionProvider"]) - model_ov = ov.convert_model(onnx_model_path, input=(ov.runtime.PartialShape((-1, 3, *self.cfg.image_size)),)) + model_ov = convert_model(onnx_model_path, input=(PartialShape((-1, 3, *self.cfg.image_size)),)) + + save_model(model_ov, self.model_path, compress_to_fp16=(self.cfg.task.precision == "FP16")) + if self.cfg.task.precision == "FP16": + from openvino import Core + + core = Core() + model_ov = core.read_model(str(self.model_path)) - ov.save_model(model_ov, self.model_path, compress_to_fp16=(self.cfg.task.precision == "FP16")) logger.info(f"๐Ÿ“ฅ ONNX model saved to {self.model_path}") return model_ov From 37481868ca6e29862cb38fbc4f9ace547e216713 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Wed, 28 Aug 2024 12:36:41 +0900 Subject: [PATCH 4/5] Enable PTQ --- yolo/config/task/inference.yaml | 1 + yolo/config/task/validation.yaml | 1 + yolo/tools/data_loader.py | 9 ++++++--- yolo/utils/deploy_utils.py | 14 ++++++++++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/yolo/config/task/inference.yaml b/yolo/config/task/inference.yaml index 88855c06..6cc85944 100644 --- a/yolo/config/task/inference.yaml +++ b/yolo/config/task/inference.yaml @@ -2,6 +2,7 @@ task: inference fast_inference: # onnx, openvino, trt, deploy or Empty precision: FP32 # for openvino +ptq: False # for openvino data: source: demo/images/inference/image.png image_size: ${image_size} diff --git a/yolo/config/task/validation.yaml b/yolo/config/task/validation.yaml index bee5a66f..291a3638 100644 --- a/yolo/config/task/validation.yaml +++ b/yolo/config/task/validation.yaml @@ -2,6 +2,7 @@ task: validation fast_inference: # onnx, openvino, trt, deploy or Empty precision: FP32 # for openvino +ptq: False # for openvino data: batch_size: 16 image_size: ${image_size} diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index 785235a9..cd3b09fe 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -150,11 +150,12 @@ def __len__(self) -> int: class YoloDataLoader(DataLoader): - def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False): + def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False, is_ov_ptq: bool = False): """Initializes the YoloDataLoader with hydra-config files.""" dataset = YoloDataset(data_cfg, dataset_cfg, task) sampler = DistributedSampler(dataset, shuffle=data_cfg.shuffle) if use_ddp else None self.image_size = data_cfg.image_size[0] + self.is_ov_ptq = is_ov_ptq super().__init__( dataset, batch_size=data_cfg.batch_size, @@ -193,17 +194,19 @@ def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[T batch_images = torch.stack(batch_images) batch_reverse = torch.stack(batch_reverse) + if self.is_ov_ptq: + return batch_images return batch_size, batch_images, batch_targets, batch_reverse, batch_path -def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False): +def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False, is_ov_ptq: bool = False): if task == "inference": return StreamDataLoader(data_cfg) if dataset_cfg.auto_download: prepare_dataset(dataset_cfg, task) - return YoloDataLoader(data_cfg, dataset_cfg, task, use_ddp) + return YoloDataLoader(data_cfg, dataset_cfg, task, use_ddp, is_ov_ptq) class StreamDataLoader: diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 0d2ff1ef..c53ebd38 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -6,6 +6,7 @@ from yolo.config.config import Config from yolo.model.yolo import create_model +from yolo.tools.data_loader import create_dataloader class FastModelLoader: @@ -117,6 +118,19 @@ def openvino_call(self: CompiledModel, *args, **kwargs): except Exception as e: logger.warning(f"๐Ÿˆณ Error loading OpenVINO model: {e}") model_ov = self._create_openvino_model() + + if self.cfg.task.ptq: + if "optimized" in str(self.model_path): + logger.info("๐Ÿš€ PTQ Model is already loaded!") + else: + import nncf + from openvino.runtime import serialize + + train_dataloader = create_dataloader(self.cfg.task.data, self.cfg.dataset, "train", is_ov_ptq=True) + ptq_dataset = nncf.Dataset(train_dataloader, lambda x: x) + model_ov = nncf.quantize(model_ov, ptq_dataset, preset=nncf.QuantizationPreset.MIXED) + serialize(model_ov, str(self.model_path).replace(".xml", "_optimized.xml")) + return core.compile_model(model_ov, "CPU") def _create_openvino_model(self): From 8714954e13939313cdb89bd741ca65701e8245d6 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Wed, 28 Aug 2024 12:53:51 +0900 Subject: [PATCH 5/5] Update log --- yolo/lazy.py | 2 +- yolo/utils/deploy_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/yolo/lazy.py b/yolo/lazy.py index abb2ca68..4281f8ca 100644 --- a/yolo/lazy.py +++ b/yolo/lazy.py @@ -34,7 +34,7 @@ def main(cfg: Config): solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp) if cfg.task.task == "validation": if cfg.task.fast_inference in ["trt", "deploy"]: - logger.warning("โš ๏ธ ONNX is only tested, not responsible about using trt and deploy.") + logger.warning("โš ๏ธ ONNX and OpenVINO are only tested, not responsible about using trt and deploy.") solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device) if cfg.task.task == "inference": solver = ModelTester(cfg, model, converter, progress, device) diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index c53ebd38..75a39156 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -152,7 +152,7 @@ def _create_openvino_model(self): core = Core() model_ov = core.read_model(str(self.model_path)) - logger.info(f"๐Ÿ“ฅ ONNX model saved to {self.model_path}") + logger.info(f"๐Ÿ“ฅ OpenVINO model saved to {self.model_path}") return model_ov def _load_trt_model(self):