Skip to content
Learni
View all tutorials
Machine Learning

How to Implement Model Distillation in 2026

Lire en français

Introduction

Model distillation transfers knowledge from a large model (teacher) to a lighter model (student). This technique is essential in 2026 for deploying LLMs on constrained infrastructure. It combines standard supervision with a distillation loss that aligns logits or internal representations. Unlike simple pruning, it preserves generalization capabilities. In this advanced tutorial, we will build a complete PyTorch pipeline to distill a BERT-like model into a reduced version, with temperature and loss weight management.

Prerequisites

  • Python 3.11+
  • PyTorch 2.4+
  • Transformers 4.45+
  • CUDA 12.1 (recommended)
  • Solid knowledge of deep learning and optimization

Setup and Dependencies

terminal
pip install torch==2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install transformers==4.45.0 datasets accelerate

Install stable versions compatible with advanced distillation. CUDA is required to efficiently train the teacher and student in parallel.

Defining Teacher and Student Models

models.py
from transformers import AutoModelForSequenceClassification, AutoConfig
import torch.nn as nn

class DistillationModel(nn.Module):
    def __init__(self, teacher_name="bert-base-uncased", student_name="bert-base-uncased", num_labels=2):
        super().__init__()
        self.teacher = AutoModelForSequenceClassification.from_pretrained(teacher_name, num_labels=num_labels)
        self.student = AutoModelForSequenceClassification.from_pretrained(student_name, num_labels=num_labels)
        # Réduire la taille du student
        config = AutoConfig.from_pretrained(student_name)
        config.hidden_size = 384
        config.num_hidden_layers = 6
        self.student = AutoModelForSequenceClassification.from_config(config)
    
    def forward(self, input_ids, attention_mask, labels=None):
        with torch.no_grad():
            teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=attention_mask)
        student_outputs = self.student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return teacher_outputs, student_outputs

This module loads a frozen teacher and initializes a smaller student. The teacher's logits are extracted without gradients to serve as soft targets.

Implementing the Distillation Loss

distillation_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, teacher_logits, student_logits, labels):
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
        hard_loss = self.ce_loss(student_logits, labels)
        return self.alpha * distill_loss + (1 - self.alpha) * hard_loss

The loss combines KL-divergence on softened logits (distillation) and hard cross-entropy. The temperature parameter controls the softness of the distribution.

Complete Training Loop

train.py
from torch.optim import AdamW
from torch.utils.data import DataLoader

model = DistillationModel()
loss_fn = DistillationLoss(temperature=4.0, alpha=0.7)
optimizer = AdamW(model.student.parameters(), lr=2e-5)

model.train()
for epoch in range(3):
    for batch in train_loader:
        optimizer.zero_grad()
        teacher_out, student_out = model(batch['input_ids'], batch['attention_mask'], batch['labels'])
        loss = loss_fn(teacher_out.logits, student_out.logits, batch['labels'])
        loss.backward()
        optimizer.step()

Train only the student's parameters. The teacher remains in no_grad mode to save memory and stabilize learning.

Evaluation and Export Script

evaluate.py
from sklearn.metrics import accuracy_score
import torch

def evaluate(model, dataloader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            _, outputs = model(batch['input_ids'], batch['attention_mask'])
            preds.extend(torch.argmax(outputs.logits, dim=-1).cpu().numpy())
            labels.extend(batch['labels'].cpu().numpy())
    return accuracy_score(labels, preds)

# Export du student
model.student.save_pretrained("./distilled-student")

Strict evaluation of the student only. Final save in Hugging Face format for immediate deployment.

Best Practices

  • Always freeze the teacher and never propagate its gradients
  • Tune temperature and alpha via cross-validation
  • Use a scheduler with warmup for the student
  • Monitor KL divergence during training
  • Validate final model size and inference latency

Common Mistakes to Avoid

  • Forgetting to divide logits by temperature in both directions
  • Training the teacher at the same time as the student
  • Ignoring the KL loss scaling by temperature²
  • Using a batch size that is too small, which degrades soft distribution quality

Going Further

Deepen your knowledge of multi-layer distillation and post-distillation quantization techniques in our Learni training courses.