diff --git a/medsegpy/config.py b/medsegpy/config.py index 3215c095..10e87d71 100644 --- a/medsegpy/config.py +++ b/medsegpy/config.py @@ -758,7 +758,7 @@ def summary(self, additional_vars=None): super().summary(summary_vars) -class ContextUNetConfig(Config): +class ContextUNetConfig(ContextEncoderConfig): """ Configuration for the ContextUNet model. @@ -769,16 +769,6 @@ class ContextUNetConfig(Config): """ MODEL_NAME = "ContextUNet" - NUM_FILTERS = [[32, 32], [64, 64], [128, 128], [256, 256]] - - def __init__(self, state="training", create_dirs=True): - super().__init__(self.MODEL_NAME, state, create_dirs=create_dirs) - - def summary(self, additional_vars=None): - summary_vars = ["NUM_FILTERS"] - if additional_vars: - summary_vars.extend(additional_vars) - super().summary(summary_vars) class ContextInpaintingConfig(ContextUNetConfig): diff --git a/medsegpy/cross_validation/cv_util.py b/medsegpy/cross_validation/cv_util.py index f754a536..13aa6c23 100644 --- a/medsegpy/cross_validation/cv_util.py +++ b/medsegpy/cross_validation/cv_util.py @@ -147,12 +147,9 @@ def init_cv_experiments(self, num_valid_bins=1, num_test_bins=1): for i in range(len(temp)): for j in range(i + 1, len(temp)): - assert ( - len(set(temp[i]) & set(temp[j])) == 0 - ), "Test bins %d and %d not mutually exclusive - %d overlap" % ( - i, - j, - len(set(temp[i]) & set(temp[j])), + assert len(set(temp[i]) & set(temp[j])) == 0, ( + "Test bins %d and %d not mutually exclusive - %d overlap" + % (i, j, len(set(temp[i]) & set(temp[j]))) ) self.num_valid_bins = num_valid_bins diff --git a/medsegpy/data/data_utils.py b/medsegpy/data/data_utils.py index 32593829..adb530ed 100644 --- a/medsegpy/data/data_utils.py +++ b/medsegpy/data/data_utils.py @@ -3,6 +3,7 @@ from typing import Sequence, Union import numpy as np +from numba import njit def collect_mask(mask: np.ndarray, index: Sequence[Union[int, Sequence[int], int]]): @@ -195,7 +196,7 @@ def generate_poisson_disc_mask( x /= x.max() y = np.maximum(abs(y - img_shape[-2] / 2), 0) y /= y.max() - r = np.sqrt(x**2 + y**2) + r = np.sqrt(x ** 2 + y ** 2) # Quick checks assert int(num_samples) == num_samples, ( @@ -233,6 +234,7 @@ def generate_poisson_disc_mask( return mask, patch_mask +@njit def _poisson(nx, ny, K, R, num_samples=None, patch_size=0.0, seed=None): mask = np.zeros((ny, nx)) patch_mask = np.zeros((ny, nx)) diff --git a/medsegpy/data/datasets/abct.py b/medsegpy/data/datasets/abct.py index b3c8425f..e5f29c0a 100644 --- a/medsegpy/data/datasets/abct.py +++ b/medsegpy/data/datasets/abct.py @@ -153,4 +153,6 @@ def register_all_abct(): txt_file_or_scan_root = os.path.join( Cluster.working_cluster().data_dir, txt_file_or_scan_root ) + if not os.path.exists(txt_file_or_scan_root): + continue register_abct(dataset_name, txt_file_or_scan_root) diff --git a/medsegpy/data/datasets/oai.py b/medsegpy/data/datasets/oai.py index 9388c766..9298b510 100644 --- a/medsegpy/data/datasets/oai.py +++ b/medsegpy/data/datasets/oai.py @@ -176,4 +176,6 @@ def register_all_oai(): for dataset_name, scan_root in _DATA_CATALOG.items(): if not os.path.isabs(scan_root): scan_root = os.path.join(Cluster.working_cluster().data_dir, scan_root) + if not os.path.exists(scan_root): + continue register_oai(dataset_name, scan_root) diff --git a/medsegpy/data/datasets/qdess_mri.py b/medsegpy/data/datasets/qdess_mri.py index b60628bc..9b1bfb51 100644 --- a/medsegpy/data/datasets/qdess_mri.py +++ b/medsegpy/data/datasets/qdess_mri.py @@ -4,6 +4,7 @@ import re from medsegpy.data.catalog import DatasetCatalog, MetadataCatalog +from medsegpy.utils.cluster import Cluster logger = logging.getLogger(__name__) @@ -161,6 +162,7 @@ def load_2d_from_filepaths(filepaths: list, source_path: str, dataset_name: str corresponding ground truth segmentations. total_num_slices: The total number of slices for this dataset. dataset_name: The name of the dataset. + Returns: dataset_dicts: A list of dictionaries, described above in the docstring. @@ -336,4 +338,8 @@ def register_all_qdess_datasets(): Registers all qDESS MRI datasets listed in _DATA_CATALOG. """ for dataset_name, scan_root in _DATA_CATALOG.items(): + if not os.path.isabs(scan_root): + scan_root = os.path.join(Cluster.working_cluster().data_dir, scan_root) + if not os.path.exists(scan_root): + continue register_qdess_dataset(scan_root=scan_root, dataset_name=dataset_name) diff --git a/medsegpy/data/im_gens.py b/medsegpy/data/im_gens.py index 7ee3cf29..d87af41e 100644 --- a/medsegpy/data/im_gens.py +++ b/medsegpy/data/im_gens.py @@ -1234,11 +1234,9 @@ def __validate_img_size__(self, total_volume_shape): # this means shape of total volume must be perfectly divisible into # cubes of size IMG_SIZE for dim in range(3): - assert ( - total_volume_shape[dim] % self.config.IMG_SIZE[dim] == 0 - ), "Cannot divide volume of size %s to blocks of size %s" % ( - total_volume_shape, - self.config.IMG_SIZE, + assert total_volume_shape[dim] % self.config.IMG_SIZE[dim] == 0, ( + "Cannot divide volume of size %s to blocks of size %s" + % (total_volume_shape, self.config.IMG_SIZE) ) def img_generator_test(self, model=None): diff --git a/medsegpy/data/transforms/transform.py b/medsegpy/data/transforms/transform.py index 5d6dec2b..4d6fb924 100644 --- a/medsegpy/data/transforms/transform.py +++ b/medsegpy/data/transforms/transform.py @@ -67,6 +67,7 @@ def apply_image(self, img: np.ndarray): img (ndarray): of shape NxHxWxC, or HxWxC or HxW. The array can be of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255]. + Returns: ndarray: image after apply the transformation. """ @@ -147,6 +148,7 @@ def _apply(self, x: _T, meth: str) -> _T: Args: x: input to apply the transform operations. meth (str): meth. + Returns: x: after apply the transformation. """ @@ -167,6 +169,7 @@ def __add__(self, other: "TransformList") -> "TransformList": """ Args: other (TransformList): transformation to add. + Returns: TransformList: list of transforms. """ @@ -177,6 +180,7 @@ def __iadd__(self, other: "TransformList") -> "TransformList": """ Args: other (TransformList): transformation to add. + Returns: TransformList: list of transforms. """ @@ -188,6 +192,7 @@ def __radd__(self, other: "TransformList") -> "TransformList": """ Args: other (TransformList): transformation to add. + Returns: TransformList: list of transforms. """ diff --git a/medsegpy/data/transforms/transform_gen.py b/medsegpy/data/transforms/transform_gen.py index dd457af8..86fb9ada 100644 --- a/medsegpy/data/transforms/transform_gen.py +++ b/medsegpy/data/transforms/transform_gen.py @@ -52,10 +52,9 @@ def check_dtype(img: np.ndarray): assert isinstance(img, np.ndarray), "[TransformGen] Needs an numpy array, but got a {}!".format( type(img) ) - assert ( - not isinstance(img.dtype, np.integer) or img.dtype == np.uint8 - ), "[TransformGen] Got image of type {}, " "use uint8 or floating points instead!".format( - img.dtype + assert not isinstance(img.dtype, np.integer) or img.dtype == np.uint8, ( + "[TransformGen] Got image of type {}, " + "use uint8 or floating points instead!".format(img.dtype) ) assert img.ndim > 2, img.ndim @@ -221,20 +220,20 @@ def __init__( max_height: The maximum height of each hole. max_width: The maximum width of each hole. min_holes: The minimum number of holes to drop out. - If None, the maximum number of holes will - be dropped out. + If None, the maximum number of holes will + be dropped out. min_height: The minimum height of each hole. - If None, each hole will have maximum height. + If None, each hole will have maximum height. min_width: The minimum width of each hole. - If None, each hole will have maximum width. + If None, each hole will have maximum width. img_shape: The height and width of each image. max_perc_area_to_remove: The maximum fraction of the image that - will be removed and filled with a constant - value. This is useful to prevent too - much of the image from being modified. + will be removed and filled with a constant + value. This is useful to prevent too + much of the image from being modified. fill_value: The value used to fill in each hole. sampling_pattern: The type of sampling pattern to use when - selecting random patches. + selecting random patches. Possible values: "uniform", "poisson" num_precompute: The number of masks to precompute. """ @@ -250,76 +249,84 @@ def __init__( # Precompute masks self.precomputed_masks = [] if sampling_pattern == "poisson": - assert min_height == max_height == min_width == max_width, ( - "Only square patches are allowed if sampling pattern is " "'poisson'" - ) - - hole_size = min_width - # Check max_perc_area_to_remove - # - Assuming a hexagonal packing of circles will output the - # most number of samples when using Poisson disc sampling. - # - From https://mathworld.wolfram.com/CirclePacking.html: - # Max packing density when using hexagonal packing is - # pi / (2 * sqrt(3)) - max_pos_area = (img_shape[0] * img_shape[1]) * (np.pi / (2 * np.sqrt(3))) - max_num_patches = max_pos_area // ((np.pi / 2) * (hole_size**2)) - max_pos_perc_area = np.round( - (max_num_patches * (hole_size**2)) / (img_shape[0] * img_shape[1]), decimals=3 + self.create_poisson_disc_masks() + else: + self.create_random_masks() + + def create_poisson_disc_masks(self): + """Precomputes masks using poisson-disc sampling""" + + if not (self.min_height == self.max_height == self.min_width == self.max_width): + raise ValueError("Only square patches are allowed if sampling pattern is 'poisson'") + + hole_size = self.min_width + # Check max_perc_area_to_remove + # - Assuming a hexagonal packing of circles will output the + # most number of samples when using Poisson disc sampling. + # - From https://mathworld.wolfram.com/CirclePacking.html: + # Max packing density when using hexagonal packing is + # pi / (2 * sqrt(3)) + max_pos_area = (self.img_shape[0] * self.img_shape[1]) * (np.pi / (2 * np.sqrt(3))) + max_num_patches = max_pos_area // ((np.pi / 2) * (hole_size ** 2)) + max_pos_perc_area = np.round( + (max_num_patches * (hole_size ** 2)) / (self.img_shape[0] * self.img_shape[1]), + decimals=3, + ) + if self.max_perc_area_to_remove >= max_pos_perc_area: + raise ValueError( + f"Value of 'max_perc_area_to_remove' " + f"(= {self.max_perc_area_to_remove}) is too large. " + f"An overestimate of the maximum possible % area that can " + f"be corrupted given {hole_size} x {hole_size} patches " + f"is: {max_pos_perc_area}. Make sure " + f"'max_perc_area_to_remove' is less than this value." ) - if max_perc_area_to_remove >= max_pos_perc_area: - raise ValueError( - f"Value of 'max_perc_area_to_remove' " - f"(= {max_perc_area_to_remove}) is too large. " - f"An overestimate of the maximum possible % area that can " - f"be corrupted given {hole_size} x {hole_size} patches " - f"is: {max_pos_perc_area}. Make sure " - f"'max_perc_area_to_remove' is less than this value." - ) - # If `sampling_pattern` is "poisson", precompute - # `num_precompute` masks - num_samples = ((img_shape[0] * img_shape[1]) * max_perc_area_to_remove) // ( - hole_size**2 + # If `sampling_pattern` is "poisson", precompute + # `num_precompute` masks + num_samples = ((self.img_shape[0] * self.img_shape[1]) * self.max_perc_area_to_remove) // ( + hole_size ** 2 + ) + logger.info("Precomputing masks...") + for _ in tqdm.tqdm(range(self.num_precompute)): + _, patch_mask = generate_poisson_disc_mask( + (self.img_shape[0], self.img_shape[1]), + min_distance=hole_size * np.sqrt(2), + num_samples=num_samples, + patch_size=hole_size, + k=10, ) - logger.info("Precomputing masks...") - for _ in tqdm.tqdm(range(num_precompute)): - _, patch_mask = generate_poisson_disc_mask( - (img_shape[0], img_shape[1]), - min_distance=hole_size * np.sqrt(2), - num_samples=num_samples, - patch_size=hole_size, - k=10, - ) - self.precomputed_masks.append(patch_mask) - logger.info("Finished precomputing masks!") - else: - logger.info("Precomputing masks...") - img_height = img_shape[0] - img_width = img_shape[1] - max_area_to_remove = int((img_height * img_width) * max_perc_area_to_remove) - for _ in tqdm.tqdm(range(num_precompute)): - patch_mask = np.zeros(img_shape) - cur_num_holes = 0 - if max_holes < 0: - num_holes = np.inf - else: - num_holes = np.random.randint(min_holes, max_holes + 1) - while ( - cur_num_holes <= num_holes and np.count_nonzero(patch_mask) < max_area_to_remove - ): - hole_width = np.random.randint(min_width, max_width + 1) - hole_height = np.random.randint(min_height, max_height + 1) - tl_x = np.random.randint(img_width - hole_width) - tl_y = np.random.randint(img_height - hole_height) - br_x = tl_x + hole_width - br_y = tl_y + hole_height - hole_area = (br_x - tl_x) * (br_y - tl_y) - if hole_area > max_area_to_remove and np.count_nonzero(patch_mask) == 0: - continue - patch_mask[tl_y:br_y, tl_x:br_x] = 1 - cur_num_holes += 1 - self.precomputed_masks.append(patch_mask) - logger.info("Finished precomputing masks!") + self.precomputed_masks.append(patch_mask) + logger.info("Finished precomputing masks!") + + def create_random_masks(self): + """Precomputes masks using random sampling""" + + logger.info("Precomputing masks...") + img_height = self.img_shape[0] + img_width = self.img_shape[1] + max_area_to_remove = int((img_height * img_width) * self.max_perc_area_to_remove) + for _ in tqdm.tqdm(range(self.num_precompute)): + patch_mask = np.zeros(self.img_shape) + cur_num_holes = 0 + if self.max_holes < 0: + num_holes = np.inf + else: + num_holes = np.random.randint(self.min_holes, self.max_holes + 1) + while cur_num_holes <= num_holes and np.count_nonzero(patch_mask) < max_area_to_remove: + hole_width = np.random.randint(self.min_width, self.max_width + 1) + hole_height = np.random.randint(self.min_height, self.max_height + 1) + tl_x = np.random.randint(img_width - hole_width) + tl_y = np.random.randint(img_height - hole_height) + br_x = tl_x + hole_width + br_y = tl_y + hole_height + hole_area = (br_x - tl_x) * (br_y - tl_y) + if hole_area > max_area_to_remove and np.count_nonzero(patch_mask) == 0: + continue + patch_mask[tl_y:br_y, tl_x:br_x] = 1 + cur_num_holes += 1 + self.precomputed_masks.append(patch_mask) + logger.info("Finished precomputing masks!") def get_transform(self, img: np.ndarray) -> MedTransform: """ @@ -337,7 +344,8 @@ def get_transform(self, img: np.ndarray) -> MedTransform: Args: img: A N x H x W x C image, containing N images with height H, - width W, and consisting of C channels. + width W, and consisting of C channels. + Returns: An instance of the MedTransform "FillRegionsWithValue", initialized with the appropriate parameters based on the @@ -426,118 +434,129 @@ def __init__( min_height = max_height if not is_square: if max_width is None: - raise ValueError("Value of 'max_width' must not be None if patch is " "not square") + raise ValueError("Value of 'max_width' must not be None if patch is 'not square'") if min_width is None: min_width = max_width self._init(locals()) self.precomputed_masks = [] if sampling_pattern == "poisson": - assert max_height % 2 == 0, f"'max_height' (= {max_height}) must be even" - assert min_height == max_height, ( - f"If sampling_pattern is 'poisson', min_height (= {min_height}) " - f"must equal max_height (= {max_height})" - ) - assert is_square, "Only square patches are allowed if sampling pattern is " "'poisson'" - - patch_size = max_height - # Check max_perc_area_to_modify - # -- Assuming a hexagonal packing of circles will output the - # most number of samples when using Poisson disc sampling. - # -- From https://mathworld.wolfram.com/CirclePacking.html: - # Max packing density when using hexagonal packing is - # pi / (2 * sqrt(3)) - max_pos_area = (img_shape[0] * img_shape[1]) * (np.pi / (2 * np.sqrt(3))) - max_num_patches = max_pos_area // ((np.pi / 2) * (patch_size**2)) - max_pos_perc_area = np.round( - (max_num_patches * (patch_size**2)) / (img_shape[0] * img_shape[1]), decimals=3 + self.create_poisson_disc_masks() + else: + self.create_random_masks() + + def create_poisson_disc_masks(self): + """Precomputes masks using poisson-disc sampling""" + + assert self.max_height % 2 == 0, f"'max_height' (= {self.max_height}) must be even" + assert self.min_height == self.max_height, ( + f"If sampling_pattern is 'poisson', min_height (= {self.min_height}) " + f"must equal max_height (= {self.max_height})" + ) + assert self.is_square, "Only square patches are allowed if sampling pattern is 'poisson'" + + patch_size = self.max_height + # Check max_perc_area_to_modify + # -- Assuming a hexagonal packing of circles will output the + # most number of samples when using Poisson disc sampling. + # -- From https://mathworld.wolfram.com/CirclePacking.html: + # Max packing density when using hexagonal packing is + # pi / (2 * sqrt(3)) + max_pos_area = (self.img_shape[0] * self.img_shape[1]) * (np.pi / (2 * np.sqrt(3))) + max_num_patches = max_pos_area // ((np.pi / 2) * (patch_size ** 2)) + max_pos_perc_area = np.round( + (max_num_patches * (patch_size ** 2)) / (self.img_shape[0] * self.img_shape[1]), + decimals=3, + ) + if self.max_perc_area_to_modify >= max_pos_perc_area: + raise ValueError( + f"Value of 'max_perc_area_to_modify' " + f"(= {self.max_perc_area_to_modify}) is too large. " + f"An overestimate of the maximum possible % area that can " + f"be corrupted given {patch_size} x {patch_size} patches " + f"is: {max_pos_perc_area}. Make sure " + f"'max_perc_area_to_modify' is less than this value." ) - if max_perc_area_to_modify >= max_pos_perc_area: - raise ValueError( - f"Value of 'max_perc_area_to_modify' " - f"(= {max_perc_area_to_modify}) is too large. " - f"An overestimate of the maximum possible % area that can " - f"be corrupted given {patch_size} x {patch_size} patches " - f"is: {max_pos_perc_area}. Make sure " - f"'max_perc_area_to_modify' is less than this value." - ) - # If `sampling_pattern` is "poisson", precompute - # `num_precompute` masks - num_samples = ((img_shape[0] * img_shape[1]) * max_perc_area_to_modify) // ( - patch_size**2 + # If `sampling_pattern` is "poisson", precompute + # `num_precompute` masks + num_samples = ((self.img_shape[0] * self.img_shape[1]) * self.max_perc_area_to_modify) // ( + patch_size ** 2 + ) + assert num_samples >= 2, f"Number of samples (= {num_samples}) must be >= 2" + # Ensure number of samples is even + if num_samples % 2: + num_samples -= 1 + + logger.info("Precomputing masks...") + for _ in tqdm.tqdm(range(self.num_precompute)): + pd_mask, _ = generate_poisson_disc_mask( + (self.img_shape[0], self.img_shape[1]), + min_distance=patch_size * np.sqrt(2), + num_samples=num_samples, + patch_size=patch_size, + k=10, ) - assert num_samples >= 2, f"Number of samples (= {num_samples}) must be >= 2" - # Ensure number of samples is even - if num_samples % 2: - num_samples -= 1 - - logger.info("Precomputing masks...") - for _ in tqdm.tqdm(range(num_precompute)): - pd_mask, _ = generate_poisson_disc_mask( - (img_shape[0], img_shape[1]), - min_distance=patch_size * np.sqrt(2), - num_samples=num_samples, - patch_size=patch_size, - k=10, - ) - self.precomputed_masks.append(pd_mask) - logger.info("Finished precomputing masks!") - else: - img_height = img_shape[0] - img_width = img_shape[1] - area_check = np.zeros((img_height, img_width)) - max_area_to_modify = int((img_height * img_width) * max_perc_area_to_modify) - - logger.info("Precomputing masks...") - for _ in tqdm.tqdm(range(num_precompute)): - patch_coords = [] - area_check[:] = 0 - if max_iterations < 0: - num_iterations = np.inf + self.precomputed_masks.append(pd_mask) + logger.info("Finished precomputing masks!") + + def create_random_masks(self): + """Precomputes masks using random sampling""" + + img_height = self.img_shape[0] + img_width = self.img_shape[1] + area_check = np.zeros((img_height, img_width)) + max_area_to_modify = int((img_height * img_width) * self.max_perc_area_to_modify) + + logger.info("Precomputing masks...") + for _ in tqdm.tqdm(range(self.num_precompute)): + patch_coords = [] + area_check[:] = 0 + if self.max_iterations < 0: + num_iterations = np.inf + else: + num_iterations = np.random.randint(self.min_iterations, self.max_iterations + 1) + while (len(patch_coords) / 4) <= num_iterations and np.count_nonzero( + area_check + ) < max_area_to_modify: + patch_height = np.random.randint(self.min_height, self.max_height + 1) + if self.is_square: + patch_width = patch_height else: - num_iterations = np.random.randint(min_iterations, max_iterations + 1) - while (len(patch_coords) / 4) <= num_iterations and np.count_nonzero( - area_check - ) < max_area_to_modify: - patch_height = np.random.randint(min_height, max_height + 1) - if is_square: - patch_width = patch_height - else: - patch_width = np.random.randint(min_width, max_width + 1) - # Get coordinates of first patch - tl_x_1 = np.random.randint(img_width - patch_width) - tl_y_1 = np.random.randint(img_height - patch_height) - br_x_1 = tl_x_1 + patch_width - br_y_1 = tl_y_1 + patch_height - patch_area = (br_x_1 - tl_x_1) * (br_y_1 - tl_y_1) - if patch_area > (max_area_to_modify / 2) and not patch_coords: - continue - patch_coords.append([tl_x_1, tl_y_1]) - patch_coords.append([br_x_1, br_y_1]) - - # Get coordinates of second patch, ensuring the second - # patch will not overlap with the first patch - tl_y_2 = np.random.randint(img_height - patch_height) - if tl_y_2 <= tl_y_1 - patch_height or tl_y_2 >= br_y_1: - tl_x_2 = np.random.randint(img_width - patch_width) - else: - possible_columns = list(range(tl_x_1 - patch_width + 1)) + list( - range(br_x_1, img_width - patch_width) - ) - tl_x_2 = np.random.choice(np.array(possible_columns)) - br_x_2 = tl_x_2 + patch_width - br_y_2 = tl_y_2 + patch_height - patch_coords.append([tl_x_2, tl_y_2]) - patch_coords.append([br_x_2, br_y_2]) - - # Record the areas of image that were modified by the current - # pair of patches - area_check[tl_y_1:br_y_1, tl_x_1:br_x_1] = 1 - area_check[tl_y_2:br_y_2, tl_x_2:br_x_2] = 1 - coord_matrix = np.array(patch_coords).T - self.precomputed_masks.append(coord_matrix) - logger.info("Finished precomputing masks!") + patch_width = np.random.randint(self.min_width, self.max_width + 1) + # Get coordinates of first patch + tl_x_1 = np.random.randint(img_width - patch_width) + tl_y_1 = np.random.randint(img_height - patch_height) + br_x_1 = tl_x_1 + patch_width + br_y_1 = tl_y_1 + patch_height + patch_area = (br_x_1 - tl_x_1) * (br_y_1 - tl_y_1) + if patch_area > (max_area_to_modify / 2) and not patch_coords: + continue + patch_coords.append([tl_x_1, tl_y_1]) + patch_coords.append([br_x_1, br_y_1]) + + # Get coordinates of second patch, ensuring the second + # patch will not overlap with the first patch + tl_y_2 = np.random.randint(img_height - patch_height) + if tl_y_2 <= tl_y_1 - patch_height or tl_y_2 >= br_y_1: + tl_x_2 = np.random.randint(img_width - patch_width) + else: + possible_columns = list(range(tl_x_1 - patch_width + 1)) + list( + range(br_x_1, img_width - patch_width) + ) + tl_x_2 = np.random.choice(np.array(possible_columns)) + br_x_2 = tl_x_2 + patch_width + br_y_2 = tl_y_2 + patch_height + patch_coords.append([tl_x_2, tl_y_2]) + patch_coords.append([br_x_2, br_y_2]) + + # Record the areas of image that were modified by the current + # pair of patches + area_check[tl_y_1:br_y_1, tl_x_1:br_x_1] = 1 + area_check[tl_y_2:br_y_2, tl_x_2:br_x_2] = 1 + coord_matrix = np.array(patch_coords).T + self.precomputed_masks.append(coord_matrix) + logger.info("Finished precomputing masks!") def get_transform(self, img: np.ndarray) -> MedTransform: """ @@ -546,6 +565,7 @@ def get_transform(self, img: np.ndarray) -> MedTransform: Args: img: A N x H x W x C image, containing N images with height H, width W, and consisting of C channels. + Returns: An instance of the MedTransform "Swap2DPatches", initialized with the appropriate parameters based on the @@ -664,10 +684,9 @@ def apply_transform_gens( tfms = [] for g in transform_gens: tfm = g.get_transform(img) if isinstance(g, TransformGen) else g - assert isinstance( - tfm, MedTransform - ), "TransformGen {} must return an instance of MedTransform! " "Got {} instead".format( - g, tfm + assert isinstance(tfm, MedTransform), ( + "TransformGen {} must return an instance of MedTransform! " + "Got {} instead".format(g, tfm) ) img = tfm.apply_image(img) tfms.append(tfm) diff --git a/medsegpy/evaluation/metrics.py b/medsegpy/evaluation/metrics.py index 0a8d5477..a00e6247 100644 --- a/medsegpy/evaluation/metrics.py +++ b/medsegpy/evaluation/metrics.py @@ -69,7 +69,7 @@ def rms_cv(y_pred: np.ndarray, y_true: np.ndarray, dim=None): stds = np.std([y_pred, y_true], axis=0) means = np.mean([y_pred, y_true], axis=0) cv = stds / means - return np.sqrt(np.mean(cv**2, axis=dim)) + return np.sqrt(np.mean(cv ** 2, axis=dim)) def rmse_cv(y_pred: np.ndarray, y_true: np.ndarray, dim=None): diff --git a/medsegpy/modeling/deeplab_2d/deeplab_model.py b/medsegpy/modeling/deeplab_2d/deeplab_model.py index d766f4ce..31a503a7 100755 --- a/medsegpy/modeling/deeplab_2d/deeplab_model.py +++ b/medsegpy/modeling/deeplab_2d/deeplab_model.py @@ -733,6 +733,7 @@ def preprocess_input(self, x): """Preprocesses a numpy array encoding a batch of images. # Arguments x: a 4D numpy array consists of RGB values within [0, 255]. + # Returns Input array scaled to [-1.,1.] """ diff --git a/medsegpy/modeling/layers/convolutional.py b/medsegpy/modeling/layers/convolutional.py index c43dd7db..1749084e 100644 --- a/medsegpy/modeling/layers/convolutional.py +++ b/medsegpy/modeling/layers/convolutional.py @@ -86,3 +86,8 @@ def call(self, inputs): if self.activation is not None: return self.activation(outputs) return outputs + + +# Add aliases for class +ConvWeightStandardization2D = ConvStandardized2D +ConvWS2D = ConvStandardized2D diff --git a/medsegpy/modeling/loading.py b/medsegpy/modeling/loading.py index f25feae3..bcd20d76 100644 --- a/medsegpy/modeling/loading.py +++ b/medsegpy/modeling/loading.py @@ -18,8 +18,10 @@ def model_from_config(config, custom_objects=None): custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. + Returns: A Keras model instance (uncompiled). + Raises: TypeError: if `config` is not a dictionary. """ @@ -41,8 +43,10 @@ def model_from_yaml(yaml_string, custom_objects=None): custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. + Returns: A Keras model instance (uncompiled). + Raises: ImportError: if yaml module is not found. """ @@ -61,6 +65,7 @@ def model_from_json(json_string, custom_objects=None): custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. + Returns: A Keras model instance (uncompiled). """ diff --git a/medsegpy/modeling/meta_arch/inpainting_and_seg_unet.py b/medsegpy/modeling/meta_arch/inpainting_and_seg_unet.py index 1d3ec0be..d67b8212 100644 --- a/medsegpy/modeling/meta_arch/inpainting_and_seg_unet.py +++ b/medsegpy/modeling/meta_arch/inpainting_and_seg_unet.py @@ -83,8 +83,10 @@ def build_model(self, input_tensor=None) -> Model: Builds the encoder network and returns the resulting model. This implementation will overload the abstract method defined in the superclass "ModelBuilder". + Args: input_tensor: The input to the network. + Returns: model: A Model that defines the encoder network. """ @@ -145,6 +147,7 @@ def build_encoder_block( Builds one block of the ContextEncoder. Each block consists of the following structure: [Conv -> Activation] -> BN -> Dropout. + Args: x: Input tensor. num_filters: Number of filters to use for each conv layer. @@ -153,6 +156,7 @@ def build_encoder_block( kernel_initializer: Kernel initializer accepted by `keras.layers.Conv(...)`. dropout: Dropout rate. + Returns: Output of encoder block. """ @@ -195,8 +199,10 @@ def build_encoder_block( def _build_input(input_size: Tuple): """ Creates an input tensor of size "input_size". + Args: input_size: The size of the input tensor for the model. + Returns: A symbolic input (i.e. a placeholder) with size "input_size". """ @@ -211,8 +217,10 @@ def _get_pool_size(x: tf.Tensor): """ Determines the right size of the pooling filter based on the dimensions of the input tensor `x`. + Args: x: The input tensor. + Returns: A list with the same number of elements as dimensions of `x`, where each element is either 2 or 3. @@ -281,8 +289,10 @@ def build_model(self, input_tensor=None) -> Model: resulting model. This implementation will overload the abstract method defined in the superclass "ModelBuilder". + Args: input_tensor: The input to the network. + Returns: model: A Model that defines the full encoder/decoder architecture. """ @@ -379,6 +389,7 @@ def build_decoder_block( Each block of the decoder will have the following structure: Input -> Transposed Convolution -> Concatenate Skip Connection -> Convolution -> BN -> Dropout (optional) + Args: x: Input tensor. x_skip: The output of the next highest layer of the @@ -397,6 +408,7 @@ def build_decoder_block( convolutional and regular convolutional layers. dropout: Dropout rate. + Returns: Output of the decoder block. """ @@ -450,8 +462,10 @@ def build_decoder_block( def _build_input(input_size: Tuple): """ Creates an input tensor of size "input_size". + Args: input_size: The size of the input tensor for the model. + Returns: A symbolic input (i.e. a placeholder) with size "input_size". """ diff --git a/medsegpy/modeling/meta_arch/unet.py b/medsegpy/modeling/meta_arch/unet.py index 9bc9ae70..07c70149 100644 --- a/medsegpy/modeling/meta_arch/unet.py +++ b/medsegpy/modeling/meta_arch/unet.py @@ -154,7 +154,7 @@ def build_model(self, input_tensor=None) -> Model: num_classes = cfg.get_num_classes() num_filters = cfg.NUM_FILTERS if not num_filters: - num_filters = [2**feat * 32 for feat in range(depth)] + num_filters = [2 ** feat * 32 for feat in range(depth)] else: depth = len(num_filters) diff --git a/medsegpy/modeling/unet_2d/anisotropic_unet_model.py b/medsegpy/modeling/unet_2d/anisotropic_unet_model.py index 2d241f61..52d7e4f9 100644 --- a/medsegpy/modeling/unet_2d/anisotropic_unet_model.py +++ b/medsegpy/modeling/unet_2d/anisotropic_unet_model.py @@ -38,7 +38,7 @@ def anisotropic_unet_2d( raise ValueError("input_size must be a tuple of size (height, width, 1)") if num_filters is None: - nfeatures = [2**feat * 32 for feat in np.arange(depth)] + nfeatures = [2 ** feat * 32 for feat in np.arange(depth)] else: nfeatures = num_filters assert len(nfeatures) == depth diff --git a/medsegpy/modeling/unet_2d/residual_unet_model.py b/medsegpy/modeling/unet_2d/residual_unet_model.py index a038baab..c367beb1 100644 --- a/medsegpy/modeling/unet_2d/residual_unet_model.py +++ b/medsegpy/modeling/unet_2d/residual_unet_model.py @@ -143,7 +143,7 @@ def residual_unet_2d( raise ValueError("input_size must be a tuple of size (height, width, 1)") if num_filters is None: - nfeatures = [2**feat * 32 for feat in np.arange(depth)] + nfeatures = [2 ** feat * 32 for feat in np.arange(depth)] else: nfeatures = num_filters assert len(nfeatures) == depth diff --git a/medsegpy/modeling/unet_2d/unet_model.py b/medsegpy/modeling/unet_2d/unet_model.py index b210cdee..0c6e1a0b 100755 --- a/medsegpy/modeling/unet_2d/unet_model.py +++ b/medsegpy/modeling/unet_2d/unet_model.py @@ -53,7 +53,7 @@ def unet_2d_model(input_size=DEFAULT_INPUT_SIZE, input_tensor=None, output_mode= if input_tensor is None and (type(input_size) is not tuple or len(input_size) != 3): raise ValueError("input_size must be a tuple of size (height, width, 1)") - nfeatures = [2**feat * 32 for feat in np.arange(6)] + nfeatures = [2 ** feat * 32 for feat in np.arange(6)] depth = len(nfeatures) conv_ptr = [] @@ -182,7 +182,7 @@ def unet_2d_model_v2( raise ValueError("input_size must be a tuple of size (height, width, 1)") if num_filters is None: - nfeatures = [2**feat * 32 for feat in np.arange(depth)] + nfeatures = [2 ** feat * 32 for feat in np.arange(depth)] else: nfeatures = num_filters assert len(nfeatures) == depth diff --git a/medsegpy/modeling/unet_3d_model.py b/medsegpy/modeling/unet_3d_model.py index eae86411..f648ed9b 100644 --- a/medsegpy/modeling/unet_3d_model.py +++ b/medsegpy/modeling/unet_3d_model.py @@ -32,7 +32,7 @@ def unet_3d_model( raise ValueError("input_size must be a tuple of size (height, width, slices, 1)") if num_filters is None: - nfeatures = [2**feat * 32 for feat in np.arange(depth)] + nfeatures = [2 ** feat * 32 for feat in np.arange(depth)] else: nfeatures = num_filters assert len(nfeatures) == depth diff --git a/setup.py b/setup.py index 05d00ee1..639a675c 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,7 @@ def get_required_packages(): "pandas", "medpy", "numpy", + "numba", "h5py", "natsort", "scipy",