Building a CNN for Satellite Imagery Classification

Build a CNN for satellite imagery classification: geospatial raster I/O with rasterio, spectral normalization, spatially blocked cross-validation, and production inference deployment.

Building a CNN for Satellite Imagery Classification requires a strict, geospatial-aware pipeline that handles multi-band raster I/O, explicit spatial feature engineering, and consistent coordinate reference systems (CRS) before feeding tensors into a convolutional architecture. Standard computer vision workflows fail on remote sensing data due to spectral variability, atmospheric interference, and geographic leakage during cross-validation. A production-ready workflow consists of five stages: (1) streaming Cloud-Optimized GeoTIFFs (COGs) with memory-efficient tiling, (2) per-scene spectral normalization and index computation, (3) adapting a standard CNN backbone for variable band counts, (4) training with spatially blocked cross-validation to prevent leakage, and (5) deploying inference with automated drift detection and CRS-aware vectorization.

1. Geospatial-Aware Raster I/O & CRS Validation

Satellite imagery introduces constraints that standard image loaders ignore. Raw .tif files contain georeferencing metadata, affine transforms, and variable band orders that must survive the training-to-inference transition. Always validate projection alignment before model ingestion; mismatched CRS values cause silent spatial misregistration during inference. Use rasterio for deterministic, windowed reads that stream COGs directly from cloud storage without full-file downloads. The rasterio documentation provides authoritative guidance on windowed I/O, profile management, and CRS handling. When reading patches, preserve the affine transform and CRS in metadata dictionaries so downstream vector alignment or map tiling remains accurate.

2. Spatial Feature Engineering & Normalization

Convolutional networks assume stationary pixel distributions, but satellite sensors capture top-of-atmosphere (TOA) or surface reflectance values that shift with season, latitude, and sensor calibration. Normalize each scene independently using per-band percentiles (e.g., 2nd–98th) or scene-level mean/variance to stabilize gradients. Compute derived indices like NDVI, NDWI, or NDBI as additional input channels; these explicit spatial features accelerate convergence and improve class separability for land-cover tasks.

When spatial relationships extend beyond local receptive fields, CNNs may struggle with regional connectivity or topological constraints. In such cases, integrating methodologies from Training Geospatial Predictive Models in Python ensures your pipeline accounts for spatial autocorrelation and scale-dependent feature interactions. For adjacency-aware modeling where land-cover transitions depend on watershed boundaries or road networks rather than local texture, Graph Neural Networks for Spatial Data offer a complementary architecture that explicitly models spatial graphs alongside raster features.

3. CNN Architecture & Spatial Cross-Validation

Standard vision models expect 3-channel RGB inputs. Satellite sensors like Sentinel-2 provide 10–13 spectral bands. Modify the first convolutional layer to accept in_channels=len(bands) and initialize weights by averaging pretrained RGB kernels across the new spectral dimension. Below is a complete, production-ready PyTorch pipeline that handles multi-band COGs, applies per-scene normalization, and enforces spatial leakage prevention.

import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights

class SatelliteDataset(Dataset):
    def __init__(self, img_dir: str, label_map: dict, bands: tuple = (1, 2, 3, 7, 8), patch_size: int = 256):
        self.img_paths = sorted([
            os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.tif')
        ])
        self.label_map = label_map  # {filename: class_id}
        self.bands = bands
        self.patch_size = patch_size

    def __len__(self) -> int:
        return len(self.img_paths)

    def __getitem__(self, idx: int):
        path = self.img_paths[idx]
        with rasterio.open(path) as src:
            # Read only selected bands (1-indexed for Sentinel-2)
            data = src.read(self.bands).astype(np.float32)
            # Mask invalid/cloud pixels (0 or NaN)
            data[data == 0] = np.nan
            # Per-scene percentile normalization (2-98)
            low, high = np.nanpercentile(data, [2, 98], axis=(1, 2), keepdims=True)
            data = (data - low) / (high - low + 1e-6)
            data = np.nan_to_num(data, nan=0.0)
            tensor = torch.from_numpy(data)
            # Center crop to patch_size
            if tensor.shape[1] > self.patch_size:
                y = (tensor.shape[1] - self.patch_size) // 2
                x = (tensor.shape[2] - self.patch_size) // 2
                tensor = tensor[:, y:y+self.patch_size, x:x+self.patch_size]
        label = self.label_map.get(os.path.basename(path), 0)
        return tensor, torch.tensor(label, dtype=torch.long)

class MultiBandResNet(nn.Module):
    def __init__(self, in_channels: int, num_classes: int = 10):
        super().__init__()
        self.model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        old_conv = self.model.conv1
        self.model.conv1 = nn.Conv2d(
            in_channels, old_conv.out_channels,
            kernel_size=old_conv.kernel_size, stride=old_conv.stride,
            padding=old_conv.padding, bias=old_conv.bias is not None
        )
        # Initialize new weights by averaging RGB pretrained kernels
        with torch.no_grad():
            self.model.conv1.weight.copy_(
                old_conv.weight.mean(dim=1, keepdim=True).repeat(1, in_channels, 1, 1)
            )
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

# Training loop setup
# dataset = SatelliteDataset("data/tiles/", {"scene1.tif": 0, "scene2.tif": 1})
# loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)

To prevent geographic leakage, never use random train/test splits. Group samples by spatial tiles, administrative boundaries, or acquisition dates and apply blocked cross-validation. The PyTorch documentation outlines best practices for custom Sampler implementations that enforce geographic isolation during hyperparameter tuning.

4. MLOps, Inference & Automated Drift Detection

Deployment requires preserving the exact preprocessing logic used during training. Wrap normalization, band selection, and CRS validation in a reproducible inference container. Track input distributions using statistical drift metrics (e.g., Jensen-Shannon divergence on per-band histograms or Earth Mover’s Distance on NDVI distributions) and trigger retraining when sensor calibration shifts or seasonal cycles alter reflectance baselines. Automate batch inference with message queues and attach CRS metadata to output predictions so downstream GIS tools can directly consume vectorized results without reprojection. Implement automated rollback hooks that revert to the previous model version if drift thresholds exceed predefined confidence intervals, ensuring continuous reliability in production environments.

Summary

Building a CNN for Satellite Imagery Classification succeeds only when geospatial constraints are treated as first-class pipeline components. By enforcing CRS consistency, applying per-scene spectral normalization, adapting backbone architectures for multi-band inputs, and blocking spatial leakage during validation, you create a robust foundation for production-grade remote sensing models.