Wavelet Transform for Feature Extraction and Image Compression using Python and OpenCV

Wavelet Transform for Feature Extraction & Image Compression

In this tutorial, we explore three important applications of the Discrete Wavelet Transform (DWT) using Python. You will learn how to:

  1. Extract features from an MRI image using wavelet decomposition.
  2. Perform image compression using DWT and hard thresholding.
  3. Compare compression results between FFT and DWT methods.

This tutorial is based on the experiment documented in ​:contentReference[oaicite:0]{index=0}.

Task 1: Feature Extraction using Wavelet Transform

We first load an MRI image in grayscale and then apply a 2D Discrete Wavelet Transform (DWT) using the Haar wavelet. We also extract 1D signals from the approximation coefficients and display the wavelet decomposition components.

Import Required Libraries

import numpy as np
import pywt
import cv2
import matplotlib.pyplot as plt

Load Image and Apply DWT

# Load MRI image in grayscale
def load_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Image not found: {image_path}")
    return img

# Apply 2D DWT
def apply_dwt(img, wavelet='haar', level=2):
    coeffs = pywt.wavedec2(img, wavelet, level=level)
    return coeffs

# Extract a 1D signal from the 2D data (horizontal line by default)
def extract_signal(img, axis=0, index=None):
    if axis == 0:  # Horizontal line
        if index is None:
            index = img.shape[0] // 2
        signal = img[index, :]
    else:  # Vertical line
        if index is None:
            index = img.shape[1] // 2
        signal = img[:, index]
    return signal

# Display wavelet decomposition components and their signals
def display_wavelet_decomposition(coeffs, original_img):
    levels = len(coeffs) - 1
    for i, (LH, HL, HH) in enumerate(coeffs[1:], start=1):
        LL = coeffs[0] if i == 1 else coeffs[i-1][0]
        fig, axes = plt.subplots(3, 3, figsize=(12, 10))
        
        # Layout row 1: Display Original Image (for context)
        axes[0, 0].imshow(original_img, cmap='gray')
        axes[0, 0].set_title(f'Level {i} - Original Image')
        axes[0, 0].axis("off")
        for j in range(1, 3):
            axes[0, j].axis("off")
        
        # Row 2: Approximation Coefficients (LL) and its signal
        axes[1, 0].imshow(LL, cmap='gray')
        axes[1, 0].set_title(f'Level {i} - LL (Approximation)')
        axes[1, 0].axis("off")
        signal = extract_signal(LL)
        axes[1, 1].plot(signal)
        axes[1, 1].set_title(f'Level {i} - LL Signal')
        axes[1, 1].set_xlabel("Pixel Index")
        axes[1, 1].set_ylabel("Intensity")
        axes[1, 2].axis("off")
        
        # Row 3: Detail Coefficients LH, HL, and HH
        axes[2, 0].imshow(LH, cmap='gray')
        axes[2, 0].set_title(f'Level {i} - LH')
        axes[2, 0].axis("off")
        axes[2, 1].imshow(HL, cmap='gray')
        axes[2, 1].set_title(f'Level {i} - HL')
        axes[2, 1].axis("off")
        axes[2, 2].imshow(HH, cmap='gray')
        axes[2, 2].set_title(f'Level {i} - HH')
        axes[2, 2].axis("off")
        
        plt.tight_layout()
        plt.show()

# Main function for Task 1
def main_wavelet_feature_extraction(image_path):
    img = load_image(image_path)
    coeffs = apply_dwt(img)
    display_wavelet_decomposition(coeffs, img)

# Run Task 1
image_path = "mri_scan.jpg"  
main_wavelet_feature_extraction(image_path)

Task 2: Image Compression using DWT

This section demonstrates how to compress an MRI image by applying DWT, using hard thresholding on the detail coefficients, and then reconstructing the image with inverse DWT (IDWT).

Compression Code

# Load image in grayscale
def load_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Image not found: {image_path}")
    return img

# Apply 2D DWT for decomposition
def apply_dwt(img, wavelet='haar', level=2):
    return pywt.wavedec2(img, wavelet, level=level)

# Apply hard thresholding for compression
def apply_thresholding(coeffs, threshold_ratio=0.1):
    thresholded_coeffs = []
    max_coeff = np.max(np.abs(coeffs[0]))  # Use LL coefficients for threshold
    threshold = threshold_ratio * max_coeff
    thresholded_coeffs.append(coeffs[0])  # Keep LL unchanged
    for details in coeffs[1:]:
        thresholded_details = tuple(np.where(np.abs(subband) < threshold, 0, subband)
                                      for subband in details)
        thresholded_coeffs.append(thresholded_details)
    return thresholded_coeffs

# Reconstruct image from thresholded coefficients
def reconstruct_image(coeffs, wavelet='haar'):
    return pywt.waverec2(coeffs, wavelet)

# Compute compression metrics
def compute_compression_ratio(original, thresholded):
    original_size = sum(arr.size for arr in pywt.coeffs_to_array(original)[0])
    compressed_size = sum(np.count_nonzero(arr) for arr in pywt.coeffs_to_array(thresholded)[0])
    return original_size / compressed_size

def compute_mse(original, reconstructed):
    return np.mean((original - reconstructed) ** 2)

def compute_psnr(mse, max_pixel=255):
    return 20 * np.log10(max_pixel / np.sqrt(mse)) if mse != 0 else np.inf

# Display original and compressed images
def display_images(original, compressed):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(original, cmap='gray')
    axes[0].set_title("Original MRI Image")
    axes[0].axis("off")
    axes[1].imshow(compressed, cmap='gray')
    axes[1].set_title("Compressed MRI Image (DWT + Thresholding)")
    axes[1].axis("off")
    plt.show()

# Main function for Task 2
def main_image_compression(image_path):
    img = load_image(image_path)
    dwt_coeffs = apply_dwt(img)
    thresholded_coeffs = apply_thresholding(dwt_coeffs)
    reconstructed_img = reconstruct_image(thresholded_coeffs)
    
    CR = compute_compression_ratio(dwt_coeffs, thresholded_coeffs)
    MSE = compute_mse(img, reconstructed_img)
    PSNR = compute_psnr(MSE)
    
    print(f"Compression Ratio (CR): {CR:.2f}")
    print(f"Mean Squared Error (MSE): {MSE:.4f}")
    print(f"Peak Signal-to-Noise Ratio (PSNR): {PSNR:.2f} dB")
    
    display_images(img, reconstructed_img)

# Run Task 2
image_path = "mri_scan.jpg"  
main_image_compression(image_path)

Task 3: Comparison of FFT vs. DWT for Image Compression

In this final section, we compare two compression methods: one using DWT and the other using FFT. We apply hard thresholding to both the DWT coefficients and FFT domain, reconstruct the images, and then compute metrics such as Compression Ratio, MSE, PSNR, and computation time.

Comparison Code

import time
from tabulate import tabulate

# FFT based processing functions
def apply_fft(img):
    return np.fft.fftshift(np.fft.fft2(img))

def threshold_fft(fft_image, threshold_ratio=0.1):
    max_val = np.max(np.abs(fft_image))
    threshold = threshold_ratio * max_val
    fft_image[np.abs(fft_image) < threshold] = 0
    return fft_image

def reconstruct_fft(fft_image):
    return np.abs(np.fft.ifft2(np.fft.ifftshift(fft_image)))

# DWT functions already defined: apply_dwt, threshold_dwt, reconstruct_dwt
def threshold_dwt(coeffs, threshold_ratio=0.1):
    threshold = threshold_ratio * np.max(coeffs[0])
    coeffs_thresholded = [coeffs[0]]
    for level in coeffs[1:]:
        level_thresh = tuple(np.where(np.abs(subband) < threshold, 0, subband) for subband in level)
        coeffs_thresholded.append(level_thresh)
    return coeffs_thresholded

def reconstruct_dwt(coeffs):
    return pywt.waverec2(coeffs, 'haar')

# Metric functions: compute_compression_ratio, compute_mse, compute_psnr are already defined

# Display images for comparison
def display_comparison_images(original, dwt_compressed, fft_compressed):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].imshow(original, cmap='gray')
    axes[0].set_title("Original MRI Image")
    axes[0].axis("off")
    axes[1].imshow(dwt_compressed, cmap='gray')
    axes[1].set_title("DWT Compressed Image")
    axes[1].axis("off")
    axes[2].imshow(fft_compressed, cmap='gray')
    axes[2].set_title("FFT Compressed Image")
    axes[2].axis("off")
    plt.show()

# Main function for Task 3
def main_compression_comparison(image_path):
    img = load_image(image_path)
    
    # DWT processing
    start_dwt = time.time()
    dwt_coeffs = apply_dwt(img)
    dwt_thresholded = threshold_dwt(dwt_coeffs)
    dwt_compressed = reconstruct_dwt(dwt_thresholded)
    dwt_time = time.time() - start_dwt
    
    # FFT processing
    start_fft = time.time()
    fft_image = apply_fft(img)
    thresholded_fft = threshold_fft(fft_image)
    fft_compressed = reconstruct_fft(thresholded_fft)
    fft_time = time.time() - start_fft
    
    # Metrics Calculation
    dwt_CR = compute_compression_ratio(img, dwt_thresholded[0])
    fft_CR = compute_compression_ratio(img, thresholded_fft)
    dwt_MSE = compute_mse(img, dwt_compressed)
    fft_MSE = compute_mse(img, fft_compressed)
    dwt_PSNR = compute_psnr(dwt_MSE)
    fft_PSNR = compute_psnr(fft_MSE)
    
    headers = ["Metric", "DWT", "FFT"]
    table = [
        ["Compression Ratio (CR)", f"{dwt_CR:.2f}", f"{fft_CR:.2f}"],
        ["Mean Squared Error (MSE)", f"{dwt_MSE:.4f}", f"{fft_MSE:.4f}"],
        ["Peak Signal-to-Noise Ratio (PSNR) (dB)", f"{dwt_PSNR:.2f}", f"{fft_PSNR:.2f}"],
        ["Computation Time (sec)", f"{dwt_time:.4f}", f"{fft_time:.4f}"]
    ]
    
    print("\n=== Compression Results Comparison ===")
    print(tabulate(table, headers=headers, tablefmt="grid"))
    
    display_comparison_images(img, dwt_compressed, fft_compressed)

# Run Task 3
image_path = "mri_scan.jpg"  
main_compression_comparison(image_path)

Conclusion

In this tutorial, you learned how to use the Discrete Wavelet Transform (DWT) for feature extraction and image compression. We also compared the performance of FFT versus DWT for compression. Experiment with the parameters (such as threshold ratios and decomposition levels) to optimize these techniques for your application.

Comments

Popular posts from this blog

Texture Classification with GLCM and LBP Features in Python

Contrast Enhancement Techniques in Image Processing using Python