Skip to content

Conversation

@kmsgnnew
Copy link

Resolves #977

#977

Added implementation of GDPO. Shared below reference link and implementation. This paper improves upon GRPO.

Reference
Reference paper
https://arxiv.org/abs/2601.05242
Reference Implementation
https://github.com/NVlabs/GDPO/blob/b080e63d0126870ad08acc8ebc3f04b728175a9e/trl-GDPO/trl-0.18.0-gdpo/trl/trainer/grpo_trainer.py#L1222

Colab Notebook
NiL

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@abheesht17
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the GDPO algorithm by adding a new GDPOLearner and a corresponding advantage estimation function. The implementation is a good starting point. My review includes two main points. First, there's a significant design issue in how _compute_rewards is overridden, which violates the base class contract and could lead to runtime errors. Second, I've suggested a refactoring of the compute_advantages function to use vectorization, which aligns better with JAX best practices and should improve performance and readability.

Comment on lines +64 to +147
def _compute_rewards(
self,
prompts: List[str],
completions: List[str],
mode: rl_cluster_lib.Mode,
step: int | None = None,
**kwargs,
) -> np.ndarray:
"""Computes the rewards for completions using the provided reward functions.

Args:
prompts: A list of input prompts.
completions: A list of generated text completions.
mode: The mode to use for logging metrics.
step: The current training step.
**kwargs: Additional keyword arguments passed to the reward functions.

Returns:
A numpy array (shape `[B]`) of scalar rewards for
each prompt-completion pair. The rewards are the sum across all the
provided reward functions.

Raises:
RuntimeError: If 'r' reward is None, indicating a failure to obtain the
result, or if the length of 'r' reward does not match the length of
'prompts'.
"""
if "mode" in kwargs:
raise ValueError(f"kwargs already contains mode as a key: {kwargs}")
kwargs["mode"] = str(mode)

num_prompts = len(prompts)
num_reward_fns = len(self.reward_fns)
rewards = np.zeros((num_prompts, num_reward_fns))

# Compute all rewards for each prompt-completion pair.
for i, reward_fn in enumerate(self.reward_fns):
r = reward_fn(prompts=prompts, completions=completions, **kwargs)

if r is None:
raise RuntimeError(
f"Failed to obtain result from {reward_fn.__name__}. Result is"
" None."
)
if isinstance(r, list) and len(r) != len(prompts):
raise RuntimeError(
f"Length mismatch after {reward_fn.__name__}: "
f"len(r)={len(r)}, len(prompts)={num_prompts}. "
f"Content of r: {r}"
)

rewards[:, i] = np.array(r)

# Sum rewards across all reward functions for each prompt.
sum_rewards = np.nansum(rewards, axis=1)

# Log all metrics in a single loop
for j, (prompt, completion) in enumerate(zip(prompts, completions)):
metrics_to_log = {}

# Log prompts and completions.
metrics_to_log["prompts"] = (prompt, None)
metrics_to_log["completions"] = (completion, None)

# Log the summed rewards for this trajectory.
trajectory_sum = sum_rewards[j]
metrics_to_log["rewards/sum"] = (trajectory_sum, np.mean)
metrics_to_log["rewards/min"] = (np.min(rewards[j]), np.min)
metrics_to_log["rewards/max"] = (np.max(rewards[j]), np.max)

# Log individual rewards for this trajectory
for i, reward_fn in enumerate(self.reward_fns):
metric_name = f"rewards/{reward_fn.__name__}"
metrics_to_log[metric_name] = (rewards[j, i], np.mean)

# Log all metrics for this trajectory in one call
if step is not None:
self.rl_cluster.buffer_metrics_async(
metrics_to_log, mode=mode, step=step
)
else:
self.rl_cluster.buffer_metrics(metrics_to_log, mode=mode)

return jnp.array(rewards)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This overridden _compute_rewards method changes the return type contract from its parent class RLLearner. The base method returns a 1D np.ndarray of summed rewards, but this implementation returns a 2D jnp.ndarray of per-function rewards.

This violates the Liskov Substitution Principle and can lead to unexpected behavior or runtime errors in other parts of the system that rely on the base class's contract, particularly the metric functions (metric_fns) which are passed the output of this method and likely expect a 1D array of summed rewards.

Additionally, the docstring for this method (lines 81-84) is incorrect as it claims the method returns a 1D array of summed rewards, which doesn't match the implementation.

To fix this, consider overriding _generate_and_compute_advantage in GDPOLearner to handle the GDPO-specific reward computation, while ensuring that the rewards passed to metric functions is the 1D array of summed rewards to maintain compatibility.

Comment on lines +161 to +180
rewards_per_func = jnp.nan_to_num(rewards)
all_reward_advantage = []
for reward_index in range(rewards.shape[-1]):
reward_for_index = rewards_per_func[:, reward_index]
each_reward_mean_grouped = reward_for_index.reshape(
-1, num_generations
).mean(axis=1)
each_reward_std_grouped = reward_for_index.reshape(-1, num_generations).std(
axis=1
)
each_reward_mean_grouped = each_reward_mean_grouped.repeat(num_generations)
each_reward_std_grouped = each_reward_std_grouped.repeat(num_generations)
each_reward_advantage = reward_for_index - each_reward_mean_grouped
each_reward_advantage = each_reward_advantage / (
each_reward_std_grouped + 1e-4
)
all_reward_advantage.append(each_reward_advantage)

combined_reward_advantage = jnp.stack(all_reward_advantage, axis=1)
pre_bn_advantages = jnp.nansum(combined_reward_advantage, axis=1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation uses a Python for loop to iterate over reward functions. This is not idiomatic for JAX and can be inefficient. For better performance and readability, this logic can be vectorized using reshape and broadcasting, avoiding the explicit loop.

This change aligns with the JAX-native philosophy of the library.

  rewards_per_func = jnp.nan_to_num(rewards)

  # Reshape to group by generations. Shape: (num_prompts, num_generations, num_reward_fns)
  grouped_rewards = rewards_per_func.reshape(-1, num_generations, rewards.shape[-1])

  # Compute mean and std per group for each reward function, using keepdims for broadcasting.
  # Shape: (num_prompts, 1, num_reward_fns)
  mean_grouped = grouped_rewards.mean(axis=1, keepdims=True)
  std_grouped = grouped_rewards.std(axis=1, keepdims=True)

  # Normalize within each group and reshape back.
  normalized_advantages = (grouped_rewards - mean_grouped) / (std_grouped + 1e-4)
  combined_reward_advantage = normalized_advantages.reshape(rewards_per_func.shape)

  pre_bn_advantages = jnp.nansum(combined_reward_advantage, axis=1)
References
  1. The style guide emphasizes following JAX-native patterns. Using vectorization instead of explicit Python loops is a core JAX pattern for performance and readability. (link)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add GDPO Support

2 participants