Wavelet Transform for Feature Extraction and Image Compression using Python and OpenCV
In this tutorial, we explore three important applications of the Discrete Wavelet Transform (DWT) using Python. You will learn how to:
- Extract features from an MRI image using wavelet decomposition.
- Perform image compression using DWT and hard thresholding.
- Compare compression results between FFT and DWT methods.
This tutorial is based on the experiment documented in :contentReference[oaicite:0]{index=0}.
Task 1: Feature Extraction using Wavelet Transform
We first load an MRI image in grayscale and then apply a 2D Discrete Wavelet Transform (DWT) using the Haar wavelet. We also extract 1D signals from the approximation coefficients and display the wavelet decomposition components.
Import Required Libraries
import numpy as np
import pywt
import cv2
import matplotlib.pyplot as plt
Load Image and Apply DWT
# Load MRI image in grayscale
def load_image(image_path):
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if img is None:
raise FileNotFoundError(f"Image not found: {image_path}")
return img
# Apply 2D DWT
def apply_dwt(img, wavelet='haar', level=2):
coeffs = pywt.wavedec2(img, wavelet, level=level)
return coeffs
# Extract a 1D signal from the 2D data (horizontal line by default)
def extract_signal(img, axis=0, index=None):
if axis == 0: # Horizontal line
if index is None:
index = img.shape[0] // 2
signal = img[index, :]
else: # Vertical line
if index is None:
index = img.shape[1] // 2
signal = img[:, index]
return signal
# Display wavelet decomposition components and their signals
def display_wavelet_decomposition(coeffs, original_img):
levels = len(coeffs) - 1
for i, (LH, HL, HH) in enumerate(coeffs[1:], start=1):
LL = coeffs[0] if i == 1 else coeffs[i-1][0]
fig, axes = plt.subplots(3, 3, figsize=(12, 10))
# Layout row 1: Display Original Image (for context)
axes[0, 0].imshow(original_img, cmap='gray')
axes[0, 0].set_title(f'Level {i} - Original Image')
axes[0, 0].axis("off")
for j in range(1, 3):
axes[0, j].axis("off")
# Row 2: Approximation Coefficients (LL) and its signal
axes[1, 0].imshow(LL, cmap='gray')
axes[1, 0].set_title(f'Level {i} - LL (Approximation)')
axes[1, 0].axis("off")
signal = extract_signal(LL)
axes[1, 1].plot(signal)
axes[1, 1].set_title(f'Level {i} - LL Signal')
axes[1, 1].set_xlabel("Pixel Index")
axes[1, 1].set_ylabel("Intensity")
axes[1, 2].axis("off")
# Row 3: Detail Coefficients LH, HL, and HH
axes[2, 0].imshow(LH, cmap='gray')
axes[2, 0].set_title(f'Level {i} - LH')
axes[2, 0].axis("off")
axes[2, 1].imshow(HL, cmap='gray')
axes[2, 1].set_title(f'Level {i} - HL')
axes[2, 1].axis("off")
axes[2, 2].imshow(HH, cmap='gray')
axes[2, 2].set_title(f'Level {i} - HH')
axes[2, 2].axis("off")
plt.tight_layout()
plt.show()
# Main function for Task 1
def main_wavelet_feature_extraction(image_path):
img = load_image(image_path)
coeffs = apply_dwt(img)
display_wavelet_decomposition(coeffs, img)
# Run Task 1
image_path = "mri_scan.jpg"
main_wavelet_feature_extraction(image_path)
Task 2: Image Compression using DWT
This section demonstrates how to compress an MRI image by applying DWT, using hard thresholding on the detail coefficients, and then reconstructing the image with inverse DWT (IDWT).
Compression Code
# Load image in grayscale
def load_image(image_path):
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if img is None:
raise FileNotFoundError(f"Image not found: {image_path}")
return img
# Apply 2D DWT for decomposition
def apply_dwt(img, wavelet='haar', level=2):
return pywt.wavedec2(img, wavelet, level=level)
# Apply hard thresholding for compression
def apply_thresholding(coeffs, threshold_ratio=0.1):
thresholded_coeffs = []
max_coeff = np.max(np.abs(coeffs[0])) # Use LL coefficients for threshold
threshold = threshold_ratio * max_coeff
thresholded_coeffs.append(coeffs[0]) # Keep LL unchanged
for details in coeffs[1:]:
thresholded_details = tuple(np.where(np.abs(subband) < threshold, 0, subband)
for subband in details)
thresholded_coeffs.append(thresholded_details)
return thresholded_coeffs
# Reconstruct image from thresholded coefficients
def reconstruct_image(coeffs, wavelet='haar'):
return pywt.waverec2(coeffs, wavelet)
# Compute compression metrics
def compute_compression_ratio(original, thresholded):
original_size = sum(arr.size for arr in pywt.coeffs_to_array(original)[0])
compressed_size = sum(np.count_nonzero(arr) for arr in pywt.coeffs_to_array(thresholded)[0])
return original_size / compressed_size
def compute_mse(original, reconstructed):
return np.mean((original - reconstructed) ** 2)
def compute_psnr(mse, max_pixel=255):
return 20 * np.log10(max_pixel / np.sqrt(mse)) if mse != 0 else np.inf
# Display original and compressed images
def display_images(original, compressed):
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(original, cmap='gray')
axes[0].set_title("Original MRI Image")
axes[0].axis("off")
axes[1].imshow(compressed, cmap='gray')
axes[1].set_title("Compressed MRI Image (DWT + Thresholding)")
axes[1].axis("off")
plt.show()
# Main function for Task 2
def main_image_compression(image_path):
img = load_image(image_path)
dwt_coeffs = apply_dwt(img)
thresholded_coeffs = apply_thresholding(dwt_coeffs)
reconstructed_img = reconstruct_image(thresholded_coeffs)
CR = compute_compression_ratio(dwt_coeffs, thresholded_coeffs)
MSE = compute_mse(img, reconstructed_img)
PSNR = compute_psnr(MSE)
print(f"Compression Ratio (CR): {CR:.2f}")
print(f"Mean Squared Error (MSE): {MSE:.4f}")
print(f"Peak Signal-to-Noise Ratio (PSNR): {PSNR:.2f} dB")
display_images(img, reconstructed_img)
# Run Task 2
image_path = "mri_scan.jpg"
main_image_compression(image_path)
Task 3: Comparison of FFT vs. DWT for Image Compression
In this final section, we compare two compression methods: one using DWT and the other using FFT. We apply hard thresholding to both the DWT coefficients and FFT domain, reconstruct the images, and then compute metrics such as Compression Ratio, MSE, PSNR, and computation time.
Comparison Code
import time
from tabulate import tabulate
# FFT based processing functions
def apply_fft(img):
return np.fft.fftshift(np.fft.fft2(img))
def threshold_fft(fft_image, threshold_ratio=0.1):
max_val = np.max(np.abs(fft_image))
threshold = threshold_ratio * max_val
fft_image[np.abs(fft_image) < threshold] = 0
return fft_image
def reconstruct_fft(fft_image):
return np.abs(np.fft.ifft2(np.fft.ifftshift(fft_image)))
# DWT functions already defined: apply_dwt, threshold_dwt, reconstruct_dwt
def threshold_dwt(coeffs, threshold_ratio=0.1):
threshold = threshold_ratio * np.max(coeffs[0])
coeffs_thresholded = [coeffs[0]]
for level in coeffs[1:]:
level_thresh = tuple(np.where(np.abs(subband) < threshold, 0, subband) for subband in level)
coeffs_thresholded.append(level_thresh)
return coeffs_thresholded
def reconstruct_dwt(coeffs):
return pywt.waverec2(coeffs, 'haar')
# Metric functions: compute_compression_ratio, compute_mse, compute_psnr are already defined
# Display images for comparison
def display_comparison_images(original, dwt_compressed, fft_compressed):
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
axes[0].imshow(original, cmap='gray')
axes[0].set_title("Original MRI Image")
axes[0].axis("off")
axes[1].imshow(dwt_compressed, cmap='gray')
axes[1].set_title("DWT Compressed Image")
axes[1].axis("off")
axes[2].imshow(fft_compressed, cmap='gray')
axes[2].set_title("FFT Compressed Image")
axes[2].axis("off")
plt.show()
# Main function for Task 3
def main_compression_comparison(image_path):
img = load_image(image_path)
# DWT processing
start_dwt = time.time()
dwt_coeffs = apply_dwt(img)
dwt_thresholded = threshold_dwt(dwt_coeffs)
dwt_compressed = reconstruct_dwt(dwt_thresholded)
dwt_time = time.time() - start_dwt
# FFT processing
start_fft = time.time()
fft_image = apply_fft(img)
thresholded_fft = threshold_fft(fft_image)
fft_compressed = reconstruct_fft(thresholded_fft)
fft_time = time.time() - start_fft
# Metrics Calculation
dwt_CR = compute_compression_ratio(img, dwt_thresholded[0])
fft_CR = compute_compression_ratio(img, thresholded_fft)
dwt_MSE = compute_mse(img, dwt_compressed)
fft_MSE = compute_mse(img, fft_compressed)
dwt_PSNR = compute_psnr(dwt_MSE)
fft_PSNR = compute_psnr(fft_MSE)
headers = ["Metric", "DWT", "FFT"]
table = [
["Compression Ratio (CR)", f"{dwt_CR:.2f}", f"{fft_CR:.2f}"],
["Mean Squared Error (MSE)", f"{dwt_MSE:.4f}", f"{fft_MSE:.4f}"],
["Peak Signal-to-Noise Ratio (PSNR) (dB)", f"{dwt_PSNR:.2f}", f"{fft_PSNR:.2f}"],
["Computation Time (sec)", f"{dwt_time:.4f}", f"{fft_time:.4f}"]
]
print("\n=== Compression Results Comparison ===")
print(tabulate(table, headers=headers, tablefmt="grid"))
display_comparison_images(img, dwt_compressed, fft_compressed)
# Run Task 3
image_path = "mri_scan.jpg"
main_compression_comparison(image_path)
Conclusion
In this tutorial, you learned how to use the Discrete Wavelet Transform (DWT) for feature extraction and image compression. We also compared the performance of FFT versus DWT for compression. Experiment with the parameters (such as threshold ratios and decomposition levels) to optimize these techniques for your application.
Comments
Post a Comment