Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions httomolibgpu/memory_estimator_helpers.py

This file was deleted.

6 changes: 3 additions & 3 deletions httomolibgpu/prep/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import numpy as np
from httomolibgpu import cupywrapper
from httomolibgpu.memory_estimator_helpers import _DeviceMemStack
from tomobar.supp.memory_estimator_helpers import DeviceMemStack

cp = cupywrapper.cp
cupy_run = cupywrapper.cupy_run
Expand Down Expand Up @@ -91,7 +91,7 @@ def paganin_filter(
cp.ndarray
The 3D array of Paganin phase-filtered projection images.
"""
mem_stack = _DeviceMemStack() if calc_peak_gpu_mem else None
mem_stack = DeviceMemStack() if calc_peak_gpu_mem else None
# Check the input data is valid
if not mem_stack and tomo.ndim != 3:
raise ValueError(
Expand Down Expand Up @@ -301,7 +301,7 @@ def _pad_projections(
"next_power_of_2", "next_fast_length", "use_pad_x_y"
],
pad_x_y: Optional[list],
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> Tuple[cp.ndarray, Tuple[int, int]]:
"""
Performs padding of each projection to a size optimal for FFT.
Expand Down
41 changes: 9 additions & 32 deletions httomolibgpu/prep/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from unittest.mock import Mock

if cupy_run:
from tomobar.supp.memory_estimator_helpers import DeviceMemStack
from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d
from cupyx.scipy.fft import fft2, ifft2, fftshift
from cupyx.scipy.fftpack import get_fft_plan
Expand Down Expand Up @@ -204,32 +205,8 @@ def _reflect(x: np.ndarray, minx: float, maxx: float) -> np.ndarray:
return np.array(out, dtype=x.dtype)


class _DeviceMemStack:
def __init__(self) -> None:
self.allocations = []
self.current = 0
self.highwater = 0

def malloc(self, bytes):
self.allocations.append(bytes)
allocated = self._round_up(bytes)
self.current += allocated
self.highwater = max(self.current, self.highwater)

def free(self, bytes):
assert bytes in self.allocations
self.allocations.remove(bytes)
self.current -= self._round_up(bytes)
assert self.current >= 0

def _round_up(self, size):
ALLOCATION_UNIT_SIZE = 512
size = (size + ALLOCATION_UNIT_SIZE - 1) // ALLOCATION_UNIT_SIZE
return size * ALLOCATION_UNIT_SIZE


def _mypad(
x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[_DeviceMemStack]
x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[DeviceMemStack]
) -> cp.ndarray:
"""Function to do numpy like padding on Arrays. Only works for 2-D
padding.
Expand Down Expand Up @@ -272,7 +249,7 @@ def _conv2d(
w: np.ndarray,
stride: Tuple[int, int],
groups: int,
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> cp.ndarray:
"""Convolution (equivalent pytorch.conv2d)"""
b, ci, hi, wi = x.shape if not mem_stack else x
Expand Down Expand Up @@ -355,7 +332,7 @@ def _conv_transpose2d(
stride: Tuple[int, int],
pad: Tuple[int, int],
groups: int,
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> cp.ndarray:
"""Transposed convolution (equivalent pytorch.conv_transpose2d)"""
b, co, ho, wo = x.shape if not mem_stack else x
Expand Down Expand Up @@ -428,7 +405,7 @@ def _afb1d(
h0: np.ndarray,
h1: np.ndarray,
dim: int,
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> cp.ndarray:
"""1D analysis filter bank (along one dimension only) of an image

Expand Down Expand Up @@ -476,7 +453,7 @@ def _sfb1d(
g0: np.ndarray,
g1: np.ndarray,
dim: int,
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> cp.ndarray:
"""1D synthesis filter bank of an image Array"""

Expand Down Expand Up @@ -520,7 +497,7 @@ def __init__(self, wave: str):
self.h1_row = np.array(h1_row).astype("float32")[::-1].reshape((1, 1, 1, -1))

def apply(
self, x: cp.ndarray, mem_stack: Optional[_DeviceMemStack] = None
self, x: cp.ndarray, mem_stack: Optional[DeviceMemStack] = None
) -> Tuple[cp.ndarray, cp.ndarray]:
"""Forward pass of the DWT.

Expand Down Expand Up @@ -582,7 +559,7 @@ def __init__(self, wave: str):
def apply(
self,
coeffs: Tuple[cp.ndarray, cp.ndarray],
mem_stack: Optional[_DeviceMemStack] = None,
mem_stack: Optional[DeviceMemStack] = None,
) -> cp.ndarray:
"""
Args:
Expand Down Expand Up @@ -672,7 +649,7 @@ def remove_stripe_fw(
sli_shape = [nz, 1, nproj_pad, ni]

if calc_peak_gpu_mem:
mem_stack = _DeviceMemStack()
mem_stack = DeviceMemStack()
# A data copy is assumed when invoking the function
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
mem_stack.malloc(np.prod(sli_shape) * np.float32().itemsize)
Expand Down
34 changes: 23 additions & 11 deletions httomolibgpu/recon/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from unittest.mock import Mock

if cupy_run:
from tomobar.supp.memory_estimator_helpers import DeviceMemStack
from tomobar.methodsDIR import RecToolsDIR
from tomobar.methodsDIR_CuPy import RecToolsDIRCuPy
from tomobar.methodsIR_CuPy import RecToolsIRCuPy
Expand All @@ -38,7 +39,7 @@
RecToolsIRCuPy = Mock()

from numpy import float32
from typing import Optional, Type, Union
from typing import Optional, Tuple, Type, Union


__all__ = [
Expand Down Expand Up @@ -192,7 +193,7 @@ def FBP3d_tomobar(

## %%%%%%%%%%%%%%%%%%%%%%% LPRec %%%%%%%%%%%%%%%%%%%%%%%%%%%% ##
def LPRec3d_tomobar(
data: cp.ndarray,
data: cp.ndarray | Tuple[int, int, int],
angles: np.ndarray,
center: Optional[float] = None,
detector_pad: Union[bool, int] = False,
Expand All @@ -204,6 +205,7 @@ def LPRec3d_tomobar(
power_of_2_cropping: Optional[bool] = False,
min_mem_usage_filter: Optional[bool] = True,
min_mem_usage_ifft2: Optional[bool] = True,
**kwargs,
) -> cp.ndarray:
"""
Fourier direct inversion in 3D on unequally spaced (also called as Log-Polar) grids using
Expand Down Expand Up @@ -232,6 +234,8 @@ def LPRec3d_tomobar(
The radius of the circular mask that applies to the reconstructed slice in order to crop
out some undesirable artifacts. The values outside the given diameter will be set to zero.
To implement the cropping one can use the range [0.7-1.0] or set to 2.0 when no cropping required.
calc_peak_gpu_mem: bool
Parameter to support memory estimation in HTTomo. Irrelevant to the method itself and can be ignored by user.

Returns
-------
Expand All @@ -243,7 +247,7 @@ def LPRec3d_tomobar(
data, angles, center, detector_pad, recon_size, 0
)

reconstruction = RecToolsCP.FOURIER_INV(
result = RecToolsCP.FOURIER_INV(
data,
recon_mask_radius=recon_mask_radius,
data_axes_labels_order=input_data_axis_labels,
Expand All @@ -253,9 +257,14 @@ def LPRec3d_tomobar(
power_of_2_cropping=power_of_2_cropping,
min_mem_usage_filter=min_mem_usage_filter,
min_mem_usage_ifft2=min_mem_usage_ifft2,
**kwargs,
)
cp._default_memory_pool.free_all_blocks()
return cp.require(cp.swapaxes(reconstruction, 0, 1), requirements="C")

if DeviceMemStack.instance():
return result

return cp.require(cp.swapaxes(result, 0, 1), requirements="C")


## %%%%%%%%%%%%%%%%%%%%%%% SIRT reconstruction %%%%%%%%%%%%%%%%%%%%%%%%%%%% ##
Expand Down Expand Up @@ -491,7 +500,7 @@ def FISTA3d_tomobar(

## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ##
def _instantiate_direct_recon_class(
data: cp.ndarray,
data: cp.ndarray | Tuple[int, int, int],
angles: np.ndarray,
center: Optional[float] = None,
detector_pad: Union[bool, int] = False,
Expand All @@ -511,19 +520,22 @@ def _instantiate_direct_recon_class(
Returns:
Type[RecToolsDIRCuPy]: an instance of the direct recon class
"""

data_shape = data if isinstance(data, tuple) else data.shape

if center is None:
center = data.shape[2] // 2 # making a crude guess
center = data_shape[2] // 2 # making a crude guess
if recon_size is None:
recon_size = data.shape[2]
recon_size = data_shape[2]
if detector_pad is True:
detector_pad = __estimate_detectorHoriz_padding(data.shape[2])
detector_pad = __estimate_detectorHoriz_padding(data_shape[2])
elif detector_pad is False:
detector_pad = 0
RecToolsCP = RecToolsDIRCuPy(
DetectorsDimH=data.shape[2], # Horizontal detector dimension
DetectorsDimH=data_shape[2], # Horizontal detector dimension
DetectorsDimH_pad=detector_pad, # padding for horizontal detector
DetectorsDimV=data.shape[1], # Vertical detector dimension (3D case)
CenterRotOffset=data.shape[2] / 2
DetectorsDimV=data_shape[1], # Vertical detector dimension (3D case)
CenterRotOffset=data_shape[2] / 2
- center
- 0.5, # Center of Rotation scalar or a vector
AnglesVec=-angles, # A vector of projection angles in radians
Expand Down
Loading