Graph Neural Networks πŸ•ΈοΈ

Class 12Age 16-17Lesson 04 of 12πŸ†“ Free
Class 12 Lesson 04 hero β€” Sneha, Kolkata
Watch first - 2-3 minutes

Class 12 Lesson 4 - Graph Neural Networks

No sign-in needed - English narration - Safe for all school ages

Story
Sneha's Fraud Detection Network
πŸ‘©β€πŸ’Ό Sneha Β· Kolkata Β· Age 17

Sneha interned with a Kolkata fintech startup. Their UPI fraud detector used 30 hand-written rules ("flag if > β‚Ή50,000 to a new payee at 3 AM") and had a 32% false-positive rate, annoying real users.

The fraud team showed her a graph: nodes are accounts, edges are transactions. Fraud rings form distinctive patterns β€” a "fan" of small transfers into one account, then a single large outflow. Traditional ML treats each transaction independently and misses these patterns. Graph Neural Networks see the whole structure.

Sneha built a GraphSAGE model that reduced false positives from 32% to 9% while catching 28% more actual fraud rings.

Concepts
Why Graphs?

Many real-world problems are naturally graphs:

The core operation in a GNN is message passing: each node updates its representation by aggregating information from its neighbours. Stack k layers, and each node "sees" k hops away.

GCN

Symmetric normalised aggregation. Foundational. Limited to fixed graphs.

GraphSAGE

Samples a fixed-size neighbourhood. Scales to billions of nodes. Inductive (works on unseen nodes).

GAT

Graph Attention β€” neighbours weighted by learned attention. Best when some edges matter more.

GIN

Graph Isomorphism Network. Most expressive β€” can distinguish structurally different graphs.

Code
Build a Fraud Detector with PyTorch Geometric
!pip install -q torch-geometric

Build the transaction graph:

import torch
from torch_geometric.data import Data

# Each account is a node with features (age_days, kyc_level, avg_balance, ...)
node_features = torch.tensor([
    [180, 2, 12000.0, 23, 0],   # account 0: 180 days old, KYC L2, ...
    [12,  1, 800.0,   45, 1],   # account 1: new, low KYC, suspicious
    # ... thousands of accounts
], dtype=torch.float)

# Edges: each transaction is a directed edge from sender to receiver
edge_index = torch.tensor([
    [0, 1, 1, 2, 3],   # source accounts
    [1, 2, 3, 4, 4],   # destination accounts
], dtype=torch.long)

# Edge features: amount, timestamp, channel
edge_attr = torch.tensor([
    [5000.0, 1700000000, 0],
    [4900.0, 1700000050, 0],
    # ...
], dtype=torch.float)

# Labels: 0 = clean, 1 = fraud (only on labelled accounts)
y = torch.tensor([0, 1, 1, 0, 0], dtype=torch.long)

graph = Data(x=node_features, edge_index=edge_index,
             edge_attr=edge_attr, y=y)

Define the GraphSAGE classifier:

import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class FraudGNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.sage1 = SAGEConv(in_dim, hidden_dim, aggr="mean")
        self.sage2 = SAGEConv(hidden_dim, hidden_dim, aggr="mean")
        self.classifier = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        # 1-hop neighbourhood aggregation
        h = self.sage1(x, edge_index)
        h = F.relu(h)
        h = F.dropout(h, p=0.3, training=self.training)
        # 2-hop: each node now sees neighbours-of-neighbours
        h = self.sage2(h, edge_index)
        h = F.relu(h)
        return self.classifier(h)  # logits per node
Training
Mini-batch Training with NeighborLoader

Real fraud graphs have millions of nodes β€” too big to fit one batch. NeighborLoader samples a fixed-size neighbourhood per labelled node:

from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    graph,
    num_neighbors=[20, 10],   # 20 1-hop neighbours, 10 2-hop neighbours
    batch_size=256,
    input_nodes=train_mask,
)

model = FraudGNN(in_dim=5, hidden_dim=64, out_dim=2).cuda()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

# Class weighting β€” fraud is rare (~0.5% of accounts)
class_weights = torch.tensor([1.0, 50.0]).cuda()

for epoch in range(20):
    model.train()
    for batch in train_loader:
        batch = batch.cuda()
        out = model(batch.x, batch.edge_index)
        # Predictions only for the seed nodes in the batch
        loss = F.cross_entropy(
            out[:batch.batch_size],
            batch.y[:batch.batch_size],
            weight=class_weights,
        )
        opt.zero_grad(); loss.backward(); opt.step()
    print(f"Epoch {epoch}  loss={loss.item():.4f}")
Why NeighborLoader matters: A node's 2-hop neighbourhood can have millions of nodes (high-degree hubs). Sampling caps per-batch compute and forces the model to generalise from partial neighbourhoods β€” the same idea that made Inception/Dropout work for CNNs.
Production
Deploying the Fraud Model
Sneha's outcome: 32% β†’ 9% false positive rate. The fintech team productionised the model. Sneha's GitHub repo got 800 stars and an internship offer.

πŸ“ Check Your Understanding (8 Questions)

1. Why do traditional ML models miss fraud rings that GNNs catch?
a) Traditional models can't process amounts above β‚Ή50,000
b) Traditional ML treats each transaction as an independent row of features and ignores the graph structure of relationships between accounts; fraud rings are defined by their structural pattern, not individual transactions
c) GNNs use newer hardware that traditional ML cannot access
d) Traditional ML always overfits on imbalanced classes
2. What is 'message passing' in a GNN?
a) A way for nodes to send each other email-like alerts about suspicious activity
b) The core operation where each node updates its representation by aggregating (e.g., averaging) feature vectors from its neighbours; stacking k layers means each node sees information from k hops away
c) A network protocol used by GraphSAGE to coordinate distributed training
d) A debugging mechanism that prints intermediate node states
3. Why does Sneha choose GraphSAGE over GCN for production fraud detection?
a) GraphSAGE always achieves higher accuracy than GCN on every benchmark
b) GraphSAGE samples a fixed-size neighbourhood per node, scales to millions of nodes via mini-batching, and is inductive β€” it can produce embeddings for brand-new accounts not seen during training
c) GCN does not support PyTorch Geometric
d) GraphSAGE was specifically designed for UPI fraud by NPCI
4. What does the class_weights = [1.0, 50.0] argument do?
a) It scales the learning rate differently for clean vs fraud nodes
b) It makes the loss function penalise misclassifying rare fraud cases 50Γ— more than misclassifying clean accounts, counteracting the severe class imbalance (~0.5% fraud rate)
c) It normalises node degrees in the message-passing aggregation
d) It controls the dropout rate for fraud vs clean predictions
5. What problem does NeighborLoader solve?
a) It encrypts neighbour data for privacy compliance
b) A node's full k-hop neighbourhood can explode to millions of nodes (especially with high-degree hubs); NeighborLoader samples a fixed number of neighbours per hop, capping per-batch compute and enabling training on huge graphs
c) It removes neighbours that have already been seen in previous batches
d) It loads neighbour features from disk one at a time to save RAM
6. Why does Sneha need GNNExplainer in production?
a) To make the model run faster on CPU
b) When the model flags an account as fraud, regulators and the fraud team need to know which neighbours and edges drove the prediction; GNNExplainer highlights the influential subgraph, supporting audit and appeal processes
c) To convert the model into a decision tree for serving
d) To compress the model for mobile deployment
7. What is the cold-start problem for GNN fraud detection?
a) The model fails when training data is older than 6 months
b) Brand-new accounts have no transaction edges yet, so their k-hop neighbourhoods are empty and the GNN cannot generate meaningful embeddings β€” fallback to feature-only models is needed for the first few days
c) GPUs need to warm up for 10 minutes before serving GNN models
d) The first batch of every epoch always has anomalously high loss
8. Why does stacking 2 SAGEConv layers help catch fraud rings?
a) Two layers always produce more accurate predictions than one
b) With 2 layers each node aggregates information from neighbours-of-neighbours (2-hop) β€” fraud rings often span exactly 2–3 hops (sender β†’ mule β†’ cash-out), so 2 hops capture the structural pattern
c) PyTorch Geometric requires at least 2 layers per GNN
d) One layer cannot be combined with dropout
← Lesson 3: Distributed Training Lesson 5: Diffusion Models β†’