Skip to content
Learni
View all tutorials
Machine Learning

How to Create a Classification Model with TensorFlow in 2026

Lire en français

Introduction

TensorFlow, Google's open-source framework, remains the cornerstone of production machine learning in 2026 thanks to its maturity, Keras ecosystem, and native GPU/TPU support. This intermediate tutorial guides you through creating a convolutional neural network (CNN) to classify clothing images from the Fashion MNIST dataset.

Why this project? It demonstrates the full workflow: data preparation, modular architecture, optimized training, and deployment. By the end, you'll master Keras APIs to scale to real-world tasks like object detection.

Tailored for developers familiar with Python and ML basics, this guide delivers 100% functional code, tested on TensorFlow 2.17+. Estimated time: 20 minutes for your first run. Ready to level up your deep learning skills?

Prerequisites

  • Python 3.10+ installed
  • Virtual environment (venv or conda)
  • Basic knowledge of NumPy, Pandas, and supervised ML
  • Optional GPU (CUDA 12+ for acceleration)
  • IDE like VS Code or Jupyter Notebook

Installing TensorFlow

install.sh
python -m venv tf_env
source tf_env/bin/activate  # Sur Windows : tf_env\Scripts\activate
pip install tensorflow==2.17.0 matplotlib numpy scikit-learn
pip list | grep tensorflow

This script sets up an isolated virtual environment, installs TensorFlow 2.17 (stable in 2026) and dependencies for visualization and metrics. Verify with pip list. Stick to stable releases for production; enable GPU with tensorflow-gpu if CUDA is set up.

Loading and Exploring the Data

Fashion MNIST is a standard dataset of 70,000 28x28 grayscale images (10 classes: t-shirt, trouser, etc.). Load it via tf.keras.datasets. Analogy: like a digitized e-commerce catalog, ready for training without external preprocessing.

Visualize to confirm: 60k training, 10k test images. Normalize pixels to 0-1 range to stabilize gradients.

Loading and Preparing the Dataset

data_prep.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Chargement du dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# Normalisation
train_images = train_images / 255.0
test_images = test_images / 255.0

# Classes
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# Visualisation exemple
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

print(f'Train: {train_images.shape}, Test: {test_images.shape}')

This code loads, normalizes, and visualizes the dataset. Dividing by 255 scales pixels for faster convergence. Use tf.data.Dataset for batching in production. Pitfall: forgetting to reshape for CNN (28x28x1) – handled later.

Building the CNN Architecture

A typical CNN uses Conv2D layers for feature extraction (filters), MaxPooling2D for dimensionality reduction, Dropout to prevent overfitting, and Dense layers for classification. Analogy: like a successive pattern detector (edges → textures → objects).

We use Sequential for simplicity; switch to Functional API for complex branches.

Defining the CNN Model

model_build.py
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.summary()

# Reshape des données pour CNN
train_images = np.expand_dims(train_images, -1)
test_images = np.expand_dims(test_images, -1)

print(f'Shape train après reshape: {train_images.shape}')

Lightweight CNN architecture: 3 growing Conv2D layers, pooling for translation invariance, Flatten to Dense. summary() shows params (~130k). Pitfall: input_shape only on first layer; reshape adds channel (1 for grayscale).

Compiling and Training

model_train.py
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    train_images, train_labels,
    epochs=10,
    validation_split=0.2,
    batch_size=128,
    verbose=1
)

# Visualisation courbes
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Adam optimizer adapts dynamically; sparse_crossentropy for integer labels. 10 epochs hit ~90% accuracy. Validation_split monitors overfitting. Increase batch_size for GPU; use callbacks like EarlyStopping in production.

Evaluation and Predictions

test_loss, test_acc measures generalization. Get predictions with predict() and argmax for classes. Analogy: like A/B testing the model on unseen data.

Evaluation and Predictions

model_eval.py
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')

# Prédictions
predictions = model.predict(test_images)
pred_label = np.argmax(predictions[0])
true_label = test_labels[0]

print(f'Prédiction: {class_names[pred_label]}, Vrai: {class_names[true_label]}')

# Matrice de confusion (bonus)
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
pred_classes = np.argmax(predictions, axis=1)
cm = confusion_matrix(test_labels, pred_classes)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.ylabel('Vrai')
plt.xlabel('Prédit')
plt.show()
print(classification_report(test_labels, pred_classes, target_names=class_names))

Quantitative evaluation; confusion matrix visualizes per-class errors. argmax on softmax outputs. Install seaborn if needed. Pitfall: overlooking minority classes – the report flags it.

Saving and Loading the Model

model_save.py
# Sauvegarde
model.save('fashion_cnn_model.keras')

# Chargement et inférence
loaded_model = tf.keras.models.load_model('fashion_cnn_model.keras')
loaded_pred = loaded_model.predict(np.expand_dims(test_images[0], 0))
print(f'Chargé - Prédiction: {class_names[np.argmax(loaded_pred[0])]}')

# Export TensorFlow Lite pour mobile
converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)
tflite_model = converter.convert()
with open('fashion_cnn.tflite', 'wb') as f:
    f.write(tflite_model)
print('Modèle TFLite sauvé.')

Native .keras format (2026 standard). TFLite for edge devices. Loading preserves weights. Pitfall: TensorFlow version mismatch – use SavedModel for multi-platform compatibility.

Best Practices

  • Callbacks: EarlyStopping, ModelCheckpoint for automation.
  • Data augmentation: ImageDataGenerator for robustness (rotations, flips).
  • Hyperparameter tuning: Keras Tuner or Optuna for layers/learning rate.
  • Monitoring: Built-in TensorBoard (tf.keras.callbacks.TensorBoard).
  • Scalability: tf.data for massive datasets; MirroredStrategy for multi-GPU.

Common Errors to Avoid

  • Forgetting channel reshape (ValueError shape mismatch).
  • No validation_split: can't detect overfitting.
  • Default learning rate too high: loss divergence (use ReduceLROnPlateau).
  • Saving without compile: metrics lost on reload.

Next Steps