From 2b7e01337d6a970fc7d97282ec9796d24154d7d2 Mon Sep 17 00:00:00 2001 From: Seth Price Date: Wed, 14 Aug 2024 00:11:35 -0700 Subject: [PATCH] Malformed bbox and dedup --- yolo/tools/data_loader.py | 49 +++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index c3555117..799611d5 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -21,6 +21,7 @@ VerticalFlip, ) from yolo.tools.dataset_preparation import prepare_dataset +from yolo.utils.bounding_box_utils import calculate_iou from yolo.utils.dataset_utils import ( create_image_metadata, locate_label_paths, @@ -105,12 +106,42 @@ def filter_data(self, dataset_path: Path, phase_name: str) -> list: labels = self.load_valid_labels(image_id, image_seg_annotations) + if labels is not None and len(labels) > 1: + labels = self.deduplicate_labels(labels) + img_path = images_path / image_name data.append((img_path, labels)) valid_inputs += 1 logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list)) return data + def deduplicate_labels(self, labels: Tensor) -> Tensor: + """ + Removes duplicate labels from a Tensor of bboxes. + + Parameters: + labels (Tensor): A tensor of all input bounding boxes. + + Returns: + Tensor: A tensor of all remaining bounding boxes. + """ + dedup_labels = [] + + for l in labels: + acceptable = True + for ddl in dedup_labels: + if int(l[0]) != int(ddl[0]): + continue + + if float(calculate_iou(l[1:], ddl[1:])) > .99: + acceptable = False + break + + if acceptable: + dedup_labels.append(l) + + return torch.stack(dedup_labels) + def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]: """ Loads and validates bounding box data is [0, 1] from a label file. @@ -122,13 +153,17 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None. """ bboxes = [] - for seg_data in seg_data_one_img: - cls = seg_data[0] - points = np.array(seg_data[1:]).reshape(-1, 2) - valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2) - if valid_points.size > 1: - bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)]) - bboxes.append(bbox) + try: + for seg_data in seg_data_one_img: + cls = seg_data[0] + points = np.array(seg_data[1:]).reshape(-1, 2) + valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2) + if valid_points.size > 1: + bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)]) + bboxes.append(bbox) + except ValueError: + logger.warning("Invalid BBox in {}", label_path) + return torch.zeros((0, 5)) if bboxes: return torch.stack(bboxes)