Skip to content
Learni
View all tutorials
Machine Learning

How to Train a CNN Model with TensorFlow in 2026

Lire en français

Introduction

TensorFlow, Google's open-source framework, leads deep learning in 2026 with its native Keras integration for an intuitive and high-performance API. This intermediate tutorial guides you through training a convolutional neural network (CNN) on the Fashion MNIST dataset, a classic image classification benchmark (10 classes like t-shirts, pants).

Why this topic? CNNs excel in computer vision, crucial for production apps like object recognition. You'll learn to load data, build a robust architecture with Conv2D and MaxPooling, optimize training with Adam and callbacks, and evaluate precisely.

By the end, you'll have a model with >90% accuracy, ready for deployment (TensorFlow Lite). Progressive structure: basics → complexity, with 100% functional code. Estimated runtime: 20 min on a standard CPU.

Prerequisites

  • Python 3.9+ installed
  • Basic Python and NumPy/Pandas knowledge
  • Introductory machine learning concepts (neural networks, loss functions)
  • Virtual environment (venv or conda) recommended
  • Internet access for pip install (one time only)

Installing TensorFlow

install.sh
python -m venv tf_env
source tf_env/bin/activate  # On Windows: tf_env\Scripts\activate
pip install --upgrade pip
pip install tensorflow==2.16.1
pip install matplotlib seaborn
python -c "import tensorflow as tf; print(tf.__version__)" # Verification

This script creates an isolated virtual environment, installs TensorFlow 2.16 (stable in 2026) and visualization dependencies. The verification command confirms the installation without errors. Skip GPU versions if no CUDA setup, for simple CPU usage.

Loading and Exploring the Data

Fashion MNIST (70k 28x28 grayscale images, 60k train / 10k test) mimics real clothing, more challenging than digit MNIST. Load it via tf.keras.datasets with automatic splits. Visualize to validate: shapes (60000, 28, 28, 1), labels [0-9]. Think of it like prepping a recipe—normalize pixels to (0-1) to avoid exploding gradients.

Preparing the Dataset

prepare_data.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Loading
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# Normalization and reshape
train_images = train_images.reshape(60000, 28, 28, 1).astype('float32') / 255.0
test_images = test_images.reshape(10000, 28, 28, 1).astype('float32') / 255.0

# One-hot labels (optional for sparse_categorical)
train_labels = tf.keras.utils.to_categorical(train_labels, 10)
test_labels = tf.keras.utils.to_categorical(test_labels, 10)

# Example visualization
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i].reshape(28, 28), cmap=plt.cm.binary)
    plt.xlabel(train_labels[i].argmax())
plt.show()

print(f'Train: {train_images.shape}, Test: {test_images.shape}')

This code loads, normalizes (divide by 255 for [0,1] range), and reshapes to (samples, height, width, channels=1). One-hot encoding for multiclass labels. Matplotlib visualization confirms data integrity. Pitfall: forgetting reshape causes 'channels_last' error in Conv2D.

Building the CNN Architecture

A typical CNN uses Conv2D to extract features (filters), MaxPooling to reduce dimensions (translation invariance), and Dropout to prevent overfitting. Flatten → Dense for classification. Analogy: like layered pattern detectors (edges → textures → objects).

Defining the Sequential Model

build_model.py
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    Flatten(),
    Dense(64, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

model.summary()

Simplified VGG-like architecture: 3 Conv2D layers increasing (32→64 filters), 2 MaxPool (downsample /2), 50% Dropout against overfitting. summary() shows params (~1M). Input_shape set for grayscale. Pitfall: no ReLU leads to vanishing gradients; check model.output_shape per layer.

Compilation and Training with Callbacks

train_model.py
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Compilation
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Callbacks
callbacks = [
    EarlyStopping(patience=3, restore_best_weights=True),
    ModelCheckpoint('best_model.keras', save_best_only=True)
]

# Training
history = model.fit(train_images, train_labels,
                    batch_size=128,
                    epochs=20,
                    validation_split=0.2,
                    callbacks=callbacks,
                    verbose=1)

# Plot history
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Val')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Adam optimizer (adaptive), categorical_crossentropy for multiclass softmax. Batch size 128 speeds up training, validation_split=0.2 monitors overfitting. EarlyStopping halts on stagnation, Checkpoint saves the best. Expect ~92% val_accuracy after 10 epochs. Pitfall: too many epochs without callbacks wastes compute.

Evaluation and Predictions

evaluate_predict.py
# Evaluation
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)
print(f'Test accuracy: {test_acc:.4f}')

# Predictions on 10 examples
predictions = model.predict(test_images[:10])
labels = ['T-shirt', 'Pantalon', 'Pull', 'Robe', 'Manteau', 'Sandale', 'Chemise', 'Sneaker', 'Sac', 'Bottine']

import numpy as np
for i in range(10):
    predicted_label = np.argmax(predictions[i])
    true_label = np.argmax(test_labels[i])
    print(f'Ex {i}: Prédit={labels[predicted_label]}, Vrai={labels[true_label]}, Conf={predictions[i][predicted_label]:.2f}')

Evaluate on test set for final metrics (~90%). Predict returns softmax probabilities; argmax gets the class. Concrete display validates results. Load 'best_model.keras' if needed: model = tf.keras.models.load_model('best_model.keras'). Pitfall: don't use test set for training.

Saving and Exporting to TensorFlow Lite

save_export.py
# Full save
model.save('fashion_cnn_full.keras')

# Lite export for mobile/edge
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('fashion_cnn.tflite', 'wb') as f:
    f.write(tflite_model)

# Lite verification
interpreter = tf.lite.Interpreter(model_path='fashion_cnn.tflite')
interpreter.allocate_tensors()
print('TensorFlow Lite exported successfully!')

Keras save for easy reload. TFLite optimizes (quantization) for fast/low-memory inference. ~90% size reduction, near-identical accuracy. Perfect for Android/iOS deployment. Pitfall: skipping optimizations bloats the model.

Best Practices

  • Always split train/val/test (80/10/10) to detect overfitting via history curves.
  • Use callbacks (ReduceLROnPlateau for auto LR decay) on datasets >10k samples.
  • Data Augmentation: ImageDataGenerator(rotation=10, zoom=0.1) boosts generalization (+5% acc).
  • Monitor with TensorBoard: tf.keras.callbacks.TensorBoard(log_dir='./logs') for live visuals.
  • Reproducible seed: tf.random.set_seed(42) for identical runs.

Common Errors to Avoid

  • Forgot normalization: [0,255] pixels cause NaN losses; always /255 or StandardScaler.
  • Too-small batch_size (>512 ok on GPU, but CPU: 64-256 for gradient stability).
  • No Dropout/Reg: overfitting on val (>10% train/val gap); add L2 kernel_regularizer.
  • Input_shape mismatch: Conv2D expects (H,W,C); check model.summary(). Don't fit without compile.

Next Steps

Master transfer learning with MobileNetV2 on custom datasets. Check TensorFlow docs. For pros: advanced training at Learni Group on MLOps deployment and GANs. Test on Kaggle Fashion MNIST competitions.