curvelets.torch package

PyTorch implementation of Uniform Discrete Curvelet Transform (UDCT).

class curvelets.torch.MeyerWavelet(shape)

Bases: object

Multi-dimensional Meyer wavelet transform with pre-computed filters.

This class provides forward and backward Meyer wavelet transforms with filters pre-computed during initialization for improved performance. The forward transform returns all subbands in a nested list structure, and the backward transform accepts this full structure for reconstruction.

The filter computation uses the same method as UDCT’s _create_bandpass_windows() for num_scales=2, ensuring compatibility when used with UDCT in wavelet mode with num_scales=2 and high_frequency_mode=”wavelet”.

Parameters:

shape (tuple[int, ...]) – Expected shape of input signals. Used for validation and to determine the number of dimensions. All dimensions must be even (divisible by 2).

shape

Expected signal shape (all dimensions are even).

Type:

tuple[int, …]

dimension

Number of dimensions.

Type:

int

Raises:

ValueError – If any dimension in shape is odd (not divisible by 2).

Notes

This implementation requires all dimensions to be even. Odd-length signals are not supported due to the downsampling strategy used in the transform, which produces mismatched subband sizes that cannot be correctly reconstructed.

Examples

>>> import torch
>>> from curvelets.torch import MeyerWavelet
>>> wavelet = MeyerWavelet(shape=(64, 64))
>>> signal = torch.randn(64, 64)
>>> coefficients = wavelet.forward(signal)
>>> len(coefficients)  # 2 subband groups
2
>>> coefficients[0][0].shape  # Lowpass subband
torch.Size([32, 32])
>>> len(coefficients[1])  # Highpass subbands
3
>>> reconstructed = wavelet.backward(coefficients)
>>> torch.allclose(signal, reconstructed, atol=1e-10)
True
>>> # Odd dimensions raise an error
>>> try:
...     MeyerWavelet(shape=(65, 65))
... except ValueError:
...     print("Odd dimensions not supported")
Odd dimensions not supported
backward(coefficients)

Apply multi-dimensional Meyer wavelet inverse transform.

Reconstructs the original signal from the full coefficient structure returned by forward().

Parameters:

coefficients (list[list[torch.Tensor]]) –

Full coefficient structure from forward() with 2 subband groups:

  • coefficients[0]: [lowpass] - single lowpass subband

  • coefficients[1]: [highpass_0, highpass_1, …] - all highpass subbands

Returns:

Reconstructed signal with shape matching the original input.

Return type:

torch.Tensor

Raises:

ValueError – If coefficients structure is invalid (must have 2 subband groups).

Examples

>>> import torch
>>> from curvelets.torch import MeyerWavelet
>>> wavelet = MeyerWavelet(shape=(64, 64))
>>> signal = torch.randn(64, 64)
>>> coefficients = wavelet.forward(signal)
>>> reconstructed = wavelet.backward(coefficients)
>>> torch.allclose(signal, reconstructed, atol=1e-10)
True
forward(signal)

Apply multi-dimensional Meyer wavelet forward transform.

Decomposes the input signal into 2^dimension subbands organized into 2 subband groups. Returns all subbands in a nested list structure similar to UDCT.

Parameters:

signal (torch.Tensor) – Input tensor (real or complex). Must match the shape specified during initialization.

Returns:

All subbands organized into 2 subband groups:

  • coefficients[0]: [lowpass] - single lowpass subband (1 subband)

  • coefficients[1]: [highpass_0, highpass_1, …] - all highpass subbands (2^dimension - 1 subbands)

Each subband has shape approximately half the input shape in each dimension.

Return type:

list[list[torch.Tensor]]

Raises:

ValueError – If signal shape does not match expected shape.

Examples

>>> import torch
>>> from curvelets.torch import MeyerWavelet
>>> wavelet = MeyerWavelet(shape=(64, 64))
>>> signal = torch.randn(64, 64)
>>> coefficients = wavelet.forward(signal)
>>> len(coefficients)  # 2 subband groups
2
>>> coefficients[0][0].shape  # Lowpass subband
torch.Size([32, 32])
>>> len(coefficients[1])  # Highpass subbands
3
class curvelets.torch.UDCT(shape, angular_wedges_config=None, num_scales=None, wedges_per_direction=None, window_overlap=None, radial_frequency_params=None, window_threshold=1e-05, high_frequency_mode='curvelet', transform_kind='real')

Bases: object

Uniform Discrete Curvelet Transform (UDCT) implementation.

This class provides forward and backward curvelet transforms with support for both real and complex transforms.

Parameters:
  • shape (tuple[int, ...]) – Shape of the input data.

  • angular_wedges_config (torch.Tensor, optional) – Configuration tensor specifying the number of angular wedges per scale and dimension. Shape is (num_scales - 1, dimension), where num_scales includes the lowpass scale. If provided, cannot be used together with num_scales/wedges_per_direction. Default is None.

  • num_scales (int, optional) – Total number of scales (including lowpass scale 0). Must be >= 2. Used when angular_wedges_config is not provided. Default is 3.

  • wedges_per_direction (int, optional) – Number of angular wedges per direction at the coarsest scale. The number of wedges doubles at each finer scale. Must be >= 3. Used when angular_wedges_config is not provided. Default is 3.

  • window_overlap (float, optional) – Window overlap parameter controlling the smoothness of window transitions. If None and using num_scales/wedges_per_direction, automatically chosen based on wedges_per_direction. Default is None (auto) or 0.15.

  • radial_frequency_params (tuple[float, float, float, float], optional) – Radial frequency parameters defining the frequency bands. Default is (\(\pi/3\), \(2\pi/3\), \(2\pi/3\), \(4\pi/3\)).

  • window_threshold (float, optional) – Threshold for sparse window storage (values below this are stored as sparse). Default is 1e-5.

  • high_frequency_mode ({"curvelet", "wavelet"}, optional) – High frequency mode. “curvelet” uses curvelets at all scales, “wavelet” creates a single ring-shaped window (bandpass filter only, no angular components) at the highest scale with decimation=1. Default is “curvelet”.

  • transform_kind ({"real", "complex", "monogenic"}, optional) –

    Type of transform to use:

    • ”real” (default): Real transform where each band captures both positive and negative frequencies combined.

    • ”complex”: Complex transform which separates positive and negative frequency components into different bands. Each band is scaled by \(\sqrt{0.5}\).

    • ”monogenic”: Monogenic transform that extends the curvelet transform by applying Riesz transforms, producing ndim+1 components per band (scalar plus all Riesz components).

shape

Shape of the input data.

Type:

tuple[int, …]

high_frequency_mode

High frequency mode.

Type:

str

transform_kind

Type of transform being used (“real”, “complex”, or “monogenic”).

Type:

str

parameters

Internal UDCT parameters.

Type:

ParamUDCT

windows

Curvelet windows in sparse format.

Type:

UDCTWindows

decimation_ratios

Decimation ratios for each scale/direction.

Type:

list

Examples

>>> import torch
>>> from curvelets.torch import UDCT
>>> # Create a 2D transform using num_scales (simplified interface)
>>> transform = UDCT(shape=(64, 64), num_scales=3, wedges_per_direction=3)
>>> data = torch.randn(64, 64)
>>> coeffs = transform.forward(data)
>>> recon = transform.backward(coeffs)
>>> torch.allclose(data, recon, atol=1e-4)
True
>>> # Create using angular_wedges_config (advanced interface)
>>> cfg = torch.tensor([[3, 3], [6, 6]], dtype=torch.int64)
>>> transform2 = UDCT(shape=(64, 64), angular_wedges_config=cfg)
>>> coeffs2 = transform2.forward(data)
>>> recon2 = transform2.backward(coeffs2)
>>> torch.allclose(data, recon2, atol=1e-4)
True
apply_to_tensors(fn)

Apply a function to all internal tensors.

This method is used for device and dtype transfers (e.g., when calling model.to(device) on a module containing this UDCT instance). It applies the given function to all internal tensor attributes including windows and decimation ratios.

Parameters:

fn (Callable[[torch.Tensor], torch.Tensor]) – Function to apply to each tensor. Typically this is a function like lambda t: t.cuda() or lambda t: t.to(device).

Return type:

None

Examples

>>> import torch
>>> from curvelets.torch import UDCT
>>> transform = UDCT(shape=(64, 64), num_scales=3, wedges_per_direction=3)
>>> # Move all internal tensors to GPU (if available)
>>> if torch.cuda.is_available():
...     transform.apply_to_tensors(lambda t: t.cuda())
backward(coefficients)

Apply backward (inverse) curvelet transform.

Parameters:

coefficients (UDCTCoefficients) – Curvelet coefficients from forward transform.

Returns:

Reconstructed image with shape self.shape.

Return type:

torch.Tensor

property decimation_ratios: list[Tensor]

Decimation ratios for each scale.

forward(image)

Apply forward curvelet transform.

Parameters:

image (torch.Tensor) – Input image with shape matching self.shape. - For transform_kind=”real” or “monogenic”: must be real-valued - For transform_kind=”complex”: can be real-valued or complex-valued

Returns:

Curvelet coefficients organized by scale, direction, and wedge. For monogenic transforms, returns MUDCTCoefficients where each coefficient is a list of ndim+1 arrays.

Return type:

UDCTCoefficients | MUDCTCoefficients

property ndim: int

Number of dimensions.

property num_scales: int

Number of scales.

property shape: tuple[int, ...]

Shape of the transform.

struct(vector)

Restructure vectorized coefficients to nested list format.

Parameters:

vector (torch.Tensor) – 1D tensor of coefficients.

Returns:

Restructured coefficients. The dtype is preserved from the forward transform if available, otherwise uses the vector’s dtype.

Return type:

UDCTCoefficients

Examples

>>> import torch
>>> from curvelets.torch import UDCT
>>> transform = UDCT(shape=(64, 64), angular_wedges_config=torch.tensor([[3, 3]]))
>>> data = torch.randn(64, 64)
>>> coeffs = transform.forward(data)
>>> vec = transform.vect(coeffs)
>>> coeffs_recon = transform.struct(vec)
>>> len(coeffs_recon) == len(coeffs)
True
vect(coefficients)

Vectorize curvelet coefficients.

Parameters:

coefficients (UDCTCoefficients) – Curvelet coefficients.

Returns:

1D tensor containing all coefficients.

Return type:

torch.Tensor

property windows: list[list[list[tuple[Tensor, Tensor]]]]

Curvelet windows in sparse format.

class curvelets.torch.UDCTModule(shape, angular_wedges_config=None, num_scales=None, wedges_per_direction=None, window_overlap=None, radial_frequency_params=None, window_threshold=1e-05, high_frequency_mode='curvelet', transform_type='real')

Bases: Module

PyTorch nn.Module wrapper for UDCT with autograd support.

PyTorch module with automatic differentiation. When called, it returns flattened coefficients as a single tensor, enabling gradient computation through the backward transform. The backward transform is automatically used in the autograd graph when computing gradients.

Parameters:
  • shape (tuple[int, ...]) – Shape of the input data.

  • angular_wedges_config (torch.Tensor, optional) – Configuration tensor specifying the number of angular wedges per scale and dimension. Shape is (num_scales - 1, dimension), where num_scales includes the lowpass scale. If provided, cannot be used together with num_scales/wedges_per_direction. Default is None.

  • num_scales (int, optional) – Total number of scales (including lowpass scale 0). Must be >= 2. Used when angular_wedges_config is not provided. Default is 3.

  • wedges_per_direction (int, optional) – Number of angular wedges per direction at the coarsest scale. The number of wedges doubles at each finer scale. Must be >= 3. Used when angular_wedges_config is not provided. Default is 3.

  • window_overlap (float, optional) – Window overlap parameter controlling the smoothness of window transitions. If None and using num_scales/wedges_per_direction, automatically chosen based on wedges_per_direction. Default is None (auto) or 0.15.

  • radial_frequency_params (tuple[float, float, float, float], optional) – Radial frequency band parameters. Default is (\(\pi/3\), \(2\pi/3\), \(2\pi/3\), \(4\pi/3\)).

  • window_threshold (float, optional) – Threshold for sparse window storage. Default is 1e-5.

  • high_frequency_mode ({"curvelet", "wavelet"}, optional) – High frequency mode. Default is “curvelet”.

  • transform_type ("real" or "complex", optional) –

    Type of transform to use:

    • "real": Real transform (default). Each band captures both positive and negative frequencies combined.

    • "complex": Complex transform. Positive and negative frequency bands are separated into different directions.

    Default is "real".

Examples

>>> import torch
>>> from curvelets.torch import UDCTModule
>>>
>>> # Create module using num_scales (simplified interface)
>>> udct = UDCTModule(shape=(64, 64), num_scales=3, wedges_per_direction=3)
>>> input_tensor = torch.randn(64, 64, dtype=torch.float64, requires_grad=True)
>>> output = udct(input_tensor)  # Returns flattened coefficients tensor
>>>
>>> # Create module using angular_wedges_config (advanced interface)
>>> cfg = torch.tensor([[3, 3], [6, 6]], dtype=torch.int64)
>>> udct2 = UDCTModule(shape=(64, 64), angular_wedges_config=cfg)
>>> output2 = udct2(input_tensor)
>>>
>>> # Create module with complex transform
>>> udct_complex = UDCTModule(
...     shape=(64, 64),
...     num_scales=3,
...     wedges_per_direction=3,
...     transform_type="complex"
... )
>>> output_complex = udct_complex(input_tensor)
>>>
>>> # Test with gradcheck
>>> torch.autograd.gradcheck(udct, input_tensor, atol=1e-5, rtol=1e-3)
True
>>>
>>> # Use struct() to convert flattened coefficients to nested structure
>>> coeffs_nested = udct.struct(output.detach())
property decimation_ratios: list[Tensor]

Decimation ratios for each scale.

forward(image)

Forward pass: compute forward transform and return flattened coefficients.

Parameters:

image (torch.Tensor) – Input image with shape matching self.shape.

Returns:

Flattened curvelet coefficients as a single tensor.

Return type:

torch.Tensor

property ndim: int

Number of dimensions.

property num_scales: int

Number of scales.

property shape: tuple[int, ...]

Shape of the transform.

struct(vector)

Restructure vectorized coefficients to nested list format.

Parameters:

vector (torch.Tensor) – 1D tensor of coefficients.

Returns:

Restructured coefficients. The dtype is preserved from the forward transform if available, otherwise uses the vector’s dtype.

Return type:

UDCTCoefficients

Examples

>>> import torch
>>> from curvelets.torch import UDCTModule
>>> udct = UDCTModule(shape=(64, 64), angular_wedges_config=torch.tensor([[3, 3]]))
>>> input_tensor = torch.randn(64, 64)
>>> output = udct(input_tensor)
>>> coeffs_nested = udct.struct(output.detach())
>>> len(coeffs_nested) > 0
True
vect(coefficients)

Vectorize curvelet coefficients.

Parameters:

coefficients (UDCTCoefficients) – Curvelet coefficients.

Returns:

1D tensor containing all coefficients.

Return type:

torch.Tensor

property windows: list[list[list[tuple[Tensor, Tensor]]]]

Curvelet windows in sparse format.

class curvelets.torch.UDCTModule(shape, angular_wedges_config=None, num_scales=None, wedges_per_direction=None, window_overlap=None, radial_frequency_params=None, window_threshold=1e-05, high_frequency_mode='curvelet', transform_type='real')

Bases: Module

PyTorch nn.Module wrapper for UDCT with autograd support.

PyTorch module with automatic differentiation. When called, it returns flattened coefficients as a single tensor, enabling gradient computation through the backward transform. The backward transform is automatically used in the autograd graph when computing gradients.

Parameters:
  • shape (tuple[int, ...]) – Shape of the input data.

  • angular_wedges_config (torch.Tensor, optional) – Configuration tensor specifying the number of angular wedges per scale and dimension. Shape is (num_scales - 1, dimension), where num_scales includes the lowpass scale. If provided, cannot be used together with num_scales/wedges_per_direction. Default is None.

  • num_scales (int, optional) – Total number of scales (including lowpass scale 0). Must be >= 2. Used when angular_wedges_config is not provided. Default is 3.

  • wedges_per_direction (int, optional) – Number of angular wedges per direction at the coarsest scale. The number of wedges doubles at each finer scale. Must be >= 3. Used when angular_wedges_config is not provided. Default is 3.

  • window_overlap (float, optional) – Window overlap parameter controlling the smoothness of window transitions. If None and using num_scales/wedges_per_direction, automatically chosen based on wedges_per_direction. Default is None (auto) or 0.15.

  • radial_frequency_params (tuple[float, float, float, float], optional) – Radial frequency band parameters. Default is (\(\pi/3\), \(2\pi/3\), \(2\pi/3\), \(4\pi/3\)).

  • window_threshold (float, optional) – Threshold for sparse window storage. Default is 1e-5.

  • high_frequency_mode ({"curvelet", "wavelet"}, optional) – High frequency mode. Default is “curvelet”.

  • transform_type ("real" or "complex", optional) –

    Type of transform to use:

    • "real": Real transform (default). Each band captures both positive and negative frequencies combined.

    • "complex": Complex transform. Positive and negative frequency bands are separated into different directions.

    Default is "real".

Examples

>>> import torch
>>> from curvelets.torch import UDCTModule
>>>
>>> # Create module using num_scales (simplified interface)
>>> udct = UDCTModule(shape=(64, 64), num_scales=3, wedges_per_direction=3)
>>> input_tensor = torch.randn(64, 64, dtype=torch.float64, requires_grad=True)
>>> output = udct(input_tensor)  # Returns flattened coefficients tensor
>>>
>>> # Create module using angular_wedges_config (advanced interface)
>>> cfg = torch.tensor([[3, 3], [6, 6]], dtype=torch.int64)
>>> udct2 = UDCTModule(shape=(64, 64), angular_wedges_config=cfg)
>>> output2 = udct2(input_tensor)
>>>
>>> # Create module with complex transform
>>> udct_complex = UDCTModule(
...     shape=(64, 64),
...     num_scales=3,
...     wedges_per_direction=3,
...     transform_type="complex"
... )
>>> output_complex = udct_complex(input_tensor)
>>>
>>> # Test with gradcheck
>>> torch.autograd.gradcheck(udct, input_tensor, atol=1e-5, rtol=1e-3)
True
>>>
>>> # Use struct() to convert flattened coefficients to nested structure
>>> coeffs_nested = udct.struct(output.detach())
property decimation_ratios: list[Tensor]

Decimation ratios for each scale.

forward(image)

Forward pass: compute forward transform and return flattened coefficients.

Parameters:

image (torch.Tensor) – Input image with shape matching self.shape.

Returns:

Flattened curvelet coefficients as a single tensor.

Return type:

torch.Tensor

property ndim: int

Number of dimensions.

property num_scales: int

Number of scales.

property shape: tuple[int, ...]

Shape of the transform.

struct(vector)

Restructure vectorized coefficients to nested list format.

Parameters:

vector (torch.Tensor) – 1D tensor of coefficients.

Returns:

Restructured coefficients. The dtype is preserved from the forward transform if available, otherwise uses the vector’s dtype.

Return type:

UDCTCoefficients

Examples

>>> import torch
>>> from curvelets.torch import UDCTModule
>>> udct = UDCTModule(shape=(64, 64), angular_wedges_config=torch.tensor([[3, 3]]))
>>> input_tensor = torch.randn(64, 64)
>>> output = udct(input_tensor)
>>> coeffs_nested = udct.struct(output.detach())
>>> len(coeffs_nested) > 0
True
vect(coefficients)

Vectorize curvelet coefficients.

Parameters:

coefficients (UDCTCoefficients) – Curvelet coefficients.

Returns:

1D tensor containing all coefficients.

Return type:

torch.Tensor

property windows: list[list[list[tuple[Tensor, Tensor]]]]

Curvelet windows in sparse format.