From 45d08f2351e1cbbc69ee18e7a935c4718013c69c Mon Sep 17 00:00:00 2001 From: ry-immr Date: Wed, 22 Jan 2025 01:21:07 +0900 Subject: [PATCH] Fix point filtering logic in load_valid_labels() --- yolo/tools/data_loader.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index c44f00c6..72bda9f0 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -140,10 +140,14 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te for seg_data in seg_data_one_img: cls = seg_data[0] points = np.array(seg_data[1:]).reshape(-1, 2) - valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2) - if valid_points.size > 1: - bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)]) - bboxes.append(bbox) + + if not np.any((points >= 0) & (points <= 1)): + continue + + valid_points = np.clip(points, 0, 1) + + bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)]) + bboxes.append(bbox) if bboxes: return torch.stack(bboxes)