Skip to content
Open
61 changes: 38 additions & 23 deletions yolo/tools/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str
self.transform.get_more_data = self.get_more_data
self.data = self.load_data(Path(dataset_cfg.path), phase_name)

def load_data(self, dataset_path: Path, phase_name: str):
def load_data(self, dataset_path: Path, phase_name: str) -> list:
"""
Loads data from a cache or generates a new cache for a specific dataset phase.

Expand All @@ -43,7 +43,7 @@ def load_data(self, dataset_path: Path, phase_name: str):
phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

Returns:
dict: The loaded data from the cache for the specified phase.
list: The loaded data from the cache for the specified phase.
"""
cache_path = dataset_path / f"{phase_name}.cache"

Expand All @@ -58,38 +58,48 @@ def load_data(self, dataset_path: Path, phase_name: str):

def filter_data(self, dataset_path: Path, phase_name: str) -> list:
"""
Filters and collects dataset information by pairing images with their corresponding labels.
Filters and collects dataset information by pairing images with
their corresponding labels.

Parameters:
images_path (Path): Path to the directory containing image files.
labels_path (str): Path to the directory containing label files.
dataset_path (Path): The root path to the dataset directory.
phase_name (str): The specific phase of the dataset
(e.g., 'train', 'test') to load or generate data for.

Returns:
list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
list: A list of tuples, each containing image id, path to an image file
and its associated segmentation as a tensor. For COCO formatted .json
files, image id is the `int` `image_id` attribute for each annotation
in the json file.
For YOLO formatted .txt files, image id is the image file name without
the extension.
"""
images_path = dataset_path / "images" / phase_name
labels_path, data_type = locate_label_paths(dataset_path, phase_name)
images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()])
if data_type == "json":
annotations_index, image_info_dict = create_image_metadata(labels_path)

(
annotations_dict,
image_info_dict,
image_name_to_id_dict
) = create_image_metadata(labels_path)
data = []
valid_inputs = 0
for image_name in track(images_list, description="Filtering data"):
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
continue
image_id = Path(image_name).stem

if data_type == "json":
image_id = image_name_to_id_dict[image_name]
image_info = image_info_dict.get(image_id, None)
if image_info is None:
continue
annotations = annotations_index.get(image_info["id"], [])
annotations = annotations_dict.get(image_id, [])
image_seg_annotations = scale_segmentation(annotations, image_info)
if not image_seg_annotations:
continue

elif data_type == "txt":
image_id = Path(image_name).stem
label_path = labels_path / f"{image_id}.txt"
if not label_path.is_file():
continue
Expand All @@ -99,19 +109,24 @@ def filter_data(self, dataset_path: Path, phase_name: str) -> list:
image_seg_annotations = []

labels = self.load_valid_labels(image_id, image_seg_annotations)

img_path = images_path / image_name
data.append((img_path, labels))
image_path = images_path / image_name
data.append((image_id, image_path, labels))
valid_inputs += 1
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
return data

def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
def load_valid_labels(
self,
image_id: Union[int, str],
seg_data_one_img: list
) -> Union[Tensor, None]:
"""
Loads and validates bounding box data is [0, 1] from a label file.

Parameters:
label_path (str): The filepath to the label file containing bounding box data.
image_id (int | str): Image id.
If COCO .json file is used, image id is a `int`.
If YOLO .txt file is used, image id is a string.

Returns:
Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
Expand All @@ -128,22 +143,22 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te
if bboxes:
return torch.stack(bboxes)
else:
logger.warning("No valid BBox in {}", label_path)
logger.warning("No valid BBox in image id:{}", image_id)
return torch.zeros((0, 5))

def get_data(self, idx):
img_path, bboxes = self.data[idx]
image_id, img_path, bboxes = self.data[idx]
img = Image.open(img_path).convert("RGB")
return img, bboxes, img_path
return img, bboxes, image_id

def get_more_data(self, num: int = 1):
indices = torch.randint(0, len(self), (num,))
return [self.get_data(idx)[:2] for idx in indices]

def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
img, bboxes, img_path = self.get_data(idx)
img, bboxes, image_id = self.get_data(idx)
img, bboxes, rev_tensor = self.transform(img, bboxes)
return img, bboxes, rev_tensor, img_path
return img, bboxes, rev_tensor, image_id

def __len__(self) -> int:
return len(self.data)
Expand Down Expand Up @@ -189,11 +204,11 @@ def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[T
batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
batch_targets[:, :, 1:] *= self.image_size

batch_images, _, batch_reverse, batch_path = zip(*batch)
batch_images, _, batch_reverse, batch_image_ids = zip(*batch)
batch_images = torch.stack(batch_images)
batch_reverse = torch.stack(batch_reverse)

return batch_size, batch_images, batch_targets, batch_reverse, batch_path
return batch_size, batch_images, batch_targets, batch_reverse, batch_image_ids


def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
Expand Down
4 changes: 2 additions & 2 deletions yolo/tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def solve(self, dataloader, epoch_idx=1):
self.model.eval()
predict_json, mAPs = [], defaultdict(list)
self.progress.start_one_epoch(len(dataloader), task="Validate")
for batch_size, images, targets, rev_tensor, img_paths in dataloader:
for batch_size, images, targets, rev_tensor, image_ids in dataloader:
images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
with torch.no_grad():
predicts = self.model(images)
Expand All @@ -250,7 +250,7 @@ def solve(self, dataloader, epoch_idx=1):
avg_mAPs = {key: 100 * torch.mean(torch.stack(val)) for key, val in mAPs.items()}
self.progress.one_batch(avg_mAPs)

predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
predict_json.extend(predicts_to_json(image_ids, predicts, rev_tensor))
self.progress.finish_one_epoch(avg_mAPs, epoch_idx=epoch_idx)
self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)

Expand Down
74 changes: 47 additions & 27 deletions yolo/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,67 @@ def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path
return [], None


def create_image_metadata(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
def create_image_metadata(
labels_path: str
) -> Tuple[Dict[int, List], Dict[int, Dict], Dict[str, int]]:
"""
Create a dictionary containing image information and annotations indexed by image ID.
Returns three dictionaries mapping image id to list of annotations,
image id to image information, and image name to image id.
Image id is the `int` `id` assigned to a image in the COCO formatted .json file.

Args:
labels_path (str): The path to the annotation json file.

Returns:
- annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
- image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
(annotations_dict, image_info_dict, image_name_to_id_dict):
annotations_dict is a dictionary where keys are image ids and values
are lists of annotation dictionaries.
image_info_dict is a dictionary where keys are image file id and
values are image information dictionaries.
image_name_to_id_dict is a dictionary with image name without
extension as key and int image id as value.
"""
with open(labels_path, "r") as file:
labels_data = json.load(file)
id_to_idx = discretize_categories(labels_data.get("categories", [])) if "categories" in labels_data else None
annotations_index = organize_annotations_by_image(labels_data, id_to_idx) # check lookup is a good name?
image_info_dict = {Path(img["file_name"]).stem: img for img in labels_data["images"]}
return annotations_index, image_info_dict
json_data = json.load(file)
image_name_to_id_dict = {
Path(img["file_name"]).name: img['id'] for img in json_data["images"]
}
id_to_idx = discretize_categories(json_data.get("categories", [])) if "categories" in json_data else None
annotations_dict = organize_annotations_by_image(json_data, id_to_idx) # check lookup is a good name?
image_info_dict = {img['id']: img for img in json_data["images"]}
return annotations_dict, image_info_dict, image_name_to_id_dict


def organize_annotations_by_image(
json_data: Dict[str, Any],
category_id_to_idx: Optional[Dict[int, int]],
) -> dict[int, list[dict]]:
"""
Returns a dict mapping image id to a list of all corresponding annotations.

Annotations with "iscrowd" set to True, are excluded. Image id is the `int`
`image_id` in the corresponding annotation dict stored in the
COCO formatted .json file.

def organize_annotations_by_image(data: Dict[str, Any], id_to_idx: Optional[Dict[int, int]]):
"""
Use image index to lookup every annotations
Args:
data (Dict[str, Any]): A dictionary containing annotation data.

json_data: Data read from a COCO json file.
category_id_to_idx: For COCO dataset, a dict mapping from category_id
to (category_id - 1).
Returns:
Dict[int, List[Dict[str, Any]]]: A dictionary where keys are image IDs and values are lists of annotations.
Annotations with "iscrowd" set to True are excluded from the index.

image_name_to_annotation_dict_list: A dictionary where keys are image ids
and values are lists of annotation dictionaries.
"""
annotation_lookup = {}
for anno in data["annotations"]:
if anno["iscrowd"]:
image_id_to_annotation_dict_list = {}
for annotation_dict in json_data["annotations"]:
if annotation_dict["iscrowd"]:
continue
image_id = anno["image_id"]
if id_to_idx:
anno["category_id"] = id_to_idx[anno["category_id"]]
if image_id not in annotation_lookup:
annotation_lookup[image_id] = []
annotation_lookup[image_id].append(anno)
return annotation_lookup
image_id = annotation_dict["image_id"]
if category_id_to_idx:
annotation_dict["category_id"] = category_id_to_idx[annotation_dict["category_id"]]
if image_id not in image_id_to_annotation_dict_list:
image_id_to_annotation_dict_list[image_id] = []
image_id_to_annotation_dict_list[image_id].append(annotation_dict)
return image_id_to_annotation_dict_list


def scale_segmentation(
Expand Down
23 changes: 18 additions & 5 deletions yolo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,32 @@ def collect_prediction(predict_json: List, local_rank: int) -> List:
return predict_json


def predicts_to_json(img_paths, predicts, rev_tensor):
def predicts_to_json(
image_ids:Union[tuple[int], tuple[str]],
predicts:list[Tensor],
rev_tensor:Tensor
) -> list[dict[str, any]]:
"""
TODO: function document
turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
Returns a list of prediction dictionaries. Each dict contains, image_id,
category_id, bbox and score.

Args:
image_ids: Tuple of image ids.
When using a COCO .json annotation file, image ids are int.
When using YOLO .txt annotation files, image ids are string.
predicts: For each iamge, contains a tensor of shape (n, 6),
where n is the number of detected bbox in the corresponding image.
rev_tensor: A tensor of shape (m,5), where m is the number of images.
TODO: add docstring of what this is.
"""
batch_json = []
for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
for image_id, bboxes, box_reverse in zip(image_ids, predicts, rev_tensor):
scale, shift = box_reverse.split([1, 4])
bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
for cls, *pos, conf in bboxes:
bbox = {
"image_id": int(Path(img_path).stem),
"image_id": image_id,
"category_id": IDX_TO_ID[int(cls)],
"bbox": [float(p) for p in pos],
"score": float(conf),
Expand Down