Skip to content
Learni
Voir tous les tutoriels
Machine Learning

Comment entraîner un modèle CNN avec TensorFlow en 2026

Read in English

Introduction

TensorFlow, framework open-source de Google, domine le deep learning en 2026 grâce à son intégration native de Keras pour une API intuitive et performante. Ce tutoriel intermédiaire vous guide pour entraîner un réseau de neurones convolutif (CNN) sur le dataset Fashion MNIST, un benchmark classique de classification d'images (10 classes comme t-shirts, pantalons).

Pourquoi ce sujet ? Les CNN excellent dans la vision par ordinateur, essentiels pour applications comme la reconnaissance d'objets en production. Vous apprendrez à charger des données, bâtir une architecture robuste avec Conv2D et MaxPooling, optimiser l'entraînement via Adam et callbacks, puis évaluer précisément.

À la fin, vous obtiendrez un modèle à >90% d'accuracy, prêt pour déploiement (TensorFlow Lite). Structure progressive : bases → complexité, avec code 100% fonctionnel. Durée estimée : 20 min d'exécution sur CPU standard.

Prérequis

  • Python 3.9+ installé
  • Bases en Python et NumPy/Pandas
  • Connaissances introductives en machine learning (réseaux de neurones, loss functions)
  • Environnement virtuel (venv ou conda) recommandé
  • Accès internet pour pip install (une seule fois)

Installation de TensorFlow

install.sh
python -m venv tf_env
source tf_env/bin/activate  # Sur Windows : tf_env\Scripts\activate
pip install --upgrade pip
pip install tensorflow==2.16.1
pip install matplotlib seaborn
python -c "import tensorflow as tf; print(tf.__version__)" # Vérification

Ce script crée un environnement virtuel isolé, installe TensorFlow 2.16 (stable en 2026) et ses dépendances pour visualisation. La commande de vérification confirme l'installation sans erreurs. Évitez les versions GPU si pas de CUDA configuré, pour simplicité CPU.

Chargement et exploration des données

Fashion MNIST (70k images 28x28 grayscale, 60k train / 10k test) simule des vêtements réels, plus challenging que MNIST chiffré. On le charge via tf.keras.datasets, split auto. Visualisez pour valider : shapes (60000, 28, 28, 1), labels [0-9]. Analogie : comme préparer une recette, normaliser pixels (0-1) évite gradients explosifs.

Préparation du dataset

prepare_data.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Chargement
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# Normalisation et reshape
train_images = train_images.reshape(60000, 28, 28, 1).astype('float32') / 255.0
test_images = test_images.reshape(10000, 28, 28, 1).astype('float32') / 255.0

# Labels en one-hot (optionnel pour sparse_categorical)
train_labels = tf.keras.utils.to_categorical(train_labels, 10)
test_labels = tf.keras.utils.to_categorical(test_labels, 10)

# Visualisation exemple
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i].reshape(28, 28), cmap=plt.cm.binary)
    plt.xlabel(train_labels[i].argmax())
plt.show()

print(f'Train: {train_images.shape}, Test: {test_images.shape}')

Ce code charge, normalise (division par 255 pour [0,1]) et reshape en (samples, height, width, channels=1). One-hot encoding pour labels multiclasse. Visualisation matplotlib confirme intégrité. Piège : oublier reshape cause 'channels_last' error dans Conv2D.

Construction de l'architecture CNN

Un CNN typique : Conv2D extrait features (filtres), MaxPooling réduit dimensions (invariance translation), Dropout prévient overfitting. Flatten → Dense pour classification. Analogie : comme un détecteur de motifs successifs (bords → textures → objets).

Définition du modèle Sequential

build_model.py
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    Flatten(),
    Dense(64, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

model.summary()

Architecture VGG-like simplifiée : 3 Conv2D croissants (32→64 filtres), 2 MaxPool (downsample /2), Dropout 50% anti-overfit. summary() affiche params (~1M). Input_shape fixe pour grayscale. Piège : sans relu, vanishing gradients ; testez model.output_shape par couche.

Compilation et entraînement avec callbacks

train_model.py
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Compilation
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Callbacks
callbacks = [
    EarlyStopping(patience=3, restore_best_weights=True),
    ModelCheckpoint('best_model.keras', save_best_only=True)
]

# Entraînement
history = model.fit(train_images, train_labels,
                    batch_size=128,
                    epochs=20,
                    validation_split=0.2,
                    callbacks=callbacks,
                    verbose=1)

# Plot history
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Val')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Adam optimise (adaptatif), categorical_crossentropy pour multiclasse softmax. Batch 128 accélère, validation_split=0.2 monitore overfitting. EarlyStopping arrête si stagnation, Checkpoint sauve best. Attendez ~92% val_accuracy après 10 epochs. Piège : epochs trop hauts sans callbacks = waste compute.

Évaluation et prédictions

evaluate_predict.py
# Évaluation
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)
print(f'Test accuracy: {test_acc:.4f}')

# Prédictions sur 10 exemples
predictions = model.predict(test_images[:10])
labels = ['T-shirt', 'Pantalon', 'Pull', 'Robe', 'Manteau', 'Sandale', 'Chemise', 'Sneaker', 'Sac', 'Bottine']

import numpy as np
for i in range(10):
    predicted_label = np.argmax(predictions[i])
    true_label = np.argmax(test_labels[i])
    print(f'Ex {i}: Prédit={labels[predicted_label]}, Vrai={labels[true_label]}, Conf={predictions[i][predicted_label]:.2f}')

Evaluate sur test set donne métrique finale (~90%). Predict retourne proba softmax ; argmax pour classe. Affichage concret valide. Chargez 'best_model.keras' si besoin : model = tf.keras.models.load_model('best_model.keras'). Piège : ne pas utiliser test set pour train.

Sauvegarde et export TensorFlow Lite

save_export.py
# Sauvegarde complète
model.save('fashion_cnn_full.keras')

# Export Lite pour mobile/edge
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('fashion_cnn.tflite', 'wb') as f:
    f.write(tflite_model)

# Vérif Lite
interpreter = tf.lite.Interpreter(model_path='fashion_cnn.tflite')
interpreter.allocate_tensors()
print('TensorFlow Lite exporté avec succès !')

Save Keras pour reload facile. TFLite optimise (quantization) pour inférence rapide/low-mémoire. ~90% taille réduite, accuracy quasi-identique. Idéal déploiement Android/iOS. Piège : oublier optimizations = modèle gonflé.

Bonnes pratiques

  • Toujours splitter train/val/test (80/10/10) pour détection overfitting via courbes history.
  • Utilisez callbacks (ReduceLROnPlateau pour LR decay auto) sur datasets >10k samples.
  • Data Augmentation : ImageDataGenerator(rotation=10, zoom=0.1) booste généralisation (+5% acc).
  • Monitorez TensorBoard : tf.keras.callbacks.TensorBoard(log_dir='./logs') pour visuals live.
  • Seed reproductible : tf.random.set_seed(42) pour runs identiques.

Erreurs courantes à éviter

  • Oubli normalisation : pixels [0,255] causent NaN losses ; toujours /255 ou StandardScaler.
  • Batch_size trop petit (>512 GPU ok, mais CPU : 64-256 pour stabilité gradients).
  • Pas de Dropout/Reg : overfitting sur val (>10% gap train/val) ; ajoutez L2 kernel_regularizer.
  • Input_shape mismatch : Conv2D attend (H,W,C) ; vérifiez model.summary(). Ne pas fit sans compile.

Pour aller plus loin

Maîtrisez Transfer Learning avec MobileNetV2 sur custom datasets. Explorez TensorFlow docs. Pour pros : formations avancées Learni Group sur déploiement MLOps et GANs. Testez sur Kaggle Fashion MNIST competitions.