Graph Neural Networks for Spatial Data

Implement graph neural networks for spatial data with PyTorch Geometric: model topology-aware relationships for urban mobility, ecological corridors, and networks.

Traditional convolutional architectures assume a regular grid — each pixel has the same number of neighbours at the same spacing. Spatial phenomena rarely work that way. Urban mobility unfolds across irregular street networks. Ecological corridors connect habitat patches separated by variable distances. Sensor arrays measure conditions at uneven point distributions. These problems demand a model that can reason directly about arbitrary topology, and that is where graph neural networks (GNNs) come in.

Building a production GNN for spatial data is non-trivial because the graph structure itself must be engineered from geographic relationships, validated against spatial independence assumptions, and kept consistent as real-world networks evolve. This guide walks through every step — from converting a GeoDataFrame into a PyTorch Geometric Data object to deploying drift monitors in production — as part of the broader Training Geospatial Predictive Models in Python framework.

Spatial GNN Production Pipeline Five-stage pipeline diagram showing how a GeoDataFrame flows through graph construction, node/edge feature engineering, topology-aware validation, GCN model training, and production inference with drift monitoring. GeoDataFrame nodes + geometries Graph Construction k-NN / Delaunay network topology Feature Engineering raster sampling edge weights Spatial Split + GCN Train block CV folds cosine annealing Production Inference incremental updates drift monitoring Spatial graph structure must be re-validated whenever topology changes (new sensors, road closures, boundary updates)

Prerequisites & Environment Setup

Ensure your environment supports both geospatial I/O and deep learning workloads before constructing any graphs. Mixed-version dependencies between GDAL, PyTorch, and CUDA are a common source of silent failures.

# Pinned dependency versions (tested combination)
python>=3.10
torch==2.2.2
torch-geometric==2.5.3
geopandas==0.14.4
shapely==2.0.4
rasterio==1.3.10
libpysal==4.9.2
networkx==3.3
scipy==1.13.0
scikit-learn==1.4.2
mlflow==2.13.0
# Install geospatial stack (GDAL system dep required)
sudo apt-get install gdal-bin libgdal-dev
pip install geopandas shapely rasterio libpysal networkx scipy scikit-learn

# Install PyTorch Geometric (match your CUDA version)
pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121
pip install torch-geometric==2.5.3

# MLOps tooling
pip install mlflow evidently prefect

All spatial datasets must be projected to a consistent CRS before graph construction. Mixing unprojected lat/lon coordinates with Euclidean distance assumptions silently degrades edge weights and neighbor ordering. Apply CRS Alignment and Projection Handling as the first step in every ingestion pipeline. EPSG:3857 (Web Mercator) works for global datasets; for regional analysis, choose a local metric projection to minimise distortion.

Step 1: Spatial Graph Construction

Spatial data must be transformed into a graph where nodes represent geographic entities — parcels, sensors, administrative zones — and edges encode spatial relationships. The choice of edge definition strategy determines what relational patterns your model can learn.

Common strategies:

  • k-Nearest Neighbors (k-NN): Connects each node to its k closest spatial neighbours. Simple and parameter-controlled, but can produce long-range edges in sparse regions.
  • Delaunay Triangulation: Creates non-overlapping triangles covering the point distribution. Guarantees no crossing edges and tends to produce locally consistent connectivity.
  • Distance Threshold: Connects nodes within a fixed radius (e.g., 500 m). Good when proximity has a physical meaning (signal propagation, contamination spread), but produces variable-degree nodes.
  • Network/Topology-Based: Uses existing road, utility, or river networks as edge definitions. Captures real-world traversal cost but requires a clean routable network dataset.

The following function converts a GeoDataFrame into a PyTorch Geometric Data object. It handles centroid extraction, symmetric adjacency generation, and proper tensor formatting for GPU training.

import geopandas as gpd
import numpy as np
import torch
from torch_geometric.data import Data
from sklearn.neighbors import kneighbors_graph


def build_spatial_graph(
    gdf: gpd.GeoDataFrame,
    k: int = 5,
    feature_cols: list[str] | None = None,
) -> Data:
    """
    Convert a projected GeoDataFrame to a PyTorch Geometric Data object.

    Args:
        gdf: GeoDataFrame with a projected (metric) CRS. Must not contain
             null geometries.
        k: Number of nearest spatial neighbours per node.
        feature_cols: Columns to use as node features. When None, coordinates
                      and degree centrality are used as fallback features.

    Returns:
        PyG Data object with x (node features) and edge_index tensors.
    """
    assert gdf.crs is not None and gdf.crs.is_projected, (
        "GeoDataFrame must use a projected (metric) CRS before graph construction. "
        f"Current CRS: {gdf.crs}"
    )
    assert not gdf.geometry.isna().any(), "Null geometries detected — run gdf.dropna(subset=['geometry']) first."

    # 1. Extract centroids (assumes projected CRS for correct distances)
    coords = np.array([(geom.centroid.x, geom.centroid.y) for geom in gdf.geometry])

    # 2. Build symmetric k-NN adjacency matrix
    adj = kneighbors_graph(coords, n_neighbors=k, mode="connectivity", include_self=False)
    adj = adj + adj.T
    adj[adj > 1] = 1  # Deduplicate bidirectional edges

    # 3. Convert to COO edge_index format required by PyG
    edge_index = torch.tensor(np.array(adj.nonzero()), dtype=torch.long)

    # 4. Prepare node features
    if feature_cols:
        x = torch.tensor(gdf[feature_cols].values, dtype=torch.float32)
    else:
        degrees = torch.tensor(adj.sum(axis=1).A1, dtype=torch.float32).unsqueeze(-1)
        x = torch.cat([torch.tensor(coords, dtype=torch.float32), degrees], dim=1)

    graph = Data(x=x, edge_index=edge_index, num_nodes=len(gdf))

    # Inline validation: edge_index must be within node count bounds
    assert edge_index.max().item() < len(gdf), "edge_index references out-of-range node indices."
    return graph

Validation: Call torch_geometric.utils.is_undirected(graph.edge_index) — it must return True for GCN layers that assume symmetric message passing.

Step 2: Node and Edge Feature Engineering

Raw coordinates rarely capture sufficient predictive signal. Production GNNs require enriched node and edge attributes assembled from domain-specific sources. This is where spatial feature engineering for machine learning intersects directly with graph construction.

Common enrichment strategies:

  • Raster-to-Vector Sampling: Extract elevation, NDVI/EVI, or land-cover values at node centroids using rasterio.sample. Align raster CRS with the graph’s CRS before sampling.
  • Network Metrics: Calculate betweenness centrality, closeness, or PageRank for infrastructure graphs using networkx.
  • Temporal Windows: Aggregate historical sensor readings or traffic counts into sliding-window features. For satellite-derived inputs, see Aggregating Daily Satellite Data to Monthly Features.
  • Edge Weights: Replace binary connectivity with inverse distance, travel time, or spectral similarity scores.
  • Spatial Lag Features: Append local spatial statistics — local Moran’s I, spatial lag of target variable — as node-level features to give the GNN an explicit autocorrelation signal before message passing begins.
import rasterio
from rasterio.sample import sample_gen


def enrich_nodes_with_raster(
    gdf: gpd.GeoDataFrame,
    raster_path: str,
    band: int = 1,
    col_name: str = "raster_value",
) -> gpd.GeoDataFrame:
    """
    Sample a single raster band at node centroids and attach values as a column.

    Args:
        gdf: Projected GeoDataFrame (CRS must match raster CRS).
        raster_path: Path to GeoTIFF or COG.
        band: 1-indexed band number to sample.
        col_name: Name for the new column.

    Returns:
        GeoDataFrame with raster values appended.
    """
    coords = [(geom.centroid.x, geom.centroid.y) for geom in gdf.geometry]

    with rasterio.open(raster_path) as src:
        assert gdf.crs == src.crs, (
            f"CRS mismatch: GDF is {gdf.crs}, raster is {src.crs}. "
            "Reproject before sampling."
        )
        values = [v[band - 1] for v in src.sample(coords)]

    gdf = gdf.copy()
    gdf[col_name] = values
    return gdf

Normalise all continuous features with StandardScaler or MinMaxScaler before constructing feature tensors. Unnormalised coordinates (raw metres in the millions) cause unstable gradients in the first GCN layer. For high-cardinality categorical attributes such as land-use class, use learned embeddings rather than one-hot encoding to avoid sparse input tensors. When feature dimensionality is high, dimensionality reduction for spatial data can significantly reduce training cost without sacrificing accuracy.

Step 3: Topology-Aware Validation and Spatial Independence

Random train/test splits violate spatial independence assumptions and produce inflated performance metrics. Spatial autocorrelation means nearby nodes share latent patterns — when neighbours appear in both training and validation sets, the model appears to generalise when it is only interpolating.

Adopt topology-aware splitting strategies to prevent this leakage. Detailed methodologies for partitioning spatial data correctly are covered in Spatial Cross-Validation Strategies, which provides runnable implementations of spatial blocking and leave-one-region-out splits.

The three most effective approaches for graph data:

  1. Spatial Blocking: Partition the study area into contiguous rectangular or hexagonal tiles. Assign entire blocks to folds so no two adjacent nodes end up in opposing sets.
  2. Leave-One-Region-Out: Reserve entire administrative zones, watersheds, or climate regions for the test set. Most conservative and most realistic for deployment scenarios where the model must generalise to unseen geographies.
  3. Graph Cut Validation: After assigning node splits by spatial block, remove edges that cross the boundary between training and test nodes. This prevents message passing from leaking label information during inference.
import torch
from torch_geometric.data import Data
from sklearn.model_selection import KFold
import numpy as np


def spatial_block_split(
    coords: np.ndarray,
    n_blocks: int = 10,
    test_fraction: float = 0.2,
    random_state: int = 42,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Assign node indices to train/test splits using spatial blocking.

    Args:
        coords: (N, 2) array of projected node coordinates.
        n_blocks: Number of spatial blocks (tiles) along the longer axis.
        test_fraction: Fraction of blocks to reserve for testing.
        random_state: Seed for reproducible block assignment.

    Returns:
        Tuple of (train_indices, test_indices) as 1-D integer arrays.
    """
    rng = np.random.default_rng(random_state)

    # Assign each node to a block by quantising its coordinates
    x_bins = np.floor(
        (coords[:, 0] - coords[:, 0].min())
        / (coords[:, 0].ptp() + 1e-9)
        * n_blocks
    ).astype(int)
    y_bins = np.floor(
        (coords[:, 1] - coords[:, 1].min())
        / (coords[:, 1].ptp() + 1e-9)
        * n_blocks
    ).astype(int)
    block_ids = x_bins * (n_blocks + 1) + y_bins

    unique_blocks = np.unique(block_ids)
    n_test_blocks = max(1, int(len(unique_blocks) * test_fraction))
    test_blocks = rng.choice(unique_blocks, n_test_blocks, replace=False)

    test_mask = np.isin(block_ids, test_blocks)
    train_indices = np.where(~test_mask)[0]
    test_indices = np.where(test_mask)[0]
    return train_indices, test_indices

After splitting, always quantify residual spatial dependence in model residuals using Moran’s I. Unaddressed spatial structure indicates the model is underfitting or the graph topology is incomplete. Handling Spatial Autocorrelation covers statistical diagnostics and correction techniques in detail.

Step 4: Model Architecture and Batched Training

Production GNNs typically use message-passing architectures. Graph Convolutional Networks (GCN) propagate aggregated neighbourhood information; GraphSAGE samples a fixed number of neighbours per layer to improve scalability on large graphs. The following two-layer GCN includes batch normalisation, dropout, and an output classification head.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv


class SpatialGCN(nn.Module):
    """
    Two-layer GCN for node-level classification on spatial graphs.

    Args:
        in_channels: Number of input node features.
        hidden_channels: Width of the hidden representation.
        out_channels: Number of output classes.
        dropout: Dropout probability applied between layers.
    """

    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        dropout: float = 0.3,
    ) -> None:
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.norm = nn.BatchNorm1d(hidden_channels)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_weight: torch.Tensor | None = None,
    ) -> torch.Tensor:
        x = self.conv1(x, edge_index, edge_weight)
        x = self.norm(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_weight)
        return x  # Raw logits; apply softmax/sigmoid at inference time

Training loop essentials:

from torch.optim.lr_scheduler import CosineAnnealingLR


def train_spatial_gcn(
    model: SpatialGCN,
    graph: Data,
    train_idx: torch.Tensor,
    val_idx: torch.Tensor,
    epochs: int = 200,
    lr: float = 1e-3,
    patience: int = 20,
) -> dict[str, list[float]]:
    """
    Train a SpatialGCN with cosine annealing and early stopping.

    Returns:
        Dict with 'train_loss' and 'val_loss' history lists.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float("inf")
    patience_counter = 0
    history: dict[str, list[float]] = {"train_loss": [], "val_loss": []}

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_idx], graph.y[train_idx])
        loss.backward()
        optimizer.step()
        scheduler.step()

        model.eval()
        with torch.no_grad():
            val_loss = criterion(out[val_idx], graph.y[val_idx]).item()

        history["train_loss"].append(loss.item())
        history["val_loss"].append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_spatial_gcn.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break

    model.load_state_dict(torch.load("best_spatial_gcn.pt"))
    return history

For large-scale spatial graphs that do not fit in GPU memory, use torch_geometric.loader.NeighborLoader with a num_neighbors list per layer (e.g., [10, 5] for two layers). This samples a fixed-size subgraph per batch, enabling mini-batch training. Track node embedding variance across layers to detect over-smoothing — a common failure mode where deep stacking collapses all representations toward a single vector.

Verification and Testing

After training, verify that the model generalises spatially rather than memorising local patterns.

import torch
from torch_geometric.data import Data
from sklearn.metrics import classification_report
import numpy as np


def evaluate_spatial_gcn(
    model: SpatialGCN,
    graph: Data,
    test_idx: torch.Tensor,
    class_names: list[str] | None = None,
) -> dict[str, float]:
    """
    Evaluate a trained SpatialGCN on held-out test nodes.

    Returns:
        Dict of per-class and macro-averaged metrics.
    """
    model.eval()
    with torch.no_grad():
        logits = model(graph.x, graph.edge_index)
        preds = logits[test_idx].argmax(dim=1).cpu().numpy()
        labels = graph.y[test_idx].cpu().numpy()

    report = classification_report(
        labels, preds, target_names=class_names, output_dict=True
    )
    print(classification_report(labels, preds, target_names=class_names))
    return report


# Structural sanity checks
def verify_graph_integrity(graph: Data) -> None:
    """Assert basic structural properties of a spatial PyG graph."""
    from torch_geometric.utils import is_undirected, contains_self_loops

    assert not contains_self_loops(graph.edge_index), "Graph contains self-loops; remove with remove_self_loops()."
    assert is_undirected(graph.edge_index), "Graph is directed; GCNConv assumes symmetric adjacency."
    assert graph.x.shape[0] == graph.num_nodes, "Node feature count does not match num_nodes."
    assert graph.edge_index.max().item() < graph.num_nodes, "edge_index out of bounds."
    print(f"Graph verified: {graph.num_nodes} nodes, {graph.edge_index.shape[1] // 2} undirected edges.")

Visual sanity check: Plot test node predictions on a map using geopandas and compare with ground truth. Spatial prediction errors should not cluster in compact regions — systematic spatial error patterns indicate the model is failing to generalise beyond training blocks.

Troubleshooting and Common Errors

AssertionError: GeoDataFrame must use a projected CRS Root cause: Calling build_spatial_graph on a lat/lon dataset. k-NN distances computed in degrees are meaningless for spatial proximity. Fix: gdf = gdf.to_crs(epsg=3857) before graph construction.

RuntimeError: Expected all tensors to be on the same device Root cause: graph.x or graph.edge_index was not moved to the GPU before the forward pass. Fix: graph = graph.to(device) where device = torch.device("cuda" if torch.cuda.is_available() else "cpu").

Silent: inflated validation accuracy from spatial leakage Root cause: Random train/test split where spatially adjacent nodes land in opposing sets. Fix: Use spatial_block_split (see Step 3) and inspect the geographic separation between folds with a map.

torch_geometric.utils.is_undirected returns False Root cause: The adjacency was not symmetrised after kneighbors_graph. Fix: adj = adj + adj.T; adj[adj > 1] = 1 then rebuild edge_index.

Over-smoothing: all node embeddings converge after 4+ layers Root cause: Excessive message-passing depth causes representations to lose discriminative information. Fix: Limit depth to 2–3 layers; add residual connections (x = x + self.conv1(...)); use DropEdge regularisation.

CUDA out of memory during batched training Root cause: Full-graph training on large spatial datasets loads the entire adjacency into GPU memory. Fix: Switch from full-batch training to NeighborLoader with num_neighbors=[10, 5] per layer and batch_size=512.

ValueError: operands could not be broadcast in edge weight multiplication Root cause: Edge weight tensor shape (E,) does not match edge_index column count after removing self-loops or isolated nodes. Fix: Recompute edge weights after every graph structural operation, not before.

Performance Optimisation

Adjacency caching: Computing the k-NN graph is O(N²) for naive implementations. Cache the resulting edge_index and edge_attr tensors to disk with torch.save and reload them at training time. Rebuild only when node locations or k change.

Sparse matrix format: Use scipy.sparse.csr_matrix for adjacency throughout preprocessing. Avoid converting to dense NumPy arrays — a 50,000-node graph produces a 20 GB dense matrix.

Spatial indexing for distance threshold graphs: Use libpysal.weights.DistanceBand or a scipy.spatial.KDTree for radius-based connectivity. Both use spatial indexing internally and scale to millions of points. Avoid nested loops over geometry pairs.

Mini-batch training with NeighborLoader:

from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    graph,
    num_neighbors=[10, 5],   # Sample 10 1-hop, 5 2-hop neighbours per node
    batch_size=512,
    input_nodes=train_idx,
    shuffle=True,
    num_workers=4,
)

Feature scaling for geospatial inputs affects GPU memory as well as convergence. Float32 node features are standard; using float64 doubles memory consumption with no accuracy benefit in most GNN architectures.

MLOps, Inference Automation, and Drift Detection

Deploying spatial GNNs requires infrastructure that handles graph serialisation, dynamic topology updates, and continuous monitoring. Unlike tabular models, graph models depend on both node features and structural relationships, making drift detection more complex.

Pipeline orchestration: Wrap graph construction, training, and evaluation in a reproducible DAG using Prefect or Airflow. Log hyperparameters, model artefacts, and spatial validation metrics with MLflow. Version the adjacency matrix separately from node features — they have different update cadences.

Dynamic graph updates: Spatial networks evolve: new sensors deploy, roads close, land use changes. Implement incremental graph updates rather than full recomputation. Cache adjacency matrices in Parquet or a graph database, apply delta updates during inference windows, and trigger lightweight fine-tuning when structural changes exceed a threshold.

Drift detection for graphs: Monitor two dimensions:

  • Feature drift: Shifts in node attributes such as sensor calibration changes or seasonal NDVI variation. Compare training vs. production distributions using Kolmogorov-Smirnov tests per feature.
  • Structural drift: Changes in connectivity patterns from new transit lines or urban expansion. Track degree distribution shifts, average path length, and connected component counts. Alert when these deviate beyond historically observed seasonal ranges.

Best Practices Checklist

FAQ

When should I use a GNN instead of a CNN for geospatial data?

Use a GNN when geographic entities are irregular (parcels, sensors, administrative zones) and relational topology matters — road networks, hydrological adjacency, ecological connectivity. CNNs assume uniform grids and consistent spatial resolution. If your input is a satellite image tile processed at pixel level, a CNN or vision transformer is likely more appropriate. For entity-level prediction over irregular geometries, GNNs model the relational structure directly. See Building a CNN for Satellite Imagery Classification for a direct architectural comparison.

How do I prevent spatial data leakage in GNN validation?

Apply spatial blocking before constructing any train/test split. Random splits allow spatially correlated neighbours to appear in both sets because GNNs aggregate information from neighbours during message passing — any neighbour of a test node that appears in the training set constitutes leakage. Spatial blocking separates the entire graph into geographically contiguous partitions, then removes edges that cross partition boundaries during evaluation. The Spatial Cross-Validation Strategies page provides reference implementations.

What is over-smoothing and how do I detect it early?

Over-smoothing occurs when stacking many message-passing layers causes all node representations to converge toward the same vector — essentially, each node’s embedding becomes the global graph average. Detect it by computing the mean absolute deviation (MAD) of node embeddings at each layer. A MAD collapsing toward zero indicates over-smoothing. In practice, two or three GCN layers are sufficient for most spatial problems. Residual connections and DropEdge regularisation delay the onset of smoothing when deeper architectures are required.

How do I handle dynamic spatial graphs where topology changes over time?

Cache the adjacency matrix separately from node features with distinct version identifiers. When structural changes occur (new road segments, decommissioned sensors), compute a delta — added and removed edges — and apply it incrementally. Re-run the structural integrity checks (verify_graph_integrity) after each delta. Trigger lightweight fine-tuning from the most recent checkpoint rather than full retraining when the delta is small. For large structural changes, retrain from scratch with the updated adjacency but initialise from the previous checkpoint to accelerate convergence.


Part of: Training Geospatial Predictive Models in Python