Image Classification with Transfer Learning ๐Ÿ–ผ๏ธ

Class 10Age 14โ€“15Lesson 2 of 12๐Ÿ†“ Free
Student in Bengaluru looking at transfer learning diagram on laptop โ€” MobileNet layers visualised, custom classification head shown
Watch first - 2-3 minutes

Class 10 Lesson 2 - Image Classification with Transfer Learning

No sign-in needed - English narration - Safe for all school ages

Meet Kiran โ€” Class 10, Bengaluru

Kiran wants to build an app that identifies Karnataka flowers โ€” jasmine, rose, hibiscus, lotus, and champaka. She has about 200 photos per flower species collected from her garden and nearby parks. She tried training a CNN from scratch (Lesson 1 approach) and after 30 epochs got only 60% accuracy. "200 photos is just not enough," she thought.

Her uncle at an AI company said: "Use transfer learning. MobileNetV2 was trained on 1.4 million ImageNet photos. It already knows how to detect edges, textures, petals, and colour gradients. You just need to teach it the difference between Karnataka flowers. 20 minutes in Colab." Kiran followed his advice โ€” and hit 92% accuracy.

The Core Idea
What Is Transfer Learning?

Transfer learning means taking a model that was trained on a large dataset and reusing its learned features for a new, smaller task. Instead of learning filters from scratch, you start from filters that already know about shapes, textures, and patterns from millions of images.

๐ŸงŠ
Frozen Base Model (MobileNetV2)
~150 layers pre-trained on ImageNet 1.4M images. Weights are locked โ€” we don't retrain these. They already know how to extract features from images.
๐Ÿ”ฅ
Trainable Classification Head
A new Dense layer you add on top. Only these weights are trained on your small dataset. Very fast โ€” only a few thousand parameters.
Why this works: The early layers of any CNN trained on natural images detect the same things โ€” edges, textures, colour blobs. Whether the image is a dog or a flower, the low-level features are the same. Transfer learning reuses this shared knowledge.
Popular Pre-trained Models
Which Base Model to Choose

All available in tensorflow.keras.applications โ€” pre-downloaded with ImageNet weights in one line.

Full Code
Fine-Tune MobileNetV2 on Your Own Images
# Transfer Learning with MobileNetV2 โ€” Google Colab
# Example: Flower classifier (or replace with any image dataset)

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import MobileNetV2
import matplotlib.pyplot as plt
import numpy as np

# โ”€โ”€ Step 1: Load dataset โ”€โ”€
# Using tf_flowers dataset as a stand-in for your own photos
# For your own images, use ImageDataGenerator.flow_from_directory()
import tensorflow_datasets as tfds

(ds_train, ds_val, ds_test), info = tfds.load(
    'tf_flowers',
    split=['train[:70%]', 'train[70%:85%]', 'train[85%:]'],
    as_supervised=True,
    with_info=True
)
NUM_CLASSES = info.features['label'].num_classes  # 5 flower types
print(f"Classes: {info.features['label'].names}")

# โ”€โ”€ Step 2: Preprocess (resize and normalise) โ”€โ”€
IMG_SIZE = 224  # MobileNetV2 expects 224ร—224

def preprocess(image, label):
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

BATCH_SIZE = 32
ds_train = ds_train.map(preprocess).shuffle(1000).batch(BATCH_SIZE).prefetch(1)
ds_val   = ds_val.map(preprocess).batch(BATCH_SIZE).prefetch(1)
ds_test  = ds_test.map(preprocess).batch(BATCH_SIZE).prefetch(1)

# โ”€โ”€ Step 3: Load MobileNetV2 base (frozen) โ”€โ”€
base_model = MobileNetV2(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,    # remove ImageNet classification head
    weights='imagenet'    # download pre-trained weights
)
base_model.trainable = False  # FREEZE โ€” don't retrain these layers
print(f"Base model parameters: {base_model.count_params():,}")

# โ”€โ”€ Step 4: Add classification head โ”€โ”€
inputs  = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x       = base_model(inputs, training=False)  # frozen base
x       = layers.GlobalAveragePooling2D()(x)  # pool feature maps โ†’ 1D vector
x       = layers.Dense(128, activation='relu')(x)
x       = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

model = keras.Model(inputs, outputs)
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
print(f"Trainable parameters: {sum([tf.size(w).numpy() for w in model.trainable_variables]):,}")

# โ”€โ”€ Step 5: Phase 1 โ€” Train the head only โ”€โ”€
print("\nPhase 1: Training classification head...")
history1 = model.fit(ds_train, epochs=5, validation_data=ds_val)

# โ”€โ”€ Step 6: Phase 2 โ€” Fine-tune top layers of base model โ”€โ”€
# Unfreeze the top 30 layers for fine-tuning
base_model.trainable = True
for layer in base_model.layers[:-30]:
    layer.trainable = False

# Recompile with a lower learning rate (important!)
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("\nPhase 2: Fine-tuning top layers...")
history2 = model.fit(ds_train, epochs=10, validation_data=ds_val)

# โ”€โ”€ Step 7: Evaluate on test set โ”€โ”€
test_loss, test_acc = model.evaluate(ds_test)
print(f"\nTest accuracy after fine-tuning: {test_acc:.2%}")

# โ”€โ”€ Step 8: Predict on a single image โ”€โ”€
flower_names = info.features['label'].names
for images, labels in ds_test.take(1):
    preds = model.predict(images[:5])
    for i in range(5):
        print(f"Predicted: {flower_names[np.argmax(preds[i])]:15s}  "
              f"Actual: {flower_names[labels[i].numpy()]}")
Typical results: Phase 1 alone (5 epochs, head only) โ†’ ~80โ€“85% accuracy. After Phase 2 fine-tuning (10 more epochs) โ†’ 90โ€“95% on flowers. Training time in Colab (T4 GPU): under 15 minutes.
Using Your Own Photos
Folder Structure for Custom Datasets
# If you have your own images, organise them like this:
#
# my_dataset/
# โ”œโ”€โ”€ train/
# โ”‚   โ”œโ”€โ”€ healthy/       (put healthy fruit photos here)
# โ”‚   โ”œโ”€โ”€ diseased/      (put diseased fruit photos here)
# โ”‚   โ””โ”€โ”€ unripe/
# โ”œโ”€โ”€ val/
# โ”‚   โ”œโ”€โ”€ healthy/
# โ”‚   โ”œโ”€โ”€ diseased/
# โ”‚   โ””โ”€โ”€ unripe/
# โ””โ”€โ”€ test/
#     โ”œโ”€โ”€ healthy/
#     โ”œโ”€โ”€ diseased/
#     โ””โ”€โ”€ unripe/

from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Augment training images to improve generalisation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    horizontal_flip=True,
    zoom_range=0.2,
    shear_range=0.1
)
val_datagen = ImageDataGenerator(rescale=1./255)

train_gen = train_datagen.flow_from_directory(
    'my_dataset/train',
    target_size=(224, 224),
    batch_size=32,
    class_mode='sparse'
)
val_gen = val_datagen.flow_from_directory(
    'my_dataset/val',
    target_size=(224, 224),
    batch_size=32,
    class_mode='sparse'
)

# Replace ds_train / ds_val with train_gen / val_gen in model.fit()
# Google Colab: upload zip of your dataset with files.upload() or mount Google Drive
How many photos do you need? With transfer learning: 100โ€“200 per class often works. Without it: 2,000โ€“5,000+ per class. This is the practical power of transfer learning for student projects.

๐Ÿงช Check Your Understanding โ€” Lesson 2 Quiz

1. In transfer learning, `base_model.trainable = False` means:
a) The model cannot be used for predictions
b) The pre-trained base model's weights are frozen โ€” they won't be updated during training on your dataset
c) The model will train faster because no GPU is needed
d) The model uses random initial weights
2. Why is Phase 2 fine-tuning (unfreezing top layers) compiled with a much lower learning rate (1e-5)?
a) Lower learning rate is always better for all training phases
b) To avoid catastrophically overwriting the valuable pre-trained features โ€” small updates nudge the weights rather than destroying them
c) Because the dataset is small and a high learning rate would crash Colab
d) Lower learning rate makes the model use less memory
3. `include_top=False` when loading MobileNetV2 means:
a) Only the top 10 layers are loaded
b) The pre-trained ImageNet classification head (1000-class output) is removed, leaving only the feature extraction layers
c) The model runs without a GPU
d) The base model's weights are not downloaded
4. GlobalAveragePooling2D() is used in the classification head because:
a) It applies a softmax activation to the output
b) It converts the 3D feature volume output of the base model into a 1D vector suitable for a Dense layer, with fewer parameters than Flatten
c) It normalises pixel values to 0โ€“1
d) It applies data augmentation during training
5. MobileNetV2 is preferred over VGG16 for student projects on small hardware because:
a) VGG16 is too old and no longer available in Keras
b) MobileNetV2 is designed for efficiency โ€” far fewer parameters (~3.4M vs ~138M), runs faster, uses less memory, and can even be deployed on mobile devices
c) MobileNetV2 was trained on more images than VGG16
d) VGG16 doesn't support image classification tasks
6. In `ImageDataGenerator`, `horizontal_flip=True` is used to:
a) Mirror images so the model learns that a flipped flower is still a flower โ€” increases effective dataset size
b) Flip labels so the model sees balanced classes
c) Convert colour images to greyscale
d) Rotate images by exactly 180 degrees
7. Why is `flow_from_directory` useful for custom image datasets?
a) It automatically labels images based on their pixel colours
b) It reads images from a folder structure where each subfolder is a class โ€” labels are inferred from folder names, no manual CSV needed
c) It uploads images to AWS S3 automatically
d) It works only with black and white images
8. Transfer learning is most useful when:
a) You have millions of labelled images and plenty of compute
b) Your task is completely unlike anything in the pre-trained dataset
c) You have a small dataset (100โ€“1000 images per class) for a visual task similar to natural images โ€” transfer learning dramatically outperforms training from scratch
d) You are working with tabular data, not images
โ† Lesson 1: CNNs Lesson 3: Object Detection โ†’