diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml index 3e265d82..25693192 100644 --- a/.github/workflows/deploy.yaml +++ b/.github/workflows/deploy.yaml @@ -2,7 +2,7 @@ name: Deploy Mode Validation & Inference on: push: - branches: [main] + branches: [main,TRAIN] pull_request: branches: [main] diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index 9ceb455f..3d470f2a 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -32,7 +32,8 @@ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()] self.transform = AugmentationComposer(transforms, self.image_size) self.transform.get_more_data = self.get_more_data - self.data = self.load_data(Path(dataset_cfg.path), phase_name) + self.data_root = Path(dataset_cfg.path) + self.data = self.load_data(self.data_root, phase_name) def load_data(self, dataset_path: Path, phase_name: str): """ @@ -54,6 +55,14 @@ def load_data(self, dataset_path: Path, phase_name: str): else: data = torch.load(cache_path) logger.info("📦 Loaded {} cache", phase_name) + # Validate cache + if data[0][0].parent == Path("images")/phase_name: + logger.info("✅ Cache validation successful") + else: + logger.warning("⚠️ Cache validation failed, regenerating") + data = self.filter_data(dataset_path, phase_name) + torch.save(data, cache_path) + return data def filter_data(self, dataset_path: Path, phase_name: str) -> list: @@ -100,7 +109,7 @@ def filter_data(self, dataset_path: Path, phase_name: str) -> list: labels = self.load_valid_labels(image_id, image_seg_annotations) - img_path = images_path / image_name + img_path = Path("images") / phase_name / image_name data.append((img_path, labels)) valid_inputs += 1 logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list)) @@ -133,6 +142,7 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te def get_data(self, idx): img_path, bboxes = self.data[idx] + img_path = self.data_root / Path(img_path) img = Image.open(img_path).convert("RGB") return img, bboxes, img_path