-
Notifications
You must be signed in to change notification settings - Fork 233
fix: allow multi epoch training for async grpo #1836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… checkpoint restoration Signed-off-by: Parth Chadha <pchadha@nvidia.com>
📝 WalkthroughWalkthroughIntroduces epoch tracking and checkpointing support to async trajectory collection in reinforcement learning algorithms. Adds Changes
Sequence DiagramsequenceDiagram
participant Trainer as Trainer/Setup
participant Checkpoint as Checkpoint Storage
participant Collector as AsyncTrajectoryCollector
participant Buffer as ReplayBuffer
Trainer->>Checkpoint: Load checkpoint
Checkpoint-->>Trainer: {current_epoch, dataloader_state}
Trainer->>Trainer: Store restored_epoch
Trainer->>Collector: Start collection
Trainer->>Collector: set_dataloader_state(restored_epoch)
Collector->>Collector: Restore current_epoch
loop Each Epoch
Collector->>Buffer: Collect trajectories
Buffer-->>Collector: Trajectories
Collector->>Collector: Increment current_epoch
Collector->>Collector: Emit epoch progress
end
Trainer->>Collector: get_dataloader_state()
Collector-->>Trainer: {current_epoch, dataloader_state}
Trainer->>Checkpoint: Save checkpoint
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 271-289: restored_epoch is computed during checkpoint loading but
never exposed, so async resume can't restore epoch state; update the setup
function (where restored_epoch is set) to return restored_epoch (or attach it to
the object/state returned by setup) and/or persist restored_epoch inside
grpo_save_state whenever checkpoints are written, and update async_grpo_train to
accept and use the restored_epoch value to resume the epoch counter; reference
symbols: restored_epoch, setup, async_grpo_train, grpo_save_state, and the
checkpoint load block that sets restored_epoch.
- Around line 2085-2086: The restore of dataloader/epoch state must be applied
before the background collector thread is started to avoid races; modify
start_collection so it accepts/uses restored_epoch (and any restored dataloader
state) and calls set_dataloader_state/restoration logic inside start_collection
before spawning the collection thread, or alternatively add an initial_state or
start_paused boolean plus an explicit event/notify that the caller sets after
set_dataloader_state completes; update the code paths around start_collection
and the restoration code referenced by restored_epoch and set_dataloader_state
(also fix the same pattern at the block around lines 2233-2241) so the collector
cannot advance/consume batches until the restored state is applied.
| # Track epoch for async GRPO (will be None for sync GRPO or old checkpoints) | ||
| restored_epoch: Optional[int] = None | ||
| if last_checkpoint_path is not None: | ||
| dataloader_state_dict = torch.load( | ||
| os.path.join(last_checkpoint_path, "train_dataloader.pt") | ||
| ) | ||
| dataloader.load_state_dict(dataloader_state_dict) | ||
| # Handle direct state_dict and dict with epoch | ||
| if ( | ||
| isinstance(dataloader_state_dict, dict) | ||
| and "current_epoch" in dataloader_state_dict | ||
| ): | ||
| # Extract epoch and dataloader state | ||
| restored_epoch = dataloader_state_dict["current_epoch"] | ||
| if "dataloader_state" in dataloader_state_dict: | ||
| dataloader.load_state_dict(dataloader_state_dict["dataloader_state"]) | ||
| print(f" ✓ Restored from epoch {restored_epoch}") | ||
| else: | ||
| # Direct dataloader state_dict | ||
| dataloader.load_state_dict(dataloader_state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expose restored_epoch to the async training path
restored_epoch is computed during checkpoint load but isn’t returned or stored anywhere, so callers can’t pass it into async_grpo_train and epoch restoration becomes a no-op on resume. Please surface it (e.g., return it from setup or persist it alongside grpo_save_state) so async resume can reliably restore epoch context.
🤖 Prompt for AI Agents
In `@nemo_rl/algorithms/grpo.py` around lines 271 - 289, restored_epoch is
computed during checkpoint loading but never exposed, so async resume can't
restore epoch state; update the setup function (where restored_epoch is set) to
return restored_epoch (or attach it to the object/state returned by setup)
and/or persist restored_epoch inside grpo_save_state whenever checkpoints are
written, and update async_grpo_train to accept and use the restored_epoch value
to resume the epoch counter; reference symbols: restored_epoch, setup,
async_grpo_train, grpo_save_state, and the checkpoint load block that sets
restored_epoch.
| restored_epoch: Optional[int] = None, | ||
| ) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore state before the collection thread starts
start_collection launches the background thread immediately, and set_dataloader_state runs afterward. The collector can advance epochs or consume batches before the restore lands, making resumes nondeterministic. Please move the restore into start_collection (before starting the thread), or add an explicit pause/initial_state parameter to guarantee state is applied before collection begins.
Also applies to: 2233-2241
🤖 Prompt for AI Agents
In `@nemo_rl/algorithms/grpo.py` around lines 2085 - 2086, The restore of
dataloader/epoch state must be applied before the background collector thread is
started to avoid races; modify start_collection so it accepts/uses
restored_epoch (and any restored dataloader state) and calls
set_dataloader_state/restoration logic inside start_collection before spawning
the collection thread, or alternatively add an initial_state or start_paused
boolean plus an explicit event/notify that the caller sets after
set_dataloader_state completes; update the code paths around start_collection
and the restoration code referenced by restored_epoch and set_dataloader_state
(also fix the same pattern at the block around lines 2233-2241) so the collector
cannot advance/consume batches until the restored state is applied.
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
#1814
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.