Introduction
La distillation de modèles permet de transférer les connaissances d'un modèle volumineux (teacher) vers un modèle plus léger (student). Cette technique est essentielle en 2026 pour déployer des LLM sur des infrastructures contraintes. Elle combine la supervision classique avec une perte de distillation qui aligne les logits ou les représentations internes. Contrairement au simple pruning, elle préserve les capacités de généralisation. Dans ce tutoriel avancé, nous construirons un pipeline complet en PyTorch pour distiller un modèle BERT-like vers une version réduite, avec gestion des températures et des poids de loss.
Prérequis
- Python 3.11+
- PyTorch 2.4+
- Transformers 4.45+
- CUDA 12.1 (recommandé)
- Connaissances solides en deep learning et optimisation
Configuration et dépendances
pip install torch==2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install transformers==4.45.0 datasets accelerateInstallation des versions stables compatibles avec la distillation avancée. CUDA est requis pour entraîner efficacement teacher et student en parallèle.
Définition des modèles teacher et student
from transformers import AutoModelForSequenceClassification, AutoConfig
import torch.nn as nn
class DistillationModel(nn.Module):
def __init__(self, teacher_name="bert-base-uncased", student_name="bert-base-uncased", num_labels=2):
super().__init__()
self.teacher = AutoModelForSequenceClassification.from_pretrained(teacher_name, num_labels=num_labels)
self.student = AutoModelForSequenceClassification.from_pretrained(student_name, num_labels=num_labels)
# Réduire la taille du student
config = AutoConfig.from_pretrained(student_name)
config.hidden_size = 384
config.num_hidden_layers = 6
self.student = AutoModelForSequenceClassification.from_config(config)
def forward(self, input_ids, attention_mask, labels=None):
with torch.no_grad():
teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=attention_mask)
student_outputs = self.student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
return teacher_outputs, student_outputsCe module charge un teacher figé et initialise un student plus petit. Les logits du teacher sont extraits sans gradient pour servir de cible soft.
Implémentation de la loss de distillation
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, temperature=4.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, teacher_logits, student_logits, labels):
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
hard_loss = self.ce_loss(student_logits, labels)
return self.alpha * distill_loss + (1 - self.alpha) * hard_lossLa loss combine KL-divergence sur les logits adoucis (distillation) et cross-entropy hard. Le paramètre température contrôle la douceur de la distribution.
Boucle d'entraînement complète
from torch.optim import AdamW
from torch.utils.data import DataLoader
model = DistillationModel()
loss_fn = DistillationLoss(temperature=4.0, alpha=0.7)
optimizer = AdamW(model.student.parameters(), lr=2e-5)
model.train()
for epoch in range(3):
for batch in train_loader:
optimizer.zero_grad()
teacher_out, student_out = model(batch['input_ids'], batch['attention_mask'], batch['labels'])
loss = loss_fn(teacher_out.logits, student_out.logits, batch['labels'])
loss.backward()
optimizer.step()Entraînement uniquement des paramètres du student. Le teacher reste en mode no_grad pour économiser la mémoire et stabiliser l'apprentissage.
Script d'évaluation et export
from sklearn.metrics import accuracy_score
import torch
def evaluate(model, dataloader):
model.eval()
preds, labels = [], []
with torch.no_grad():
for batch in dataloader:
_, outputs = model(batch['input_ids'], batch['attention_mask'])
preds.extend(torch.argmax(outputs.logits, dim=-1).cpu().numpy())
labels.extend(batch['labels'].cpu().numpy())
return accuracy_score(labels, preds)
# Export du student
model.student.save_pretrained("./distilled-student")Évaluation stricte du student seul. Sauvegarde finale au format Hugging Face pour déploiement immédiat.
Bonnes pratiques
- Toujours figer le teacher et ne jamais propager ses gradients
- Ajuster température et alpha via validation croisée
- Utiliser un scheduler avec warmup pour le student
- Surveiller la divergence KL pendant l'entraînement
- Valider la taille finale du modèle et la latence en inférence
Erreurs courantes à éviter
- Oublier de diviser les logits par la température dans les deux directions
- Entraîner le teacher en même temps que le student
- Ignorer le scaling de la loss KL par température²
- Utiliser un batch size trop faible qui dégrade la qualité des distributions soft
Pour aller plus loin
Approfondissez la distillation multi-couches et les techniques de quantization post-distillation dans nos formations Learni.