Building a CNN for Satellite Imagery Classification

Build a production CNN for satellite imagery classification with rasterio: multi-band I/O, spectral normalization, spatial cross-validation, and inference deployment.

TL;DR: Adapt a standard ResNet backbone to accept multi-band satellite inputs by replacing conv1 and initializing new spectral weights from averaged RGB kernels. Apply per-scene percentile normalization before building tensors, then enforce spatially blocked cross-validation to prevent geographic leakage. The core pattern is src.read(bands).astype(np.float32) → percentile clip → torch.from_numpy() → modified backbone → blocked GroupKFold by tile region.

Part of: Graph Neural Networks for Spatial Data


Why This Fails in Geospatial ML Pipelines

Standard computer vision pipelines assume three things that satellite imagery violates immediately: a fixed 3-channel input, i.i.d. training/validation splits, and pixel intensities that are stationary across the dataset. Each assumption breaks in a different place of the pipeline, often silently.

Band count mismatch is the most visible failure: calling resnet18(pretrained=True) and forwarding a 5- or 12-channel tensor raises a shape error at the first convolution. Less visible is the weight initialization trap — replacing the layer with random weights discards the edge-detector priors that make transfer learning valuable on small datasets.

Spectral non-stationarity is subtler. Top-of-atmosphere reflectance for a Saharan sand scene in July and a boreal forest scene in February occupy completely different value ranges, even across the same band. Without per-scene normalization the optimizer sees gradient explosions early in training, or a model that learned only to distinguish bright vs dark scenes rather than land cover classes.

Geographic leakage is the most dangerous failure because the metrics look fine. When image tiles from the same acquisition footprint are split randomly into train and test, the model exploits local spectral correlation — nearby tiles look nearly identical — and reports validation accuracy far above what will generalize to an unseen region or season. Without reducing spatial leakage in model training, you may ship a model that degrades 20–40% on deployment.


Core Principles for a Geospatial-Aware CNN Pipeline

  • Validate CRS consistency before anything else. Mismatched projections produce silently misregistered patches; applying CRS alignment and projection handling in the ingestion step catches this before tensors are built.
  • Read only the bands you need via windowed I/O. Streaming Cloud-Optimized GeoTIFFs (COGs) with rasterio windowed reads avoids loading entire scenes into memory for patch extraction.
  • Normalize per-scene, not per-dataset. Use 2nd–98th percentile clipping computed from each individual scene to cancel out sensor calibration drift, seasonal illumination shifts, and atmospheric haze.
  • Encode derived spectral indices as explicit channels. NDVI, NDWI, or NDBI computed from band math give the network explicit land-cover priors; see raster band math and index calculation for validated implementations.
  • Block splits by spatial region, not by sample index. Group tiles by acquisition zone or administrative boundary and use GroupKFold to guarantee that no region appears in both train and validation folds.
  • Preserve the affine transform and CRS through inference. Vectorizing predictions back to geographic space requires the same spatial metadata captured at read time.

Pipeline Architecture

The diagram below shows the five-stage flow from raw COG to a georeferenced prediction:

Five-stage CNN satellite imagery classification pipeline Boxes connected by arrows: COG Streaming, Spectral Normalization, Multi-Band CNN, Spatial CV, Georeferenced Output COG Streaming windowed rasterio Spectral Norm. per-scene p2–p98 Multi-Band CNN adapted ResNet18 Spatial CV GroupKFold by zone Georef. Output CRS-attached vector

Production-Ready Code

Dataset: windowed COG reads with per-scene normalization

import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights
import logging

logger = logging.getLogger(__name__)


class SatelliteDataset(Dataset):
    """
    PyTorch Dataset for multi-band GeoTIFF patches.

    Parameters
    ----------
    img_dir : str
        Directory containing .tif scene files.
    label_map : dict[str, int]
        Mapping of filename (basename) to integer class id.
    bands : tuple[int, ...]
        1-indexed band numbers to read (Sentinel-2 defaults: B2,B3,B4,B8,B8A).
    patch_size : int
        Square centre-crop size in pixels applied after reading.
    max_cloud_fraction : float
        Discard patches where the fraction of nodata pixels exceeds this value.
    """

    def __init__(
        self,
        img_dir: str,
        label_map: dict,
        bands: tuple = (2, 3, 4, 8, 9),
        patch_size: int = 256,
        max_cloud_fraction: float = 0.20,
    ):
        self.img_paths = sorted(
            os.path.join(img_dir, f)
            for f in os.listdir(img_dir)
            if f.endswith(".tif")
        )
        self.label_map = label_map
        self.bands = bands
        self.patch_size = patch_size
        self.max_cloud_fraction = max_cloud_fraction

    def __len__(self) -> int:
        return len(self.img_paths)

    def __getitem__(self, idx: int):
        path = self.img_paths[idx]
        basename = os.path.basename(path)

        with rasterio.open(path) as src:
            # Validate CRS before reading — silent misregistration otherwise
            if src.crs is None:
                raise ValueError(f"No CRS on {basename}; reproject before ingestion.")

            data = src.read(self.bands).astype(np.float32)  # (C, H, W)

        # Replace nodata sentinel (0) with NaN before normalization
        data[data == 0] = np.nan

        # Cloud fraction check — skip heavily obscured patches
        cloud_fraction = np.isnan(data).any(axis=0).mean()
        if cloud_fraction > self.max_cloud_fraction:
            logger.warning(
                "%s skipped: cloud fraction %.2f exceeds threshold %.2f",
                basename,
                cloud_fraction,
                self.max_cloud_fraction,
            )
            # Return a zero tensor; callers should filter these with a collate_fn
            n_bands = len(self.bands)
            label = self.label_map.get(basename, 0)
            return torch.zeros(n_bands, self.patch_size, self.patch_size), torch.tensor(label, dtype=torch.long)

        # Per-scene percentile normalization (2nd–98th) — one set of stats per scene
        low, high = np.nanpercentile(data, [2, 98], axis=(1, 2), keepdims=True)
        data = (data - low) / (high - low + 1e-6)
        data = np.nan_to_num(data, nan=0.0).clip(0.0, 1.0)

        tensor = torch.from_numpy(data)

        # Centre-crop to patch_size
        _, h, w = tensor.shape
        if h > self.patch_size:
            y0 = (h - self.patch_size) // 2
            tensor = tensor[:, y0 : y0 + self.patch_size, :]
        if w > self.patch_size:
            x0 = (w - self.patch_size) // 2
            tensor = tensor[:, :, x0 : x0 + self.patch_size]

        label = self.label_map.get(basename, 0)
        return tensor, torch.tensor(label, dtype=torch.long)

Model: multi-band ResNet with averaged weight initialization

class MultiBandResNet(nn.Module):
    """
    ResNet18 adapted for an arbitrary number of spectral input bands.

    Weight initialization strategy: average the three pretrained RGB kernels
    and tile the result across `in_channels`. This preserves edge-detector
    priors rather than starting from random noise.
    """

    def __init__(self, in_channels: int, num_classes: int = 10):
        super().__init__()
        base = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        old_conv = base.conv1  # original (64, 3, 7, 7) kernel

        base.conv1 = nn.Conv2d(
            in_channels,
            old_conv.out_channels,
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding,
            bias=False,
        )

        # Average RGB kernels → (64, 1, 7, 7), then tile to (64, in_channels, 7, 7)
        with torch.no_grad():
            avg_kernel = old_conv.weight.mean(dim=1, keepdim=True)
            base.conv1.weight.copy_(avg_kernel.repeat(1, in_channels, 1, 1))

        base.fc = nn.Linear(base.fc.in_features, num_classes)
        self.model = base

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

Step-by-Step Walkthrough

Step 1: Prepare scene tiles with CRS validation

Organize raw Sentinel-2 scenes into a flat tile directory. Before any model training, verify every .tif shares the same projected CRS — EPSG:32632 (UTM zone 32N) is typical for European acquisitions. Use rasterio.open(p).crs.to_epsg() and assert equality across all tiles. Mismatched tiles should be reprojected before this step, following the guidance in CRS alignment and projection handling.

Step 2: Compute optional spectral index channels

Derived indices improve class separability for vegetation and water classification tasks. Compute NDVI as (NIR - Red) / (NIR - Red + 1e-6) and append it as an extra channel alongside the raw reflectance bands. This is an application of raster band math and index calculation — compute indices before normalization so the ratio algebra operates on physically meaningful reflectance values.

Step 3: Build geographically stratified splits

Assign each tile a zone label (e.g., acquisition footprint ID or 50 km grid block). Use scikit-learn’s GroupKFold with these zone labels as groups:

from sklearn.model_selection import GroupKFold
import numpy as np

# filenames: list of .tif basenames
# zone_ids: list of zone labels (e.g., UTM tile ID like "32UNB") per filename
# labels:   class id per filename

gkf = GroupKFold(n_splits=5)
for fold, (train_idx, val_idx) in enumerate(
    gkf.split(filenames, labels, groups=zone_ids)
):
    train_files = [filenames[i] for i in train_idx]
    val_files   = [filenames[i] for i in val_idx]
    # Build SatelliteDataset and DataLoader for this fold
    # ... (model.fit / evaluate)

Without zone-based grouping, the standard random split leaks spatial correlation across folds — see spatial cross-validation strategies for the theoretical basis and a complete SpatialKFold implementation.

Step 4: Train with the multi-band backbone

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiBandResNet(in_channels=6, num_classes=8).to(device)  # 5 bands + NDVI
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

label_map = {"tile_001.tif": 0, "tile_002.tif": 2}  # etc.
dataset = SatelliteDataset("data/tiles/", label_map, bands=(2, 3, 4, 8, 9), patch_size=224)
loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)

for epoch in range(30):
    model.train()
    running_loss = 0.0
    for imgs, targets in loader:
        imgs, targets = imgs.to(device), targets.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}: loss={running_loss / len(loader):.4f}")

Step 5: Export and attach CRS for georeferenced predictions

During inference, re-open each source tile to recover its affine transform and CRS. Write predicted class rasters as GeoTIFFs with the original spatial metadata so downstream GIS tools can consume results directly:

import rasterio
from rasterio.transform import from_bounds

def infer_and_georeference(
    model: nn.Module,
    tile_path: str,
    bands: tuple,
    device: torch.device,
    output_path: str,
) -> None:
    model.eval()
    with rasterio.open(tile_path) as src:
        profile = src.profile.copy()
        data = src.read(bands).astype(np.float32)
        crs = src.crs
        transform = src.transform

    data[data == 0] = np.nan
    low, high = np.nanpercentile(data, [2, 98], axis=(1, 2), keepdims=True)
    data = (data - low) / (high - low + 1e-6)
    data = np.nan_to_num(data, nan=0.0).clip(0.0, 1.0)

    tensor = torch.from_numpy(data).unsqueeze(0).to(device)  # (1, C, H, W)

    with torch.no_grad():
        pred_class = model(tensor).argmax(dim=1).squeeze(0).cpu().numpy()  # (H, W)

    profile.update(
        dtype=rasterio.uint8, count=1, crs=crs, transform=transform
    )
    with rasterio.open(output_path, "w", **profile) as dst:
        dst.write(pred_class.astype(np.uint8), 1)

    logger.info("Saved georeferenced prediction to %s (CRS: %s)", output_path, crs)

Verification

Confirm the pipeline is working correctly with three checks:

import rasterio

# 1. Shape and value range
ds = SatelliteDataset("data/tiles/", label_map)
img, lbl = ds[0]
assert img.shape == (5, 256, 256), f"Unexpected shape: {img.shape}"
assert 0.0 <= img.min() and img.max() <= 1.0, "Normalization out of range"

# 2. No all-zero patches in a well-populated tile set
zero_count = sum(1 for i in range(len(ds)) if ds[i][0].sum() == 0)
assert zero_count / len(ds) < 0.05, "Too many zero/cloud-masked patches"

# 3. Output GeoTIFF carries source CRS
infer_and_georeference(model, "data/tiles/tile_001.tif", (2,3,4,8,9), device, "/tmp/pred.tif")
with rasterio.open("/tmp/pred.tif") as out:
    assert out.crs is not None, "Output prediction missing CRS"
    assert out.crs.to_epsg() == 32632, f"Wrong CRS on output: {out.crs}"
    print("Output shape:", out.shape, "| CRS:", out.crs)

FAQ

Why can’t I use a standard ImageNet-pretrained CNN directly on Sentinel-2 imagery?

ImageNet-pretrained models expect exactly 3 RGB channels. Sentinel-2 provides 10–13 spectral bands. You must replace the first convolutional layer to accept the correct number of input channels, then initialize its weights by averaging the pretrained RGB kernels across the new spectral dimension so that prior feature-detector knowledge transfers without random initialization noise.

What is spatial leakage and why does it inflate CNN accuracy on satellite data?

Spatial leakage occurs when geographically proximate tiles end up in both the training and validation split. Because nearby pixels share spectral signatures, surface cover, and atmospheric conditions, the model effectively memorizes local context rather than learning transferable patterns. The result is validation accuracy that does not generalize to a new acquisition area or season.

How do I handle cloud-masked or nodata pixels in a satellite CNN training set?

Replace nodata values (typically 0 or the raster’s nodata sentinel) with NaN before normalization, then convert NaN to 0.0 after percentile scaling. Do not normalize across nodata pixels — they corrupt the band statistics. For training, optionally discard patches exceeding a cloud fraction threshold (e.g., 20%) using the scene’s cloud-mask band or a pre-computed QA layer.


Part of: Graph Neural Networks for Spatial DataTraining Geospatial Predictive Models in Python