.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_10_udct_mnist.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_10_udct_mnist.py: UDCT MNIST Classification ========================= This example demonstrates how to use ``UDCTModule`` as a feature extractor for image classification on the MNIST dataset. Instead of using convolutional layers, this network: 1. Applies the Uniform Discrete Curvelet Transform (UDCT) to each image 2. Uses all curvelet coefficients as features (every pixel in every wedge) 3. Passes features through a two-layer MLP to classify into 10 digit classes The key insight is that curvelet coefficients capture directional information at multiple scales, providing a meaningful representation for classification. Architecture Overview ##################### - **Input**: MNIST images (28x28 grayscale) - **Feature extraction**: UDCTModule transforms each image into curvelet coefficients - **Features**: All coefficient values (every pixel in every wedge) form the feature vector - **Classification**: Two-layer MLP (Linear -> ReLU -> Linear) maps features to 10 classes - **Batch processing**: ``torch.vmap`` enables efficient batched inference **Credits** This example is adapted from the `PyTorch MNIST example `_. .. GENERATED FROM PYTHON SOURCE LINES 30-35 .. code-block:: Python # sphinx_gallery_thumbnail_number = 2 from __future__ import annotations .. GENERATED FROM PYTHON SOURCE LINES 36-46 .. code-block:: Python import matplotlib.pyplot as plt import torch import torch.nn.functional as F from sklearn.manifold import TSNE from torch import nn, optim from torch.optim.lr_scheduler import StepLR from torchvision import datasets, transforms from curvelets.torch import UDCTModule .. GENERATED FROM PYTHON SOURCE LINES 47-58 UDCTNet: Curvelet-Based Classifier ################################### This network replaces convolutional layers with the curvelet transform. The key components are: 1. ``UDCTModule``: Computes curvelet coefficients for each image 2. Feature extraction: Uses all coefficient values as features 3. Two-layer MLP: Maps features to class probabilities We use ``torch.vmap`` to efficiently process batches of images. .. GENERATED FROM PYTHON SOURCE LINES 58-104 .. code-block:: Python class UDCTNet(nn.Module): # type: ignore[misc] """Neural network using UDCT for feature extraction.""" def __init__( self, shape: tuple[int, int] = (28, 28), num_scales: int = 2, wedges_per_direction: int = 3, hidden_size: int = 128, ) -> None: super().__init__() self.udct = UDCTModule( shape=shape, num_scales=num_scales, wedges_per_direction=wedges_per_direction, ) # Precompute number of features via dummy forward pass during init # UDCTModule returns flattened coefficients (all pixels in all wedges) with torch.inference_mode(): dummy = torch.zeros(shape) n_features = self.udct(dummy).numel() self.fc1 = nn.Linear(n_features, hidden_size) self.fc2 = nn.Linear(hidden_size, 10) def _extract_features(self, x: torch.Tensor) -> torch.Tensor: """Extract all curvelet coefficients from a single 2D image.""" # UDCTModule returns flattened coefficients - take abs to get real features return self.udct(x).abs() def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with batched curvelet feature extraction.""" # x: (batch, 1, 28, 28) -> squeeze to (batch, 28, 28) x = x.squeeze(1) # Use vmap for batch processing features = torch.vmap(self._extract_features)(x) x = F.relu(self.fc1(features)) return F.log_softmax(self.fc2(x), dim=1) def get_features(self, x: torch.Tensor) -> torch.Tensor: """Extract features without classification (for visualization).""" x = x.squeeze(1) return torch.vmap(self._extract_features)(x) .. GENERATED FROM PYTHON SOURCE LINES 105-109 Training and Testing Functions ############################## Standard PyTorch training loop with negative log-likelihood loss. .. GENERATED FROM PYTHON SOURCE LINES 109-235 .. code-block:: Python def train( model: nn.Module, device: torch.device, train_loader: torch.utils.data.DataLoader, optimizer: optim.Optimizer, ) -> tuple[float, float]: """Train the model for one epoch and return average loss and accuracy.""" model.train() total_loss = 0.0 correct = 0 for _, (data, target) in enumerate(train_loader): data_device = data.to(device) target_device = target.to(device) optimizer.zero_grad() output = model(data_device) loss = F.nll_loss(output, target_device) loss.backward() optimizer.step() total_loss += loss.item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target_device.view_as(pred)).sum().item() avg_loss = total_loss / len(train_loader) accuracy = 100.0 * correct / len(train_loader.dataset) return avg_loss, accuracy def test( model: nn.Module, device: torch.device, test_loader: torch.utils.data.DataLoader, ) -> tuple[float, float]: """Evaluate the model on the test set and return average loss and accuracy.""" model.eval() test_loss = 0.0 correct = 0 with torch.inference_mode(): for data, target in test_loader: data_device = data.to(device) target_device = target.to(device) output = model(data_device) test_loss += F.nll_loss(output, target_device, reduction="sum").item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target_device.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) return test_loss, accuracy def collect_worst_loss_misclassifications( model: nn.Module, device: torch.device, test_loader: torch.utils.data.DataLoader, ) -> dict[int, tuple[torch.Tensor, int, int, float]]: """Collect worst loss misclassified examples grouped by ground truth digit. For each digit class (0-9), finds the misclassified example with the highest loss value. Parameters ---------- model : nn.Module Trained model to evaluate. device : torch.device Device to run inference on. test_loader : torch.utils.data.DataLoader DataLoader for test dataset. Returns ------- dict[int, tuple[torch.Tensor, int, int, float]] Dictionary mapping ground truth digit (0-9) to a tuple of (image tensor, ground truth label, predicted label, loss value). Only contains digits that have misclassifications. """ model.eval() worst_misclassifications: dict[int, tuple[torch.Tensor, int, int, float]] = {} with torch.inference_mode(): for data, target in test_loader: data_device = data.to(device) target_device = target.to(device) output = model(data_device) pred = output.argmax(dim=1) # Compute loss for each sample # F.nll_loss with reduction='none' gives per-sample losses per_sample_loss = F.nll_loss(output, target_device, reduction="none") # Find misclassifications incorrect_mask = pred != target_device incorrect_indices = incorrect_mask.nonzero(as_tuple=True)[0] for idx in incorrect_indices: true_label = int(target_device[idx].item()) pred_label = int(pred[idx].item()) loss_value = float(per_sample_loss[idx].item()) # Store worst loss example for each digit class if true_label not in worst_misclassifications: worst_misclassifications[true_label] = ( data[idx].cpu(), true_label, pred_label, loss_value, ) else: # Update if this example has higher loss _, _, _, current_loss = worst_misclassifications[true_label] if loss_value > current_loss: worst_misclassifications[true_label] = ( data[idx].cpu(), true_label, pred_label, loss_value, ) # Stop if we have examples for all 10 digits if len(worst_misclassifications) == 10: break return worst_misclassifications .. GENERATED FROM PYTHON SOURCE LINES 236-244 Main Training Script #################### We use a simplified configuration suitable for a gallery example: - 10 epochs - Batch size of 64 - Adadelta optimizer with learning rate scheduling .. GENERATED FROM PYTHON SOURCE LINES 244-262 .. code-block:: Python # Device selection if torch.accelerator.is_available(): device = torch.accelerator.current_accelerator() else: device = torch.device("cpu") # Data loading transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform) test_dataset = datasets.MNIST("./data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.3% 0.7% 1.0% 1.3% 1.7% 2.0% 2.3% 2.6% 3.0% 3.3% 3.6% 4.0% 4.3% 4.6% 5.0% 5.3% 5.6% 6.0% 6.3% 6.6% 6.9% 7.3% 7.6% 7.9% 8.3% 8.6% 8.9% 9.3% 9.6% 9.9% 10.2% 10.6% 10.9% 11.2% 11.6% 11.9% 12.2% 12.6% 12.9% 13.2% 13.6% 13.9% 14.2% 14.5% 14.9% 15.2% 15.5% 15.9% 16.2% 16.5% 16.9% 17.2% 17.5% 17.9% 18.2% 18.5% 18.8% 19.2% 19.5% 19.8% 20.2% 20.5% 20.8% 21.2% 21.5% 21.8% 22.1% 22.5% 22.8% 23.1% 23.5% 23.8% 24.1% 24.5% 24.8% 25.1% 25.5% 25.8% 26.1% 26.4% 26.8% 27.1% 27.4% 27.8% 28.1% 28.4% 28.8% 29.1% 29.4% 29.8% 30.1% 30.4% 30.7% 31.1% 31.4% 31.7% 32.1% 32.4% 32.7% 33.1% 33.4% 33.7% 34.0% 34.4% 34.7% 35.0% 35.4% 35.7% 36.0% 36.4% 36.7% 37.0% 37.4% 37.7% 38.0% 38.3% 38.7% 39.0% 39.3% 39.7% 40.0% 40.3% 40.7% 41.0% 41.3% 41.7% 42.0% 42.3% 42.6% 43.0% 43.3% 43.6% 44.0% 44.3% 44.6% 45.0% 45.3% 45.6% 45.9% 46.3% 46.6% 46.9% 47.3% 47.6% 47.9% 48.3% 48.6% 48.9% 49.3% 49.6% 49.9% 50.2% 50.6% 50.9% 51.2% 51.6% 51.9% 52.2% 52.6% 52.9% 53.2% 53.6% 53.9% 54.2% 54.5% 54.9% 55.2% 55.5% 55.9% 56.2% 56.5% 56.9% 57.2% 57.5% 57.9% 58.2% 58.5% 58.8% 59.2% 59.5% 59.8% 60.2% 60.5% 60.8% 61.2% 61.5% 61.8% 62.1% 62.5% 62.8% 63.1% 63.5% 63.8% 64.1% 64.5% 64.8% 65.1% 65.5% 65.8% 66.1% 66.4% 66.8% 67.1% 67.4% 67.8% 68.1% 68.4% 68.8% 69.1% 69.4% 69.8% 70.1% 70.4% 70.7% 71.1% 71.4% 71.7% 72.1% 72.4% 72.7% 73.1% 73.4% 73.7% 74.0% 74.4% 74.7% 75.0% 75.4% 75.7% 76.0% 76.4% 76.7% 77.0% 77.4% 77.7% 78.0% 78.3% 78.7% 79.0% 79.3% 79.7% 80.0% 80.3% 80.7% 81.0% 81.3% 81.7% 82.0% 82.3% 82.6% 83.0% 83.3% 83.6% 84.0% 84.3% 84.6% 85.0% 85.3% 85.6% 85.9% 86.3% 86.6% 86.9% 87.3% 87.6% 87.9% 88.3% 88.6% 88.9% 89.3% 89.6% 89.9% 90.2% 90.6% 90.9% 91.2% 91.6% 91.9% 92.2% 92.6% 92.9% 93.2% 93.6% 93.9% 94.2% 94.5% 94.9% 95.2% 95.5% 95.9% 96.2% 96.5% 96.9% 97.2% 97.5% 97.9% 98.2% 98.5% 98.8% 99.2% 99.5% 99.8% 100.0% 100.0% 2.0% 4.0% 6.0% 7.9% 9.9% 11.9% 13.9% 15.9% 17.9% 19.9% 21.9% 23.8% 25.8% 27.8% 29.8% 31.8% 33.8% 35.8% 37.8% 39.7% 41.7% 43.7% 45.7% 47.7% 49.7% 51.7% 53.7% 55.6% 57.6% 59.6% 61.6% 63.6% 65.6% 67.6% 69.6% 71.5% 73.5% 75.5% 77.5% 79.5% 81.5% 83.5% 85.5% 87.4% 89.4% 91.4% 93.4% 95.4% 97.4% 99.4% 100.0% 100.0% .. GENERATED FROM PYTHON SOURCE LINES 263-269 Model Initialization #################### Create the UDCTNet model with 2 scales and 3 wedges per direction. With all curvelet coefficients as features, we get a high-dimensional feature vector that preserves all transform information. .. GENERATED FROM PYTHON SOURCE LINES 269-272 .. code-block:: Python model = UDCTNet(shape=(28, 28), num_scales=2, wedges_per_direction=3).to(device) .. GENERATED FROM PYTHON SOURCE LINES 273-277 Training Loop ############# Train for 10 epochs with Adadelta optimizer and step learning rate decay. .. GENERATED FROM PYTHON SOURCE LINES 277-296 .. code-block:: Python optimizer = optim.Adadelta(model.parameters(), lr=1.0) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) num_epochs = 10 train_losses: list[float] = [] test_losses: list[float] = [] train_accuracies: list[float] = [] test_accuracies: list[float] = [] for _ in range(1, num_epochs + 1): train_loss, train_acc = train(model, device, train_loader, optimizer) test_loss, test_acc = test(model, device, test_loader) train_losses.append(train_loss) test_losses.append(test_loss) train_accuracies.append(train_acc) test_accuracies.append(test_acc) scheduler.step() .. GENERATED FROM PYTHON SOURCE LINES 297-301 Loss and Accuracy Plot ###################### Visualize the training and test loss and accuracy over epochs. .. GENERATED FROM PYTHON SOURCE LINES 301-304 .. code-block:: Python epochs = range(1, num_epochs + 1) .. GENERATED FROM PYTHON SOURCE LINES 305-334 .. code-block:: Python fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) # Loss subplot ax1.plot(epochs, train_losses, "o-", label="Train Loss", linewidth=2, markersize=8) ax1.plot(epochs, test_losses, "s-", label="Test Loss", linewidth=2, markersize=8) ax1.set_xlabel("Epoch", fontsize=12) ax1.set_ylabel("Loss", fontsize=12) ax1.set_title("Training and Test Loss", fontsize=14) ax1.legend(fontsize=11) ax1.grid(True, alpha=0.3) ax1.set_xticks(epochs) # Accuracy subplot ax2.plot( epochs, train_accuracies, "o-", label="Train Accuracy", linewidth=2, markersize=8 ) ax2.plot( epochs, test_accuracies, "s-", label="Test Accuracy", linewidth=2, markersize=8 ) ax2.set_xlabel("Epoch", fontsize=12) ax2.set_ylabel("Accuracy (%)", fontsize=12) ax2.set_title("Training and Test Accuracy", fontsize=14) ax2.legend(fontsize=11) ax2.grid(True, alpha=0.3) ax2.set_xticks(epochs) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_plot_10_udct_mnist_001.png :alt: Training and Test Loss, Training and Test Accuracy :srcset: /auto_examples/images/sphx_glr_plot_10_udct_mnist_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 335-343 t-SNE Feature Visualization ########################### Visualize the curvelet features in 2D using t-SNE. Each digit class is shown in a different color, revealing how well the features separate the classes. t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear dimensionality reduction technique that preserves local structure. .. GENERATED FROM PYTHON SOURCE LINES 343-361 .. code-block:: Python # Extract features from a subset of training data for visualization n_samples = 2000 subset_loader = torch.utils.data.DataLoader( train_dataset, batch_size=n_samples, shuffle=True ) data_batch, labels_batch = next(iter(subset_loader)) # Extract features using the trained model model.eval() with torch.inference_mode(): features_batch = model.get_features(data_batch.to(device)).cpu().numpy() labels_np = labels_batch.numpy() # Apply t-SNE to reduce to 2 dimensions tsne = TSNE(n_components=2, random_state=42, perplexity=30) features_tsne = tsne.fit_transform(features_batch) .. GENERATED FROM PYTHON SOURCE LINES 362-381 .. code-block:: Python fig, ax = plt.subplots(figsize=(10, 8)) scatter = ax.scatter( features_tsne[:, 0], features_tsne[:, 1], c=labels_np, cmap="tab10", vmin=-0.5, vmax=9.5, alpha=0.6, s=10, ) cbar = plt.colorbar(scatter, ax=ax, ticks=range(10)) cbar.set_label("Digit Class", fontsize=12) ax.set_xlabel("t-SNE 1", fontsize=12) ax.set_ylabel("t-SNE 2", fontsize=12) ax.set_title("UDCT Features: 2D t-SNE Projection", fontsize=14) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_plot_10_udct_mnist_002.png :alt: UDCT Features: 2D t-SNE Projection :srcset: /auto_examples/images/sphx_glr_plot_10_udct_mnist_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 382-388 Misclassification Visualization ############################### Display the worst loss misclassified example for each digit class (0-9) in a 2x5 grid. Each subplot shows the image with a title indicating the ground truth digit and the predicted digit. For each digit, the example with the highest loss is shown. .. GENERATED FROM PYTHON SOURCE LINES 388-394 .. code-block:: Python # Collect worst loss misclassifications from test set worst_misclassifications = collect_worst_loss_misclassifications( model, device, test_loader ) .. GENERATED FROM PYTHON SOURCE LINES 395-431 .. code-block:: Python fig, axes = plt.subplots(2, 5, figsize=(12, 5)) axes = axes.flatten() # Denormalization parameters (reverse of Normalize((0.1307,), (0.3081,))) mean = 0.1307 std = 0.3081 for digit in range(10): ax = axes[digit] if digit in worst_misclassifications: img, true_label, pred_label, loss_value = worst_misclassifications[digit] # Denormalize image for display img_denorm = img.squeeze(0) * std + mean # Clamp to [0, 1] range img_denorm = torch.clamp(img_denorm, 0, 1) ax.imshow(img_denorm.numpy(), cmap="gray") ax.set_title( f"Ground truth: {true_label}, Predicted: {pred_label}", fontsize=11 ) else: # Handle case where digit has no misclassifications ax.text( 0.5, 0.5, f"No misclassifications\nfor digit {digit}", ha="center", va="center", transform=ax.transAxes, fontsize=10, ) ax.set_title(f"Digit {digit}", fontsize=11) ax.axis("off") plt.suptitle("Worst Loss Misclassified Examples by Digit Class", fontsize=14, y=1.02) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_plot_10_udct_mnist_003.png :alt: Worst Loss Misclassified Examples by Digit Class, Ground truth: 0, Predicted: 2, Ground truth: 1, Predicted: 2, Ground truth: 2, Predicted: 7, Ground truth: 3, Predicted: 5, Ground truth: 4, Predicted: 9, Ground truth: 5, Predicted: 3, Ground truth: 6, Predicted: 5, Ground truth: 7, Predicted: 8, Ground truth: 8, Predicted: 0, Ground truth: 9, Predicted: 4 :srcset: /auto_examples/images/sphx_glr_plot_10_udct_mnist_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (3 minutes 6.149 seconds) .. _sphx_glr_download_auto_examples_plot_10_udct_mnist.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_10_udct_mnist.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_10_udct_mnist.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_10_udct_mnist.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_