Skip to content

Conversation

@parthchadha
Copy link
Contributor

@parthchadha parthchadha commented Jan 28, 2026

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

#1814

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Added epoch-based checkpointing support to save and restore training progress.
    • Training can now be resumed from a specific epoch when loading from a checkpoint.
    • Real-time epoch tracking during training operations.
  • Tests

    • Added comprehensive test suite for multi-epoch collection behavior and checkpoint restoration scenarios.

✏️ Tip: You can customize this high-level summary in your review settings.

@parthchadha parthchadha requested review from a team as code owners January 28, 2026 16:11
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

📝 Walkthrough

Walkthrough

Introduces epoch tracking and checkpointing support to async trajectory collection in reinforcement learning algorithms. Adds current_epoch attribute to AsyncTrajectoryCollector with methods to export and restore dataloader state, enabling training resumption from saved checkpoints with proper epoch context propagation.

Changes

Cohort / File(s) Summary
Async trajectory collection core
nemo_rl/algorithms/async_utils.py
Added current_epoch attribute initialized to 0. Introduced epoch-based restarting with increment on each epoch cycle. Added checkpoint support: get_dataloader_state() exports current epoch and optional dataloader state as dict; set_dataloader_state(state) restores epoch and dataloader state.
GRPO training integration
nemo_rl/algorithms/grpo.py
Extended setup() to track restored_epoch from checkpoint loading, handling two checkpoint formats: dict with epoch/dataloader state, or direct dataloader state dict. Added restored_epoch optional parameter to async_grpo_train() to propagate resumed epoch context to trajectory collector via remote call.
Multi-epoch collection tests
tests/unit/algorithms/test_async_utils.py
Added TestMultiEpochCollection test suite with scenarios validating initial epoch state, epoch restoration via set_dataloader_state(), epoch increments during collection, checkpoint format handling, and full lifecycle including start/stop operations.

Sequence Diagram

sequenceDiagram
    participant Trainer as Trainer/Setup
    participant Checkpoint as Checkpoint Storage
    participant Collector as AsyncTrajectoryCollector
    participant Buffer as ReplayBuffer

    Trainer->>Checkpoint: Load checkpoint
    Checkpoint-->>Trainer: {current_epoch, dataloader_state}
    Trainer->>Trainer: Store restored_epoch
    Trainer->>Collector: Start collection
    Trainer->>Collector: set_dataloader_state(restored_epoch)
    Collector->>Collector: Restore current_epoch
    loop Each Epoch
        Collector->>Buffer: Collect trajectories
        Buffer-->>Collector: Trajectories
        Collector->>Collector: Increment current_epoch
        Collector->>Collector: Emit epoch progress
    end
    Trainer->>Collector: get_dataloader_state()
    Collector-->>Trainer: {current_epoch, dataloader_state}
    Trainer->>Checkpoint: Save checkpoint
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested labels

asyncRL, r0.5.0

Suggested reviewers

  • terrykong
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR description lacks documentation of test results, testing methodology, regression validation, and performance metrics for this major feature addition to async GRPO training. Update PR description to document test scenarios executed, confirm multi-epoch convergence consistency, provide end-to-end training results, and address review comments regarding timing issues.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the main change: enabling multi-epoch training capability for async GRPO, which aligns with the core modifications to AsyncTrajectoryCollector and grpo.py that implement epoch tracking and checkpoint restoration.
Docstring Coverage ✅ Passed Docstring coverage is 82.35% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 271-289: restored_epoch is computed during checkpoint loading but
never exposed, so async resume can't restore epoch state; update the setup
function (where restored_epoch is set) to return restored_epoch (or attach it to
the object/state returned by setup) and/or persist restored_epoch inside
grpo_save_state whenever checkpoints are written, and update async_grpo_train to
accept and use the restored_epoch value to resume the epoch counter; reference
symbols: restored_epoch, setup, async_grpo_train, grpo_save_state, and the
checkpoint load block that sets restored_epoch.
- Around line 2085-2086: The restore of dataloader/epoch state must be applied
before the background collector thread is started to avoid races; modify
start_collection so it accepts/uses restored_epoch (and any restored dataloader
state) and calls set_dataloader_state/restoration logic inside start_collection
before spawning the collection thread, or alternatively add an initial_state or
start_paused boolean plus an explicit event/notify that the caller sets after
set_dataloader_state completes; update the code paths around start_collection
and the restoration code referenced by restored_epoch and set_dataloader_state
(also fix the same pattern at the block around lines 2233-2241) so the collector
cannot advance/consume batches until the restored state is applied.

Comment on lines +271 to +289
# Track epoch for async GRPO (will be None for sync GRPO or old checkpoints)
restored_epoch: Optional[int] = None
if last_checkpoint_path is not None:
dataloader_state_dict = torch.load(
os.path.join(last_checkpoint_path, "train_dataloader.pt")
)
dataloader.load_state_dict(dataloader_state_dict)
# Handle direct state_dict and dict with epoch
if (
isinstance(dataloader_state_dict, dict)
and "current_epoch" in dataloader_state_dict
):
# Extract epoch and dataloader state
restored_epoch = dataloader_state_dict["current_epoch"]
if "dataloader_state" in dataloader_state_dict:
dataloader.load_state_dict(dataloader_state_dict["dataloader_state"])
print(f" ✓ Restored from epoch {restored_epoch}")
else:
# Direct dataloader state_dict
dataloader.load_state_dict(dataloader_state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Expose restored_epoch to the async training path

restored_epoch is computed during checkpoint load but isn’t returned or stored anywhere, so callers can’t pass it into async_grpo_train and epoch restoration becomes a no-op on resume. Please surface it (e.g., return it from setup or persist it alongside grpo_save_state) so async resume can reliably restore epoch context.

🤖 Prompt for AI Agents
In `@nemo_rl/algorithms/grpo.py` around lines 271 - 289, restored_epoch is
computed during checkpoint loading but never exposed, so async resume can't
restore epoch state; update the setup function (where restored_epoch is set) to
return restored_epoch (or attach it to the object/state returned by setup)
and/or persist restored_epoch inside grpo_save_state whenever checkpoints are
written, and update async_grpo_train to accept and use the restored_epoch value
to resume the epoch counter; reference symbols: restored_epoch, setup,
async_grpo_train, grpo_save_state, and the checkpoint load block that sets
restored_epoch.

Comment on lines +2085 to 2086
restored_epoch: Optional[int] = None,
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restore state before the collection thread starts

start_collection launches the background thread immediately, and set_dataloader_state runs afterward. The collector can advance epochs or consume batches before the restore lands, making resumes nondeterministic. Please move the restore into start_collection (before starting the thread), or add an explicit pause/initial_state parameter to guarantee state is applied before collection begins.

Also applies to: 2233-2241

🤖 Prompt for AI Agents
In `@nemo_rl/algorithms/grpo.py` around lines 2085 - 2086, The restore of
dataloader/epoch state must be applied before the background collector thread is
started to avoid races; modify start_collection so it accepts/uses
restored_epoch (and any restored dataloader state) and calls
set_dataloader_state/restoration logic inside start_collection before spawning
the collection thread, or alternatively add an initial_state or start_paused
boolean plus an explicit event/notify that the caller sets after
set_dataloader_state completes; update the code paths around start_collection
and the restoration code referenced by restored_epoch and set_dataloader_state
(also fix the same pattern at the block around lines 2233-2241) so the collector
cannot advance/consume batches until the restored state is applied.

Signed-off-by: Parth Chadha <pchadha@nvidia.com>
@parthchadha parthchadha added the CI:L0 Run doctests and unit tests label Jan 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L0 Run doctests and unit tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants