From 8f2ecc2af730bb1c358940c396e376375f33042d Mon Sep 17 00:00:00 2001 From: yongjian_zhang Date: Wed, 14 Aug 2024 03:44:18 +0000 Subject: [PATCH 1/3] [fix] automatically regenerate the cache when it doesn't match --- yolo/tools/data_loader.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index 9ceb455f..bf31c37a 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -54,8 +54,33 @@ def load_data(self, dataset_path: Path, phase_name: str): else: data = torch.load(cache_path) logger.info("📦 Loaded {} cache", phase_name) + + # Validate cache + if self.validate_cache(dataset_path, phase_name, data): + logger.info("✅ Cache validation successful") + else: + logger.warning("⚠️ Cache validation failed, regenerating") + data = self.filter_data(dataset_path, phase_name) + torch.save(data, cache_path) + return data + def validate_cache(self, dataset_path: Path, phase_name: str, cached_data: list) -> bool: + """ + Validates if the cached data is consistent with the current dataset, comparing complete file paths + """ + images_path = dataset_path / "images" / phase_name + current_images = sorted([p.resolve() for p in images_path.iterdir() if p.is_file()]) + cached_images = sorted([Path(item[0]).resolve() for item in cached_data]) + + # Check if image file paths are completely consistent + if current_images != cached_images: + return False + + # Can add more validation steps, e.g. checking label file modification times + + return True + def filter_data(self, dataset_path: Path, phase_name: str) -> list: """ Filters and collects dataset information by pairing images with their corresponding labels. From d2c36b18c47c7cd2b60c3e386bf31666ebb571d5 Mon Sep 17 00:00:00 2001 From: yongjian_zhang Date: Wed, 14 Aug 2024 05:13:11 +0000 Subject: [PATCH 2/3] [fix] short the path of cache file --- yolo/tools/data_loader.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index bf31c37a..3d470f2a 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -32,7 +32,8 @@ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()] self.transform = AugmentationComposer(transforms, self.image_size) self.transform.get_more_data = self.get_more_data - self.data = self.load_data(Path(dataset_cfg.path), phase_name) + self.data_root = Path(dataset_cfg.path) + self.data = self.load_data(self.data_root, phase_name) def load_data(self, dataset_path: Path, phase_name: str): """ @@ -54,9 +55,8 @@ def load_data(self, dataset_path: Path, phase_name: str): else: data = torch.load(cache_path) logger.info("📦 Loaded {} cache", phase_name) - # Validate cache - if self.validate_cache(dataset_path, phase_name, data): + if data[0][0].parent == Path("images")/phase_name: logger.info("✅ Cache validation successful") else: logger.warning("⚠️ Cache validation failed, regenerating") @@ -65,22 +65,6 @@ def load_data(self, dataset_path: Path, phase_name: str): return data - def validate_cache(self, dataset_path: Path, phase_name: str, cached_data: list) -> bool: - """ - Validates if the cached data is consistent with the current dataset, comparing complete file paths - """ - images_path = dataset_path / "images" / phase_name - current_images = sorted([p.resolve() for p in images_path.iterdir() if p.is_file()]) - cached_images = sorted([Path(item[0]).resolve() for item in cached_data]) - - # Check if image file paths are completely consistent - if current_images != cached_images: - return False - - # Can add more validation steps, e.g. checking label file modification times - - return True - def filter_data(self, dataset_path: Path, phase_name: str) -> list: """ Filters and collects dataset information by pairing images with their corresponding labels. @@ -125,7 +109,7 @@ def filter_data(self, dataset_path: Path, phase_name: str) -> list: labels = self.load_valid_labels(image_id, image_seg_annotations) - img_path = images_path / image_name + img_path = Path("images") / phase_name / image_name data.append((img_path, labels)) valid_inputs += 1 logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list)) @@ -158,6 +142,7 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te def get_data(self, idx): img_path, bboxes = self.data[idx] + img_path = self.data_root / Path(img_path) img = Image.open(img_path).convert("RGB") return img, bboxes, img_path From 86da903718367f5fe198cf9ab322c369ed564997 Mon Sep 17 00:00:00 2001 From: yongjian_zhang Date: Wed, 14 Aug 2024 05:13:43 +0000 Subject: [PATCH 3/3] add workflow --- .github/workflows/deploy.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml index 3e265d82..25693192 100644 --- a/.github/workflows/deploy.yaml +++ b/.github/workflows/deploy.yaml @@ -2,7 +2,7 @@ name: Deploy Mode Validation & Inference on: push: - branches: [main] + branches: [main,TRAIN] pull_request: branches: [main]