v1.2.0
Features
New SSDTrainer — Simple Self-Distillation
<img width="778" height="334" alt="Screenshot 2026-04-16 at 9 08 04 PM" src="https://github.com/user-attachments/assets/8ca223f0-6740-48a8-967c-ec10cb262a93" />
A new experimental SSDTrainer implements the method described in Embarrassingly Simple Self-Distillation Improves Code Generation. SSD samples completions from the model itself at a training-time temperature/truncation setting, then fine-tunes on those raw, unverified samples with standard cross-entropy loss. No reward model, verifier, teacher model, or RL: just prompts and the model.
from datasets import Dataset
from trl.experimental.ssd import SSDConfig, SSDTrainer
dataset = Dataset.from_dict({
"prompt": [
[{"role": "user", "content": "Write a function to add two numbers."}],
[{"role": "user", "content": "Write a function to check if a number is prime."}],
],
})
trainer = SSDTrainer(
model="Qwen/Qwen3-4B-Instruct",
args=SSDConfig(
output_dir="ssd-model",
temperature=0.6, # T_train from the paper
top_k=20,
top_p=0.95,
learning_rate=5e-6,
),
train_dataset=dataset,
)
trainer.train()
by @kashif in https://github.com/huggingface/trl/pull/5505
Drop, don't truncate, overlong tool results in GRPOTrainer
When tool calls produce more tokens than max_completion_length allows, GRPOTrainer now rolls back the tool messages/images added in the current iteration instead of trying to truncate them. This removes ~80 lines of fragile, image-boundary-aware bookkeeping in favor of a ~15-line snapshot-and-rollback. Since overlong samples almost always get rewarded as failures anyway, the learning signal is effectively unchanged — but the code is dramatically simpler and no longer needs per-VLM-family vision-token lookup tables.
by @qgallouedec in https://github.com/huggingface/trl/pull/5521
Expanded tool-calling model support: LLaMA 3.1 / 3.2 & DeepSeek-V3
Continuing the effort from v1.1:
- LLaMA 3.1 and 3.2 tool-calling response schemas, with dedicated templates for identity matching. Note that these templates only support a single tool call and no content alongside the tool call — limitations inherited from the models' native templates. By @qgallouedec in https://github.com/huggingface/trl/pull/5518
- DeepSeek-V3 training chat template with
{% generation %}markers, enabling assistant-only loss masking forDeepSeek-V3models. By @RudrenduPaul in https://github.com/huggingface/trl/pull/5527
As a result of a tightened detection (see fixes below), the list of templates reported as tool-calling capable is now correct — notably, the basic Llama 3 template is no longer falsely classified as tool-calling capable.
KTO/DPO alignment push
A major cleanup sweep keeps KTOTrainer and DPOTrainer in lockstep, same initialization patterns, same config surface, same precompute behavior:
- Add
precompute_ref_batch_sizeto KTO (https://github.com/huggingface/trl/pull/5530) - Align
ref_modelinitialization (https://github.com/huggingface/trl/pull/5534) - Align model initialization (https://github.com/huggingface/trl/pull/5533)
- Support
Noneargs (https://github.com/huggingface/trl/pull/5531) - Remove
generate_during_eval(https://github.com/huggingface/trl/pull/5551) - Remove model and ref adapter names (https://github.com/huggingface/trl/pull/5552)
- Don't load
ref_modelwhenprecompute_ref_log_probsis set in DPO/KTO (https://github.com/huggingface/trl/pull/5542)
All by @albertvillanova.
Other
- Support messages with images in
prepare_multimodal_messagesby @albertvillanova in https://github.com/huggingface/trl/pull/5474 - Simplify role handling in
prepare_multimodal_messagesby @albertvillanova in https://github.com/huggingface/trl/pull/5508 - Update vLLM version support to 0.18.0 by @qgallouedec in https://github.com/huggingface/trl/pull/5547
Fixes
- Fix
supports_tool_callingfalsely accepting templates that drop assistanttool_callsby @qgallouedec in https://github.com/huggingface/trl/pull/5517 - Fix
add_response_schemafor VLM processors — the schema was being set on the outer processor instead of the inner tokenizer, so it had no effect. This also collapses a handful of__init__/decode-gate workarounds. By @qgallouedec in https://github.com/huggingface/trl/pull/5520 - Remove xfail condition for Gemma 4 response_schema regex bug by @qgallouedec in https://github.com/huggingface/trl/pull/5510
- Remove unused dependencies for judges from dev requirements by @qgallouedec in https://github.com/huggingface/trl/pull/5515
Deprecations
- Deprecate
use_transformers_pagedinGRPOConfigandRLOOConfig(and remove entirely from experimentalOnlineDPOConfig,GOLDConfig,SelfDistillationConfig). Will be removed from the remaining configs in v2.0.0. In a small A/B benchmark (Qwen3-0.6B GRPO), the paged path is ~20% slower and uses ~6x more peak VRAM than the default; it's also superseded bytransformerscontinuous batching. By @qgallouedec in https://github.com/huggingface/trl/pull/5544
Documentation and Examples
- Add example script section to experimental trainer docs by @sergiopaniego in https://github.com/huggingface/trl/pull/5543
- [Docs] Fix formatting in SSD training example script by @kashif in https://github.com/huggingface/trl/pull/5548
- Nits in SSD docs by @sergiopaniego in https://github.com/huggingface/trl/pull/5554
- [docs] Add LLaMA 3 / Qwen 2.5 entries to
chat_templates/READMEby @qgallouedec in https://github.com/huggingface/trl/pull/5545 - Update CARLA VLM example scripts by @sergiopaniego in https://github.com/huggingface/trl/pull/5557
CI
- Fix CI dependency installs to use a single resolve by @qgallouedec in https://github.com/huggingface/trl/pull/5513
- Set upper transformers version to skip distributed test_rloo after fixed by @albertvillanova in https://github.com/huggingface/trl/pull/5535
- Update tests with zero3 for RLOO and GRPO once fixed in transformers 5.5.4 by @albertvillanova in https://github.com/huggingface/trl/pull/5541
- Bump doc-builder SHA for PR upload workflow by @rtrompier in https://github.com/huggingface/trl/pull/5553
What's Changed
- ⬆️ Bump dev version by @qgallouedec in https://github.com/huggingface/trl/pull/5525
- Simplify role handling in prepare_multimodal_messages by @albertvillanova in https://github.com/huggingface/trl/pull/5508
- Fix CI dependency installs to use a single resolve by @qgallouedec in https://github.com/huggingface/trl/pull/5513
- Fix
supports_tool_callingfalsely accepting templates that drop assistanttool_callsby @qgallouedec in https://github.com/huggingface/trl/pull/5517 - feat: add DeepSeek-V3 training chat template with generation markers by @RudrenduPaul in https://github.com/huggingface/trl/pull/5527
- Drop, don't truncate, overlong tool results in GRPOTrainer by @qgallouedec in https://github.com/huggingface/trl/pull/5521
- Set upper transformers version to skip distributed test_rloo after fixed by @albertvillanova in https://github.com/huggingface/trl/pull/5535
- Align KTO with DPO: Add precompute_ref_batch_size by @albertvillanova in https://github.com/huggingface/trl/pull/5530
- Update tests with zero3 for RLOO and GRPO once fixed in transformers 5.5.4 by @albertvillanova in https://github.com/huggingface/trl/pull/5541
- Align KTO with DPO: Align ref_model initialization by @albertvillanova in https://github.com/huggingface/trl/pull/5534
- Align KTO with DPO: Align model initialization by @albertvillanova in https://github.com/huggingface/trl/pull/5533
- Remove unused dependencies for judges from dev requirements by @qgallouedec in https://github.com/huggingface/trl/pull/5515
- Remove xfail condition for Gemma4 response_schema regex bug by @qgallouedec in https://github.com/huggingface/trl/pull/5510
- Align KTO with DPO: Support None args by @albertvillanova in https://github.com/huggingface/trl/pull/5531
- Add example script section to experimental trainer docs by @sergiopaniego in https://github.com/huggingface/trl/pull/5543
- [SSD] Added SSD trainer in experimental by @kashif in https://github.com/huggingface/trl/pull/5505
- [Docs] Fix formatting in SSD training example script by @kashif in https://github.com/huggingface/trl/pull/5548
- Don't load ref_model when precompute_ref_log_probs in DPO/KTO by @albertvillanova in https://github.com/huggingface/trl/pull/5542
- chore: bump doc-builder SHA for PR upload workflow by @rtrompier in https://github.com/huggingface/trl/pull/5553
- Nits is SSD docs by @sergiopaniego in https://github.com/huggingface/trl/pull/5554
- Deprecate
use_transformers_pagedby @qgallouedec in https://github.com/huggingface/trl/pull/5544 - Update vLLM version support to 0.18.0 by @qgallouedec in https://github.com/huggingface/trl/pull/5547
- Align KTO with DPO: Remove generate_during_eval by @albertvillanova in https://github.com/huggingface/trl/pull/5551
- Align KTO with DPO: Remove model and ref adapter names by @albertvillanova in https://github.com/huggingface/trl/pull/5552
- Support messages with images in prepare_multimodal_messages by @albertvillanova in https://github.com/huggingface/trl/pull/5474
- Update CARLA VLM example scripts by @sergiopaniego in https://github.com/huggingface/trl/pull/5557
- Fix
add_response_schemafor VLM processors by @qgallouedec in https://github.com/huggingface/trl/pull/5520 - [docs] Add LLaMA 3 / Qwen 2.5 entries to
chat_templates/READMEby @qgallouedec in https://github.com/huggingface/trl/pull/5545 - Add LLaMA 3.1 and 3.2 tool calling support by @qgallouedec in https://github.com/huggingface/trl/pull/5518
- Release: v1.2 by @qgallouedec in https://github.com/huggingface/trl/pull/5576
Full Changelog: https://github.com/huggingface/trl/compare/v1.1.0...v1.2.0
Fetched April 17, 2026
