curvelets.torch package¶
PyTorch implementation of Uniform Discrete Curvelet Transform (UDCT).
- class curvelets.torch.MeyerWavelet(shape)
Bases:
objectMulti-dimensional Meyer wavelet transform with pre-computed filters.
This class provides forward and backward Meyer wavelet transforms with filters pre-computed during initialization for improved performance. The forward transform returns all subbands in a nested list structure, and the backward transform accepts this full structure for reconstruction.
The filter computation uses the same method as UDCT’s _create_bandpass_windows() for num_scales=2, ensuring compatibility when used with UDCT in wavelet mode with num_scales=2 and high_frequency_mode=”wavelet”.
- Parameters:
shape (tuple[int, ...]) – Expected shape of input signals. Used for validation and to determine the number of dimensions. All dimensions must be even (divisible by 2).
- dimension
Number of dimensions.
- Type:
- Raises:
ValueError – If any dimension in shape is odd (not divisible by 2).
Notes
This implementation requires all dimensions to be even. Odd-length signals are not supported due to the downsampling strategy used in the transform, which produces mismatched subband sizes that cannot be correctly reconstructed.
Examples
>>> import torch >>> from curvelets.torch import MeyerWavelet >>> wavelet = MeyerWavelet(shape=(64, 64)) >>> signal = torch.randn(64, 64) >>> coefficients = wavelet.forward(signal) >>> len(coefficients) # 2 subband groups 2 >>> coefficients[0][0].shape # Lowpass subband torch.Size([32, 32]) >>> len(coefficients[1]) # Highpass subbands 3 >>> reconstructed = wavelet.backward(coefficients) >>> torch.allclose(signal, reconstructed, atol=1e-10) True >>> # Odd dimensions raise an error >>> try: ... MeyerWavelet(shape=(65, 65)) ... except ValueError: ... print("Odd dimensions not supported") Odd dimensions not supported
- backward(coefficients)
Apply multi-dimensional Meyer wavelet inverse transform.
Reconstructs the original signal from the full coefficient structure returned by forward().
- Parameters:
coefficients (list[list[torch.Tensor]]) –
Full coefficient structure from forward() with 2 subband groups:
coefficients[0]: [lowpass] - single lowpass subband
coefficients[1]: [highpass_0, highpass_1, …] - all highpass subbands
- Returns:
Reconstructed signal with shape matching the original input.
- Return type:
- Raises:
ValueError – If coefficients structure is invalid (must have 2 subband groups).
Examples
>>> import torch >>> from curvelets.torch import MeyerWavelet >>> wavelet = MeyerWavelet(shape=(64, 64)) >>> signal = torch.randn(64, 64) >>> coefficients = wavelet.forward(signal) >>> reconstructed = wavelet.backward(coefficients) >>> torch.allclose(signal, reconstructed, atol=1e-10) True
- forward(signal)
Apply multi-dimensional Meyer wavelet forward transform.
Decomposes the input signal into 2^dimension subbands organized into 2 subband groups. Returns all subbands in a nested list structure similar to UDCT.
- Parameters:
signal (torch.Tensor) – Input tensor (real or complex). Must match the shape specified during initialization.
- Returns:
All subbands organized into 2 subband groups:
coefficients[0]: [lowpass] - single lowpass subband (1 subband)
coefficients[1]: [highpass_0, highpass_1, …] - all highpass subbands (2^dimension - 1 subbands)
Each subband has shape approximately half the input shape in each dimension.
- Return type:
- Raises:
ValueError – If signal shape does not match expected shape.
Examples
>>> import torch >>> from curvelets.torch import MeyerWavelet >>> wavelet = MeyerWavelet(shape=(64, 64)) >>> signal = torch.randn(64, 64) >>> coefficients = wavelet.forward(signal) >>> len(coefficients) # 2 subband groups 2 >>> coefficients[0][0].shape # Lowpass subband torch.Size([32, 32]) >>> len(coefficients[1]) # Highpass subbands 3
- class curvelets.torch.UDCT(shape, angular_wedges_config=None, num_scales=None, wedges_per_direction=None, window_overlap=None, radial_frequency_params=None, window_threshold=1e-05, high_frequency_mode='curvelet', transform_kind='real')
Bases:
objectUniform Discrete Curvelet Transform (UDCT) implementation.
This class provides forward and backward curvelet transforms with support for both real and complex transforms.
- Parameters:
angular_wedges_config (torch.Tensor, optional) – Configuration tensor specifying the number of angular wedges per scale and dimension. Shape is (num_scales - 1, dimension), where num_scales includes the lowpass scale. If provided, cannot be used together with num_scales/wedges_per_direction. Default is None.
num_scales (int, optional) – Total number of scales (including lowpass scale 0). Must be >= 2. Used when angular_wedges_config is not provided. Default is 3.
wedges_per_direction (int, optional) – Number of angular wedges per direction at the coarsest scale. The number of wedges doubles at each finer scale. Must be >= 3. Used when angular_wedges_config is not provided. Default is 3.
window_overlap (float, optional) – Window overlap parameter controlling the smoothness of window transitions. If None and using num_scales/wedges_per_direction, automatically chosen based on wedges_per_direction. Default is None (auto) or 0.15.
radial_frequency_params (tuple[float, float, float, float], optional) – Radial frequency parameters defining the frequency bands. Default is (\(\pi/3\), \(2\pi/3\), \(2\pi/3\), \(4\pi/3\)).
window_threshold (float, optional) – Threshold for sparse window storage (values below this are stored as sparse). Default is 1e-5.
high_frequency_mode ({"curvelet", "wavelet"}, optional) – High frequency mode. “curvelet” uses curvelets at all scales, “wavelet” creates a single ring-shaped window (bandpass filter only, no angular components) at the highest scale with decimation=1. Default is “curvelet”.
transform_kind ({"real", "complex", "monogenic"}, optional) –
Type of transform to use:
”real” (default): Real transform where each band captures both positive and negative frequencies combined.
”complex”: Complex transform which separates positive and negative frequency components into different bands. Each band is scaled by \(\sqrt{0.5}\).
”monogenic”: Monogenic transform that extends the curvelet transform by applying Riesz transforms, producing ndim+1 components per band (scalar plus all Riesz components).
- high_frequency_mode
High frequency mode.
- Type:
- transform_kind
Type of transform being used (“real”, “complex”, or “monogenic”).
- Type:
- parameters
Internal UDCT parameters.
- Type:
ParamUDCT
- windows
Curvelet windows in sparse format.
- Type:
UDCTWindows
- decimation_ratios
Decimation ratios for each scale/direction.
- Type:
Examples
>>> import torch >>> from curvelets.torch import UDCT >>> # Create a 2D transform using num_scales (simplified interface) >>> transform = UDCT(shape=(64, 64), num_scales=3, wedges_per_direction=3) >>> data = torch.randn(64, 64) >>> coeffs = transform.forward(data) >>> recon = transform.backward(coeffs) >>> torch.allclose(data, recon, atol=1e-4) True >>> # Create using angular_wedges_config (advanced interface) >>> cfg = torch.tensor([[3, 3], [6, 6]], dtype=torch.int64) >>> transform2 = UDCT(shape=(64, 64), angular_wedges_config=cfg) >>> coeffs2 = transform2.forward(data) >>> recon2 = transform2.backward(coeffs2) >>> torch.allclose(data, recon2, atol=1e-4) True
- apply_to_tensors(fn)
Apply a function to all internal tensors.
This method is used for device and dtype transfers (e.g., when calling model.to(device) on a module containing this UDCT instance). It applies the given function to all internal tensor attributes including windows and decimation ratios.
- Parameters:
fn (Callable[[torch.Tensor], torch.Tensor]) – Function to apply to each tensor. Typically this is a function like lambda t: t.cuda() or lambda t: t.to(device).
- Return type:
Examples
>>> import torch >>> from curvelets.torch import UDCT >>> transform = UDCT(shape=(64, 64), num_scales=3, wedges_per_direction=3) >>> # Move all internal tensors to GPU (if available) >>> if torch.cuda.is_available(): ... transform.apply_to_tensors(lambda t: t.cuda())
- backward(coefficients)
Apply backward (inverse) curvelet transform.
- Parameters:
coefficients (UDCTCoefficients) – Curvelet coefficients from forward transform.
- Returns:
Reconstructed image with shape self.shape.
- Return type:
- forward(image)
Apply forward curvelet transform.
- Parameters:
image (torch.Tensor) – Input image with shape matching self.shape. - For transform_kind=”real” or “monogenic”: must be real-valued - For transform_kind=”complex”: can be real-valued or complex-valued
- Returns:
Curvelet coefficients organized by scale, direction, and wedge. For monogenic transforms, returns MUDCTCoefficients where each coefficient is a list of ndim+1 arrays.
- Return type:
UDCTCoefficients | MUDCTCoefficients
- property ndim: int
Number of dimensions.
- property num_scales: int
Number of scales.
- struct(vector)
Restructure vectorized coefficients to nested list format.
- Parameters:
vector (torch.Tensor) – 1D tensor of coefficients.
- Returns:
Restructured coefficients. The dtype is preserved from the forward transform if available, otherwise uses the vector’s dtype.
- Return type:
UDCTCoefficients
Examples
>>> import torch >>> from curvelets.torch import UDCT >>> transform = UDCT(shape=(64, 64), angular_wedges_config=torch.tensor([[3, 3]])) >>> data = torch.randn(64, 64) >>> coeffs = transform.forward(data) >>> vec = transform.vect(coeffs) >>> coeffs_recon = transform.struct(vec) >>> len(coeffs_recon) == len(coeffs) True
- vect(coefficients)
Vectorize curvelet coefficients.
- Parameters:
coefficients (UDCTCoefficients) – Curvelet coefficients.
- Returns:
1D tensor containing all coefficients.
- Return type:
- class curvelets.torch.UDCTModule(shape, angular_wedges_config=None, num_scales=None, wedges_per_direction=None, window_overlap=None, radial_frequency_params=None, window_threshold=1e-05, high_frequency_mode='curvelet', transform_type='real')
Bases:
ModulePyTorch nn.Module wrapper for UDCT with autograd support.
PyTorch module with automatic differentiation. When called, it returns flattened coefficients as a single tensor, enabling gradient computation through the backward transform. The backward transform is automatically used in the autograd graph when computing gradients.
- Parameters:
angular_wedges_config (torch.Tensor, optional) – Configuration tensor specifying the number of angular wedges per scale and dimension. Shape is (num_scales - 1, dimension), where num_scales includes the lowpass scale. If provided, cannot be used together with num_scales/wedges_per_direction. Default is None.
num_scales (int, optional) – Total number of scales (including lowpass scale 0). Must be >= 2. Used when angular_wedges_config is not provided. Default is 3.
wedges_per_direction (int, optional) – Number of angular wedges per direction at the coarsest scale. The number of wedges doubles at each finer scale. Must be >= 3. Used when angular_wedges_config is not provided. Default is 3.
window_overlap (float, optional) – Window overlap parameter controlling the smoothness of window transitions. If None and using num_scales/wedges_per_direction, automatically chosen based on wedges_per_direction. Default is None (auto) or 0.15.
radial_frequency_params (tuple[float, float, float, float], optional) – Radial frequency band parameters. Default is (\(\pi/3\), \(2\pi/3\), \(2\pi/3\), \(4\pi/3\)).
window_threshold (float, optional) – Threshold for sparse window storage. Default is 1e-5.
high_frequency_mode ({"curvelet", "wavelet"}, optional) – High frequency mode. Default is “curvelet”.
transform_type (
"real"or"complex", optional) –Type of transform to use:
"real": Real transform (default). Each band captures both positive and negative frequencies combined."complex": Complex transform. Positive and negative frequency bands are separated into different directions.
Default is
"real".
Examples
>>> import torch >>> from curvelets.torch import UDCTModule >>> >>> # Create module using num_scales (simplified interface) >>> udct = UDCTModule(shape=(64, 64), num_scales=3, wedges_per_direction=3) >>> input_tensor = torch.randn(64, 64, dtype=torch.float64, requires_grad=True) >>> output = udct(input_tensor) # Returns flattened coefficients tensor >>> >>> # Create module using angular_wedges_config (advanced interface) >>> cfg = torch.tensor([[3, 3], [6, 6]], dtype=torch.int64) >>> udct2 = UDCTModule(shape=(64, 64), angular_wedges_config=cfg) >>> output2 = udct2(input_tensor) >>> >>> # Create module with complex transform >>> udct_complex = UDCTModule( ... shape=(64, 64), ... num_scales=3, ... wedges_per_direction=3, ... transform_type="complex" ... ) >>> output_complex = udct_complex(input_tensor) >>> >>> # Test with gradcheck >>> torch.autograd.gradcheck(udct, input_tensor, atol=1e-5, rtol=1e-3) True >>> >>> # Use struct() to convert flattened coefficients to nested structure >>> coeffs_nested = udct.struct(output.detach())
- forward(image)
Forward pass: compute forward transform and return flattened coefficients.
- Parameters:
image (torch.Tensor) – Input image with shape matching self.shape.
- Returns:
Flattened curvelet coefficients as a single tensor.
- Return type:
- property ndim: int
Number of dimensions.
- property num_scales: int
Number of scales.
- struct(vector)
Restructure vectorized coefficients to nested list format.
- Parameters:
vector (torch.Tensor) – 1D tensor of coefficients.
- Returns:
Restructured coefficients. The dtype is preserved from the forward transform if available, otherwise uses the vector’s dtype.
- Return type:
UDCTCoefficients
Examples
>>> import torch >>> from curvelets.torch import UDCTModule >>> udct = UDCTModule(shape=(64, 64), angular_wedges_config=torch.tensor([[3, 3]])) >>> input_tensor = torch.randn(64, 64) >>> output = udct(input_tensor) >>> coeffs_nested = udct.struct(output.detach()) >>> len(coeffs_nested) > 0 True
- vect(coefficients)
Vectorize curvelet coefficients.
- Parameters:
coefficients (UDCTCoefficients) – Curvelet coefficients.
- Returns:
1D tensor containing all coefficients.
- Return type:
- class curvelets.torch.UDCTModule(shape, angular_wedges_config=None, num_scales=None, wedges_per_direction=None, window_overlap=None, radial_frequency_params=None, window_threshold=1e-05, high_frequency_mode='curvelet', transform_type='real')¶
Bases:
ModulePyTorch nn.Module wrapper for UDCT with autograd support.
PyTorch module with automatic differentiation. When called, it returns flattened coefficients as a single tensor, enabling gradient computation through the backward transform. The backward transform is automatically used in the autograd graph when computing gradients.
- Parameters:
angular_wedges_config (torch.Tensor, optional) – Configuration tensor specifying the number of angular wedges per scale and dimension. Shape is (num_scales - 1, dimension), where num_scales includes the lowpass scale. If provided, cannot be used together with num_scales/wedges_per_direction. Default is None.
num_scales (int, optional) – Total number of scales (including lowpass scale 0). Must be >= 2. Used when angular_wedges_config is not provided. Default is 3.
wedges_per_direction (int, optional) – Number of angular wedges per direction at the coarsest scale. The number of wedges doubles at each finer scale. Must be >= 3. Used when angular_wedges_config is not provided. Default is 3.
window_overlap (float, optional) – Window overlap parameter controlling the smoothness of window transitions. If None and using num_scales/wedges_per_direction, automatically chosen based on wedges_per_direction. Default is None (auto) or 0.15.
radial_frequency_params (tuple[float, float, float, float], optional) – Radial frequency band parameters. Default is (\(\pi/3\), \(2\pi/3\), \(2\pi/3\), \(4\pi/3\)).
window_threshold (float, optional) – Threshold for sparse window storage. Default is 1e-5.
high_frequency_mode ({"curvelet", "wavelet"}, optional) – High frequency mode. Default is “curvelet”.
transform_type (
"real"or"complex", optional) –Type of transform to use:
"real": Real transform (default). Each band captures both positive and negative frequencies combined."complex": Complex transform. Positive and negative frequency bands are separated into different directions.
Default is
"real".
Examples
>>> import torch >>> from curvelets.torch import UDCTModule >>> >>> # Create module using num_scales (simplified interface) >>> udct = UDCTModule(shape=(64, 64), num_scales=3, wedges_per_direction=3) >>> input_tensor = torch.randn(64, 64, dtype=torch.float64, requires_grad=True) >>> output = udct(input_tensor) # Returns flattened coefficients tensor >>> >>> # Create module using angular_wedges_config (advanced interface) >>> cfg = torch.tensor([[3, 3], [6, 6]], dtype=torch.int64) >>> udct2 = UDCTModule(shape=(64, 64), angular_wedges_config=cfg) >>> output2 = udct2(input_tensor) >>> >>> # Create module with complex transform >>> udct_complex = UDCTModule( ... shape=(64, 64), ... num_scales=3, ... wedges_per_direction=3, ... transform_type="complex" ... ) >>> output_complex = udct_complex(input_tensor) >>> >>> # Test with gradcheck >>> torch.autograd.gradcheck(udct, input_tensor, atol=1e-5, rtol=1e-3) True >>> >>> # Use struct() to convert flattened coefficients to nested structure >>> coeffs_nested = udct.struct(output.detach())
- forward(image)¶
Forward pass: compute forward transform and return flattened coefficients.
- Parameters:
image (torch.Tensor) – Input image with shape matching self.shape.
- Returns:
Flattened curvelet coefficients as a single tensor.
- Return type:
- struct(vector)¶
Restructure vectorized coefficients to nested list format.
- Parameters:
vector (torch.Tensor) – 1D tensor of coefficients.
- Returns:
Restructured coefficients. The dtype is preserved from the forward transform if available, otherwise uses the vector’s dtype.
- Return type:
UDCTCoefficients
Examples
>>> import torch >>> from curvelets.torch import UDCTModule >>> udct = UDCTModule(shape=(64, 64), angular_wedges_config=torch.tensor([[3, 3]])) >>> input_tensor = torch.randn(64, 64) >>> output = udct(input_tensor) >>> coeffs_nested = udct.struct(output.detach()) >>> len(coeffs_nested) > 0 True
- vect(coefficients)¶
Vectorize curvelet coefficients.
- Parameters:
coefficients (UDCTCoefficients) – Curvelet coefficients.
- Returns:
1D tensor containing all coefficients.
- Return type: