Implement pass rate-based curriculum learning with weighted sampling#153
Implement pass rate-based curriculum learning with weighted sampling#153jb3618columbia wants to merge 1 commit intoverl-latest-cispofrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements pass rate-based curriculum learning for GRPO training by introducing weighted sampling that prioritizes harder samples (those with lower historical success rates).
Changes:
- Added
PassRateTrackerclass to track attempt counts and success rates for each prompt in the dataset - Added
PassRateWeightedSamplerclass that implements curriculum learning through dynamic weighted sampling based on historical pass rates - Integrated curriculum learning into the DAPO trainer with pass rate tracking and curriculum-specific metrics logging
- Updated configuration files and training scripts with curriculum learning examples
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 18 comments.
Show a summary per file
| File | Description |
|---|---|
| verl/utils/pass_rate_tracker.py | Core tracker for maintaining historical pass rates and attempt counts per sample |
| verl/utils/pass_rate_weighted_sampler.py | Weighted sampler that adjusts sampling probabilities based on pass rates |
| verl/utils/dataset/rl_dataset.py | Added dataset_index field to enable sample tracking |
| verl/trainer/ppo/ray_trainer.py | Added comment clarifying sampler creation |
| verl/trainer/ppo/metric_utils.py | Added reward standard deviation metric |
| verl/trainer/config/data/legacy_data.yaml | Added curriculum sampler configuration parameters |
| recipe/dapo/dapo_ray_trainer.py | Integrated pass rate tracking and curriculum metrics logging into training loop |
| scripts/* | Added example training scripts demonstrating curriculum learning usage |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
scripts/train/pass_rate_weighted_sampler_multinode_rl_qwen2.5_32b_base_fsdp.sh
Outdated
Show resolved
Hide resolved
nightlessbaron
left a comment
There was a problem hiding this comment.
Hey @jb3618columbia , this is good start. However, there are lots of things that we can improve on:
- Can we add customization to it such that the user can add a custom strategy in the weight sampler?
- Also add customization such that the user can define how many steps to wait before updating the weights.
- Please add some tests or optionally you can add them to RL360.
- Please add a short documentation as a quick start guide for people to use.
- Address all of my comments as well as those from codex :)
Good job overall :D
There was a problem hiding this comment.
remove this file from your changes
There was a problem hiding this comment.
Do we want to keep the changes we made to the cluster environment set up so that people in the future can use the example script? If not, I can remove all the above changes.
One thought is to update/clean up these files as we go along so people can refer/use these out of the box
There was a problem hiding this comment.
as discussed, let's remove them for now
There was a problem hiding this comment.
remove this file from your changes
| # num dataloader workers | ||
| dataloader_num_workers: 8 | ||
| # NOTE: Must be 0 when using curriculum learning samplers (e.g., PassRateWeightedSampler) | ||
| # to prevent data caching before batches are reordered. |
There was a problem hiding this comment.
not sure i understand this
There was a problem hiding this comment.
My understanding is that having dataloader_num_workers > 0 will not be able to sample batches from the latest set of weights becuase of caching, but perhaps I am incorrect here
There was a problem hiding this comment.
It just means the DataLoader will use multiple worker subprocesses to fetch and preprocess batches in parallel with your training loop. If you set it to 0, data loading would happen in the main training process.
There was a problem hiding this comment.
Hmm, would it lead to ''off policyness'' in terms of the weighting distribution? That was my concern
f24ef65 to
ef85821
Compare
|
|
||
| "${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ | ||
| --config-path=config \ | ||
| --config-name="dapo_fsdp_config_with_resampling.yaml" \ |
There was a problem hiding this comment.
@nightlessbaron this config is a now used for pass rate based weighted sampling
83d9600 to
ec3726c
Compare
ec3726c to
ecc8b8c
Compare
Summary
Implements curriculum learning using pass rate-based weighted sampling for GRPO training
Changes
PassRateTrackerclass to track attempt and success counts for each prompt. This tracker can be used for multiple curriculum samplersPassRateWeightedSamplerclass which implements a weighted sampler that adjusts sampling probabilities (probability of prompts sampled in a batch) based on historical pass rates (optional to use exponential moving average)Testing
Tested with local single-node runs
Tested with multi-node SLURM runs (2 nodes, 8 GPUs each)
Logs curriculum metrics: hardest_10pct/25pct/50pct/75pct pass rates, batch-level statistics
See curriculum learning runs: https://wandb.ai/mbzuai-llm/Reasoning360/runs/qab27nv0?nw=nwuserjalajbhandari
Example run: Curriculum-1435219-qwen2.5-32b-base-fsdp-temp_0.5_data_mixtures_round2_train_prompt_bsz_32