TL;DR: Adapt a standard ResNet backbone to accept multi-band satellite inputs by replacing conv1 and initializing new spectral weights from averaged RGB kernels. Apply per-scene percentile normalization before building tensors, then enforce spatially blocked cross-validation to prevent geographic leakage. The core pattern is src.read(bands).astype(np.float32) → percentile clip → torch.from_numpy() → modified backbone → blocked GroupKFold by tile region.
Part of: Graph Neural Networks for Spatial Data
Why This Fails in Geospatial ML Pipelines
Standard computer vision pipelines assume three things that satellite imagery violates immediately: a fixed 3-channel input, i.i.d. training/validation splits, and pixel intensities that are stationary across the dataset. Each assumption breaks in a different place of the pipeline, often silently.
Band count mismatch is the most visible failure: calling resnet18(pretrained=True) and forwarding a 5- or 12-channel tensor raises a shape error at the first convolution. Less visible is the weight initialization trap — replacing the layer with random weights discards the edge-detector priors that make transfer learning valuable on small datasets.
Spectral non-stationarity is subtler. Top-of-atmosphere reflectance for a Saharan sand scene in July and a boreal forest scene in February occupy completely different value ranges, even across the same band. Without per-scene normalization the optimizer sees gradient explosions early in training, or a model that learned only to distinguish bright vs dark scenes rather than land cover classes.
Geographic leakage is the most dangerous failure because the metrics look fine. When image tiles from the same acquisition footprint are split randomly into train and test, the model exploits local spectral correlation — nearby tiles look nearly identical — and reports validation accuracy far above what will generalize to an unseen region or season. Without reducing spatial leakage in model training, you may ship a model that degrades 20–40% on deployment.
Core Principles for a Geospatial-Aware CNN Pipeline
- Validate CRS consistency before anything else. Mismatched projections produce silently misregistered patches; applying CRS alignment and projection handling in the ingestion step catches this before tensors are built.
- Read only the bands you need via windowed I/O. Streaming Cloud-Optimized GeoTIFFs (COGs) with
rasteriowindowed reads avoids loading entire scenes into memory for patch extraction. - Normalize per-scene, not per-dataset. Use 2nd–98th percentile clipping computed from each individual scene to cancel out sensor calibration drift, seasonal illumination shifts, and atmospheric haze.
- Encode derived spectral indices as explicit channels. NDVI, NDWI, or NDBI computed from band math give the network explicit land-cover priors; see raster band math and index calculation for validated implementations.
- Block splits by spatial region, not by sample index. Group tiles by acquisition zone or administrative boundary and use
GroupKFoldto guarantee that no region appears in both train and validation folds. - Preserve the affine transform and CRS through inference. Vectorizing predictions back to geographic space requires the same spatial metadata captured at read time.
Pipeline Architecture
The diagram below shows the five-stage flow from raw COG to a georeferenced prediction:
Production-Ready Code
Dataset: windowed COG reads with per-scene normalization
import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights
import logging
logger = logging.getLogger(__name__)
class SatelliteDataset(Dataset):
"""
PyTorch Dataset for multi-band GeoTIFF patches.
Parameters
----------
img_dir : str
Directory containing .tif scene files.
label_map : dict[str, int]
Mapping of filename (basename) to integer class id.
bands : tuple[int, ...]
1-indexed band numbers to read (Sentinel-2 defaults: B2,B3,B4,B8,B8A).
patch_size : int
Square centre-crop size in pixels applied after reading.
max_cloud_fraction : float
Discard patches where the fraction of nodata pixels exceeds this value.
"""
def __init__(
self,
img_dir: str,
label_map: dict,
bands: tuple = (2, 3, 4, 8, 9),
patch_size: int = 256,
max_cloud_fraction: float = 0.20,
):
self.img_paths = sorted(
os.path.join(img_dir, f)
for f in os.listdir(img_dir)
if f.endswith(".tif")
)
self.label_map = label_map
self.bands = bands
self.patch_size = patch_size
self.max_cloud_fraction = max_cloud_fraction
def __len__(self) -> int:
return len(self.img_paths)
def __getitem__(self, idx: int):
path = self.img_paths[idx]
basename = os.path.basename(path)
with rasterio.open(path) as src:
# Validate CRS before reading — silent misregistration otherwise
if src.crs is None:
raise ValueError(f"No CRS on {basename}; reproject before ingestion.")
data = src.read(self.bands).astype(np.float32) # (C, H, W)
# Replace nodata sentinel (0) with NaN before normalization
data[data == 0] = np.nan
# Cloud fraction check — skip heavily obscured patches
cloud_fraction = np.isnan(data).any(axis=0).mean()
if cloud_fraction > self.max_cloud_fraction:
logger.warning(
"%s skipped: cloud fraction %.2f exceeds threshold %.2f",
basename,
cloud_fraction,
self.max_cloud_fraction,
)
# Return a zero tensor; callers should filter these with a collate_fn
n_bands = len(self.bands)
label = self.label_map.get(basename, 0)
return torch.zeros(n_bands, self.patch_size, self.patch_size), torch.tensor(label, dtype=torch.long)
# Per-scene percentile normalization (2nd–98th) — one set of stats per scene
low, high = np.nanpercentile(data, [2, 98], axis=(1, 2), keepdims=True)
data = (data - low) / (high - low + 1e-6)
data = np.nan_to_num(data, nan=0.0).clip(0.0, 1.0)
tensor = torch.from_numpy(data)
# Centre-crop to patch_size
_, h, w = tensor.shape
if h > self.patch_size:
y0 = (h - self.patch_size) // 2
tensor = tensor[:, y0 : y0 + self.patch_size, :]
if w > self.patch_size:
x0 = (w - self.patch_size) // 2
tensor = tensor[:, :, x0 : x0 + self.patch_size]
label = self.label_map.get(basename, 0)
return tensor, torch.tensor(label, dtype=torch.long)Model: multi-band ResNet with averaged weight initialization
class MultiBandResNet(nn.Module):
"""
ResNet18 adapted for an arbitrary number of spectral input bands.
Weight initialization strategy: average the three pretrained RGB kernels
and tile the result across `in_channels`. This preserves edge-detector
priors rather than starting from random noise.
"""
def __init__(self, in_channels: int, num_classes: int = 10):
super().__init__()
base = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
old_conv = base.conv1 # original (64, 3, 7, 7) kernel
base.conv1 = nn.Conv2d(
in_channels,
old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=False,
)
# Average RGB kernels → (64, 1, 7, 7), then tile to (64, in_channels, 7, 7)
with torch.no_grad():
avg_kernel = old_conv.weight.mean(dim=1, keepdim=True)
base.conv1.weight.copy_(avg_kernel.repeat(1, in_channels, 1, 1))
base.fc = nn.Linear(base.fc.in_features, num_classes)
self.model = base
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)Step-by-Step Walkthrough
Step 1: Prepare scene tiles with CRS validation
Organize raw Sentinel-2 scenes into a flat tile directory. Before any model training, verify every .tif shares the same projected CRS — EPSG:32632 (UTM zone 32N) is typical for European acquisitions. Use rasterio.open(p).crs.to_epsg() and assert equality across all tiles. Mismatched tiles should be reprojected before this step, following the guidance in CRS alignment and projection handling.
Step 2: Compute optional spectral index channels
Derived indices improve class separability for vegetation and water classification tasks. Compute NDVI as (NIR - Red) / (NIR - Red + 1e-6) and append it as an extra channel alongside the raw reflectance bands. This is an application of raster band math and index calculation — compute indices before normalization so the ratio algebra operates on physically meaningful reflectance values.
Step 3: Build geographically stratified splits
Assign each tile a zone label (e.g., acquisition footprint ID or 50 km grid block). Use scikit-learn’s GroupKFold with these zone labels as groups:
from sklearn.model_selection import GroupKFold
import numpy as np
# filenames: list of .tif basenames
# zone_ids: list of zone labels (e.g., UTM tile ID like "32UNB") per filename
# labels: class id per filename
gkf = GroupKFold(n_splits=5)
for fold, (train_idx, val_idx) in enumerate(
gkf.split(filenames, labels, groups=zone_ids)
):
train_files = [filenames[i] for i in train_idx]
val_files = [filenames[i] for i in val_idx]
# Build SatelliteDataset and DataLoader for this fold
# ... (model.fit / evaluate)Without zone-based grouping, the standard random split leaks spatial correlation across folds — see spatial cross-validation strategies for the theoretical basis and a complete SpatialKFold implementation.
Step 4: Train with the multi-band backbone
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiBandResNet(in_channels=6, num_classes=8).to(device) # 5 bands + NDVI
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
label_map = {"tile_001.tif": 0, "tile_002.tif": 2} # etc.
dataset = SatelliteDataset("data/tiles/", label_map, bands=(2, 3, 4, 8, 9), patch_size=224)
loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
for epoch in range(30):
model.train()
running_loss = 0.0
for imgs, targets in loader:
imgs, targets = imgs.to(device), targets.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}: loss={running_loss / len(loader):.4f}")Step 5: Export and attach CRS for georeferenced predictions
During inference, re-open each source tile to recover its affine transform and CRS. Write predicted class rasters as GeoTIFFs with the original spatial metadata so downstream GIS tools can consume results directly:
import rasterio
from rasterio.transform import from_bounds
def infer_and_georeference(
model: nn.Module,
tile_path: str,
bands: tuple,
device: torch.device,
output_path: str,
) -> None:
model.eval()
with rasterio.open(tile_path) as src:
profile = src.profile.copy()
data = src.read(bands).astype(np.float32)
crs = src.crs
transform = src.transform
data[data == 0] = np.nan
low, high = np.nanpercentile(data, [2, 98], axis=(1, 2), keepdims=True)
data = (data - low) / (high - low + 1e-6)
data = np.nan_to_num(data, nan=0.0).clip(0.0, 1.0)
tensor = torch.from_numpy(data).unsqueeze(0).to(device) # (1, C, H, W)
with torch.no_grad():
pred_class = model(tensor).argmax(dim=1).squeeze(0).cpu().numpy() # (H, W)
profile.update(
dtype=rasterio.uint8, count=1, crs=crs, transform=transform
)
with rasterio.open(output_path, "w", **profile) as dst:
dst.write(pred_class.astype(np.uint8), 1)
logger.info("Saved georeferenced prediction to %s (CRS: %s)", output_path, crs)Verification
Confirm the pipeline is working correctly with three checks:
import rasterio
# 1. Shape and value range
ds = SatelliteDataset("data/tiles/", label_map)
img, lbl = ds[0]
assert img.shape == (5, 256, 256), f"Unexpected shape: {img.shape}"
assert 0.0 <= img.min() and img.max() <= 1.0, "Normalization out of range"
# 2. No all-zero patches in a well-populated tile set
zero_count = sum(1 for i in range(len(ds)) if ds[i][0].sum() == 0)
assert zero_count / len(ds) < 0.05, "Too many zero/cloud-masked patches"
# 3. Output GeoTIFF carries source CRS
infer_and_georeference(model, "data/tiles/tile_001.tif", (2,3,4,8,9), device, "/tmp/pred.tif")
with rasterio.open("/tmp/pred.tif") as out:
assert out.crs is not None, "Output prediction missing CRS"
assert out.crs.to_epsg() == 32632, f"Wrong CRS on output: {out.crs}"
print("Output shape:", out.shape, "| CRS:", out.crs)FAQ
Why can’t I use a standard ImageNet-pretrained CNN directly on Sentinel-2 imagery?
ImageNet-pretrained models expect exactly 3 RGB channels. Sentinel-2 provides 10–13 spectral bands. You must replace the first convolutional layer to accept the correct number of input channels, then initialize its weights by averaging the pretrained RGB kernels across the new spectral dimension so that prior feature-detector knowledge transfers without random initialization noise.
What is spatial leakage and why does it inflate CNN accuracy on satellite data?
Spatial leakage occurs when geographically proximate tiles end up in both the training and validation split. Because nearby pixels share spectral signatures, surface cover, and atmospheric conditions, the model effectively memorizes local context rather than learning transferable patterns. The result is validation accuracy that does not generalize to a new acquisition area or season.
How do I handle cloud-masked or nodata pixels in a satellite CNN training set?
Replace nodata values (typically 0 or the raster’s nodata sentinel) with NaN before normalization, then convert NaN to 0.0 after percentile scaling. Do not normalize across nodata pixels — they corrupt the band statistics. For training, optionally discard patches exceeding a cloud fraction threshold (e.g., 20%) using the scene’s cloud-mask band or a pre-computed QA layer.
Related
- Spatial Cross-Validation Strategies — geographically blocked fold construction to prevent leakage
- Reducing Spatial Leakage in Model Training — buffer exclusion zones and grid-blocking techniques
- Raster Band Math and Index Calculation — NDVI, EVI, and NDWI as engineered input channels
- Gradient Boosting for Raster Data — tabular alternative when patch context is not required
Part of: Graph Neural Networks for Spatial Data — Training Geospatial Predictive Models in Python