Skip to content
Learni
View all tutorials
Machine Learning

Comment implémenter la distillation de modèles en 2026

Introduction

La distillation de modèles permet de transférer les connaissances d'un modèle volumineux (teacher) vers un modèle plus léger (student). Cette technique est essentielle en 2026 pour déployer des LLM sur des infrastructures contraintes. Elle combine la supervision classique avec une perte de distillation qui aligne les logits ou les représentations internes. Contrairement au simple pruning, elle préserve les capacités de généralisation. Dans ce tutoriel avancé, nous construirons un pipeline complet en PyTorch pour distiller un modèle BERT-like vers une version réduite, avec gestion des températures et des poids de loss.

Prérequis

  • Python 3.11+
  • PyTorch 2.4+
  • Transformers 4.45+
  • CUDA 12.1 (recommandé)
  • Connaissances solides en deep learning et optimisation

Configuration et dépendances

terminal
pip install torch==2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install transformers==4.45.0 datasets accelerate

Installation des versions stables compatibles avec la distillation avancée. CUDA est requis pour entraîner efficacement teacher et student en parallèle.

Définition des modèles teacher et student

models.py
from transformers import AutoModelForSequenceClassification, AutoConfig
import torch.nn as nn

class DistillationModel(nn.Module):
    def __init__(self, teacher_name="bert-base-uncased", student_name="bert-base-uncased", num_labels=2):
        super().__init__()
        self.teacher = AutoModelForSequenceClassification.from_pretrained(teacher_name, num_labels=num_labels)
        self.student = AutoModelForSequenceClassification.from_pretrained(student_name, num_labels=num_labels)
        # Réduire la taille du student
        config = AutoConfig.from_pretrained(student_name)
        config.hidden_size = 384
        config.num_hidden_layers = 6
        self.student = AutoModelForSequenceClassification.from_config(config)
    
    def forward(self, input_ids, attention_mask, labels=None):
        with torch.no_grad():
            teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=attention_mask)
        student_outputs = self.student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return teacher_outputs, student_outputs

Ce module charge un teacher figé et initialise un student plus petit. Les logits du teacher sont extraits sans gradient pour servir de cible soft.

Implémentation de la loss de distillation

distillation_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, teacher_logits, student_logits, labels):
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
        hard_loss = self.ce_loss(student_logits, labels)
        return self.alpha * distill_loss + (1 - self.alpha) * hard_loss

La loss combine KL-divergence sur les logits adoucis (distillation) et cross-entropy hard. Le paramètre température contrôle la douceur de la distribution.

Boucle d'entraînement complète

train.py
from torch.optim import AdamW
from torch.utils.data import DataLoader

model = DistillationModel()
loss_fn = DistillationLoss(temperature=4.0, alpha=0.7)
optimizer = AdamW(model.student.parameters(), lr=2e-5)

model.train()
for epoch in range(3):
    for batch in train_loader:
        optimizer.zero_grad()
        teacher_out, student_out = model(batch['input_ids'], batch['attention_mask'], batch['labels'])
        loss = loss_fn(teacher_out.logits, student_out.logits, batch['labels'])
        loss.backward()
        optimizer.step()

Entraînement uniquement des paramètres du student. Le teacher reste en mode no_grad pour économiser la mémoire et stabiliser l'apprentissage.

Script d'évaluation et export

evaluate.py
from sklearn.metrics import accuracy_score
import torch

def evaluate(model, dataloader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            _, outputs = model(batch['input_ids'], batch['attention_mask'])
            preds.extend(torch.argmax(outputs.logits, dim=-1).cpu().numpy())
            labels.extend(batch['labels'].cpu().numpy())
    return accuracy_score(labels, preds)

# Export du student
model.student.save_pretrained("./distilled-student")

Évaluation stricte du student seul. Sauvegarde finale au format Hugging Face pour déploiement immédiat.

Bonnes pratiques

  • Toujours figer le teacher et ne jamais propager ses gradients
  • Ajuster température et alpha via validation croisée
  • Utiliser un scheduler avec warmup pour le student
  • Surveiller la divergence KL pendant l'entraînement
  • Valider la taille finale du modèle et la latence en inférence

Erreurs courantes à éviter

  • Oublier de diviser les logits par la température dans les deux directions
  • Entraîner le teacher en même temps que le student
  • Ignorer le scaling de la loss KL par température²
  • Utiliser un batch size trop faible qui dégrade la qualité des distributions soft

Pour aller plus loin

Approfondissez la distillation multi-couches et les techniques de quantization post-distillation dans nos formations Learni.