We introduced a new trainer to train Process-supervised Reward Model (PRM) in TRL. A PRM rewards the quality of intermediate steps, promoting structured reasoning over focusing solely on the final outcome.With this trainer, we introduce a new dataset type: Stepwise supervision, which is a variant of the prompt-completion type, but for which completion is divided into several intermediate steps, and each step is associated with a label. Find out more in the stepwise-supervision section in the TRL documentation.
Here is an example of how to use the PRMTrainer to train a PRM on the Math Shepherd dataset:
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
For more information, check out the PRMTrainer documentation.
by @qgallouedec and @gaetanlop in https://github.com/huggingface/trl/pull/2127 and https://github.com/huggingface/trl/pull/2148
MergeModelCallBackVarious works show that model merging can non-trivially improve performance, especially if the models belong to the same architecture. TRL now features a callback that merges the reference model with the current policy and optionally pushes the merged checkpoint to the Hub. This could be done on step/epoch end and/or the end of training. This callback uses Arcee's mergekit lib: https://github.com/arcee-ai/mergekit
from trl import DPOTrainer, MergeModelCallback
from trl.mergekit_utils import MergeConfig
config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
by @August-murr in https://github.com/huggingface/trl/pull/2282
TRL preprocessing utils now support tooling. A first step toward agent fine-tuning.
from trl import apply_chat_template
def get_current_temperature(location: str):
"""
Gets the temperature at a given location.
Args:
location: The location to get the temperature for
"""
return 22.0
example = apply_chat_template(example, tokenizer, tools=[get_current_temperature])
by @August-murr in https://github.com/huggingface/trl/pull/2455
DPOTrainerVLMs have their own specificities which require special treatment in the trainer. DPOTrainer now supports LLaVA-Next models natively.
model = model = AutoModelForVision2Seq.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
trainer = DPOTrainer(model=model, ...)
by @chenweize1998 in https://github.com/huggingface/trl/pull/2413
TRL CLI has been refactored to be more user-friendly and easy to extend. We plan to extend the support to all trainers soon.
(simplified output, for readibility)
$ trl dpo --help
usage: trl dpo [-h] --dataset_name DATASET_NAME [--dataset_config DATASET_CONFIG] --output_dir OUTPUT_DIR [--loss_type {sigmoid,hinge,ipo}]
options:
-h, --help show this help message and exit
--dataset_name DATASET_NAME, --dataset-name DATASET_NAME
--dataset_config DATASET_CONFIG, --dataset-config DATASET_CONFIG
--output_dir OUTPUT_DIR, --output-dir OUTPUT_DIR
The output directory where the model predictions and checkpoints will be written. (default: None)
--loss_type {sigmoid,hinge,ipo}, --loss-type {sigmoid,hinge,ipo}
by @qgallouedec in https://github.com/huggingface/trl/pull/2380 and https://github.com/huggingface/trl/pull/2412
TRL features a new judge AllTrueJudge that unifies the decision of multiple binary judges. This judge implements the Mixture of Judges as described in the CGPO paper.
from trl import AllTrueJudge, BaseBinaryJudge
class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""
def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
judgements = judge.judge(prompts=prompts, completions=completions)
print(judgements) # [0, 1]
by @gaetanlop in https://github.com/huggingface/trl/pull/2159
num_logits_to_keep to save memorySave memory by only keeping the top num_logits_to_keep logits in the DPO trainer.
training_args = DPOConfig(..., use_num_logits_to_keep=True)
by @xyangk in https://github.com/huggingface/trl/pull/2129
The DiscoPOP paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0).
training_args = DPOConfig(..., loss_type="discopop", discopop_tau=0.05)
by @fanconic in https://github.com/huggingface/trl/pull/2323
DPOTrainer for reference modelWe can now control the batch size for precomputing reference model logits.
training_args = DPOConfig(
...
precompute_ref_log_probs=True,
precompute_ref_batch_size=4,
)
by @SwayamInSync in https://github.com/huggingface/trl/pull/2426
SFTTrainer has supported packing datasets for faster training. Now, it support packing tokenized datasets as well.
by @kmehant in https://github.com/huggingface/trl/pull/2011
PPOTrainerPPOTrainer now supports PEFT for efficient training.
PPOTrainer(
...,
peft_config=peft_config,
)
by @ccs96307 in https://github.com/huggingface/trl/pull/2344
config in favor of args in PPOTrainerconfig has been deprecated in favor of args in PPOTrainer.
PPOTrainer(
- config=training_args,
+ args=training_args,
)
by @qgallouedec in https://github.com/huggingface/trl/pull/2384
policy in favor of model in PPOTrainerpolicy has been deprecated in favor of model in PPOTrainer.
PPOTrainer(
- policy=model,
+ model=model,
)
by @qgallouedec in https://github.com/huggingface/trl/pull/2386
0.13.0.dev0 by @qgallouedec in https://github.com/huggingface/trl/pull/2305token_id instead of token in DPOTrainer by @qgallouedec in https://github.com/huggingface/trl/pull/2324output_layer to the list of lm_head_namings in AutoModelForCausalLMWithValueHead by @qgallouedec in https://github.com/huggingface/trl/pull/2328tokenizer arg back and add deprecation guidelines by @qgallouedec in https://github.com/huggingface/trl/pull/2348tokenizer argument in BCO, GKD, Iterative SFT, Nash MD and XPO by @qgallouedec in https://github.com/huggingface/trl/pull/2349use_soft_judge option to WinRateCallback by @kashif in https://github.com/huggingface/trl/pull/2347GeometricMixtureWrapper.forward by @kashif in https://github.com/huggingface/trl/pull/2345data_collator in RLOOTrainer and PPOTrainer by @bartoszzuk in https://github.com/huggingface/trl/pull/2360PPOTrainer by @ccs96307 in https://github.com/huggingface/trl/pull/2344require_bitsandbytes by @qgallouedec in https://github.com/huggingface/trl/pull/2370start_time to _maybe_log_save_evaluate by @qgallouedec in https://github.com/huggingface/trl/pull/2373MergeModelCallBack by @August-murr in https://github.com/huggingface/trl/pull/2282start_time parameter by @qgallouedec in https://github.com/huggingface/trl/pull/2381config in favor of args in PPOTrainer by @qgallouedec in https://github.com/huggingface/trl/pull/2384policy in favor of model in PPOTrainer by @qgallouedec in https://github.com/huggingface/trl/pull/2386KTOTrainer by @qgallouedec in https://github.com/huggingface/trl/pull/2394SmolVLM models via standalone script sft_vlm_smol_vlm.py by @sergiopaniego in https://github.com/huggingface/trl/pull/2409AutoModelForCausalLMWithValueHead by @qgallouedec in https://github.com/huggingface/trl/pull/2398DPOTrainer by @chenweize1998 in https://github.com/huggingface/trl/pull/2413DPOTrainer for reference model by @SwayamInSync in https://github.com/huggingface/trl/pull/2426TrlParser by @qgallouedec in https://github.com/huggingface/trl/pull/2412max_steps calculation in RLOOTrainer by @qgallouedec in https://github.com/huggingface/trl/pull/2433datast_config to ScriptArguments by @qgallouedec in https://github.com/huggingface/trl/pull/2440ref_model in OnlineDPOTrainer by @qgallouedec in https://github.com/huggingface/trl/pull/2417model_args by @qgallouedec in https://github.com/huggingface/trl/pull/2442tests_latest.yml workflow file by @qgallouedec in https://github.com/huggingface/trl/pull/2457BitsAndBytesConfig import in doc by @August-murr in https://github.com/huggingface/trl/pull/2478Full Changelog: https://github.com/huggingface/trl/compare/v0.12.0...v0.13.0
Fetched April 7, 2026