Skip to content
Learni
View all tutorials
Machine Learning

Comment créer un modèle de classification avec TensorFlow en 2026

Introduction

TensorFlow, framework open-source de Google, reste en 2026 le pilier du machine learning en production grâce à sa maturité, son écosystème Keras et son support GPU/TPU natif. Ce tutoriel intermédiaire vous guide pour créer un réseau de neurones convolutif (CNN) classifiant des images de vêtements (dataset Fashion MNIST).

Pourquoi ce projet ? Il illustre les flux complets : préparation de données, architecture modulaire, entraînement optimisé et déploiement. À la fin, vous maîtriserez les APIs Keras pour scaler vers des tâches réelles comme la détection d'objets.

Adapté aux développeurs connaissant Python et les bases du ML, ce guide délivre des codes 100% fonctionnels, testés sur TensorFlow 2.17+. Durée estimée : 20 minutes pour un premier run. Prêt à booster vos compétences en deep learning ?

Prérequis

  • Python 3.10+ installé
  • Environnement virtuel (venv ou conda)
  • Connaissances de base en NumPy, Pandas et ML supervisé
  • GPU optionnel (CUDA 12+ pour accélération)
  • IDE comme VS Code ou Jupyter Notebook

Installation de TensorFlow

install.sh
python -m venv tf_env
source tf_env/bin/activate  # Sur Windows : tf_env\Scripts\activate
pip install tensorflow==2.17.0 matplotlib numpy scikit-learn
pip list | grep tensorflow

Ce script crée un environnement virtuel isolé, installe TensorFlow 2.17 (stable en 2026) et ses dépendances pour visualisation et métriques. Vérifiez l'installation avec pip list. Évitez les versions nightly pour la stabilité en prod ; activez le GPU via tensorflow-gpu si CUDA est configuré.

Chargement et exploration des données

Fashion MNIST est un dataset standard de 70 000 images 28x28 en niveaux de gris (10 classes : t-shirt, pantalon...). On le charge via tf.keras.datasets. Analogie : comme un catalogue e-commerce numérisé, prêt pour l'entraînement sans preprocessing externe.

Visualisez pour valider : 60k train, 10k test. Normalisez les pixels (0-1) pour stabiliser les gradients.

Chargement et préparation du dataset

data_prep.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Chargement du dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# Normalisation
train_images = train_images / 255.0
test_images = test_images / 255.0

# Classes
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# Visualisation exemple
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

print(f'Train: {train_images.shape}, Test: {test_images.shape}')

Ce code charge, normalise et visualise le dataset. La division par 255 scale les pixels pour une convergence rapide. Ajout de tf.data.Dataset pour batching en prod. Piège : oublier la reshape pour CNN (28x28x1) – géré plus loin.

Construction de l'architecture CNN

Un CNN typique : Conv2D pour extraction de features (filtres), MaxPooling2D pour réduction dimensionnelle, Dropout anti-overfit, Dense pour classification. Analogie : comme un détecteur de motifs successifs (bords → textures → objets).

On utilise Sequential pour simplicité ; passez à Functional pour branches complexes.

Définition du modèle CNN

model_build.py
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.summary()

# Reshape des données pour CNN
train_images = np.expand_dims(train_images, -1)
test_images = np.expand_dims(test_images, -1)

print(f'Shape train après reshape: {train_images.shape}')

Architecture CNN légère : 3 Conv2D croissants, pooling pour invariance translationnelle, Flatten vers Dense. summary() affiche params (~130k). Piège : input_shape seulement sur première couche ; reshape ajoute canal (1 pour gris).

Compilation et entraînement

model_train.py
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    train_images, train_labels,
    epochs=10,
    validation_split=0.2,
    batch_size=128,
    verbose=1
)

# Visualisation courbes
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Adam optimise adaptativement ; sparse_crossentropy pour labels entiers. 10 epochs suffisent pour ~90% acc. Validation_split surveille l'overfit. Augmentez batch_size pour GPU ; callbacks comme EarlyStopping en prod.

Évaluation et prédictions

test_loss, test_acc mesure la généralisation. Prédictions via predict() avec argmax pour classes. Analogie : comme tester un modèle en A/B sur données vues.

Évaluation et prédictions

model_eval.py
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')

# Prédictions
predictions = model.predict(test_images)
pred_label = np.argmax(predictions[0])
true_label = test_labels[0]

print(f'Prédiction: {class_names[pred_label]}, Vrai: {class_names[true_label]}')

# Matrice de confusion (bonus)
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
pred_classes = np.argmax(predictions, axis=1)
cm = confusion_matrix(test_labels, pred_classes)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.ylabel('Vrai')
plt.xlabel('Prédit')
plt.show()
print(classification_report(test_labels, pred_classes, target_names=class_names))

Évaluation quantitative ; matrice confusion visualise erreurs par classe. argmax sur softmax. Installez seaborn si besoin. Piège : ignorer les classes minoritaires – report le détecte.

Sauvegarde et chargement du modèle

model_save.py
# Sauvegarde
model.save('fashion_cnn_model.keras')

# Chargement et inférence
loaded_model = tf.keras.models.load_model('fashion_cnn_model.keras')
loaded_pred = loaded_model.predict(np.expand_dims(test_images[0], 0))
print(f'Chargé - Prédiction: {class_names[np.argmax(loaded_pred[0])]}')

# Export TensorFlow Lite pour mobile
converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)
tflite_model = converter.convert()
with open('fashion_cnn.tflite', 'wb') as f:
    f.write(tflite_model)
print('Modèle TFLite sauvé.')

Format .keras natif (2026 standard). TFLite pour edge devices. Chargement préserve weights. Piège : version TensorFlow mismatch – utilisez saved_model pour compatibilité multi-plateforme.

Bonnes pratiques

  • Callbacks : EarlyStopping, ModelCheckpoint pour automatiser.
  • Data augmentation : ImageDataGenerator pour robustesse (rotations, flips).
  • Hyperparam tuning : Keras Tuner ou Optuna pour optimiser layers/learning rate.
  • Monitoring : TensorBoard intégré (tf.keras.callbacks.TensorBoard).
  • Scalabilité : tf.data pour datasets massifs ; MirroredStrategy pour multi-GPU.

Erreurs courantes à éviter

  • Oublier reshape canal (ValueError shape mismatch).
  • Pas de validation_split : détection overfit impossible.
  • Learning rate par défaut trop haut : divergence loss (use ReduceLROnPlateau).
  • Sauvegarde sans compile : métriques perdues au reload.

Pour aller plus loin