-
Notifications
You must be signed in to change notification settings - Fork 232
Feature/gdpo #978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature/gdpo #978
Conversation
added GDPO
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the GDPO algorithm by adding a new GDPOLearner and a corresponding advantage estimation function. The implementation is a good starting point. My review includes two main points. First, there's a significant design issue in how _compute_rewards is overridden, which violates the base class contract and could lead to runtime errors. Second, I've suggested a refactoring of the compute_advantages function to use vectorization, which aligns better with JAX best practices and should improve performance and readability.
| def _compute_rewards( | ||
| self, | ||
| prompts: List[str], | ||
| completions: List[str], | ||
| mode: rl_cluster_lib.Mode, | ||
| step: int | None = None, | ||
| **kwargs, | ||
| ) -> np.ndarray: | ||
| """Computes the rewards for completions using the provided reward functions. | ||
|
|
||
| Args: | ||
| prompts: A list of input prompts. | ||
| completions: A list of generated text completions. | ||
| mode: The mode to use for logging metrics. | ||
| step: The current training step. | ||
| **kwargs: Additional keyword arguments passed to the reward functions. | ||
|
|
||
| Returns: | ||
| A numpy array (shape `[B]`) of scalar rewards for | ||
| each prompt-completion pair. The rewards are the sum across all the | ||
| provided reward functions. | ||
|
|
||
| Raises: | ||
| RuntimeError: If 'r' reward is None, indicating a failure to obtain the | ||
| result, or if the length of 'r' reward does not match the length of | ||
| 'prompts'. | ||
| """ | ||
| if "mode" in kwargs: | ||
| raise ValueError(f"kwargs already contains mode as a key: {kwargs}") | ||
| kwargs["mode"] = str(mode) | ||
|
|
||
| num_prompts = len(prompts) | ||
| num_reward_fns = len(self.reward_fns) | ||
| rewards = np.zeros((num_prompts, num_reward_fns)) | ||
|
|
||
| # Compute all rewards for each prompt-completion pair. | ||
| for i, reward_fn in enumerate(self.reward_fns): | ||
| r = reward_fn(prompts=prompts, completions=completions, **kwargs) | ||
|
|
||
| if r is None: | ||
| raise RuntimeError( | ||
| f"Failed to obtain result from {reward_fn.__name__}. Result is" | ||
| " None." | ||
| ) | ||
| if isinstance(r, list) and len(r) != len(prompts): | ||
| raise RuntimeError( | ||
| f"Length mismatch after {reward_fn.__name__}: " | ||
| f"len(r)={len(r)}, len(prompts)={num_prompts}. " | ||
| f"Content of r: {r}" | ||
| ) | ||
|
|
||
| rewards[:, i] = np.array(r) | ||
|
|
||
| # Sum rewards across all reward functions for each prompt. | ||
| sum_rewards = np.nansum(rewards, axis=1) | ||
|
|
||
| # Log all metrics in a single loop | ||
| for j, (prompt, completion) in enumerate(zip(prompts, completions)): | ||
| metrics_to_log = {} | ||
|
|
||
| # Log prompts and completions. | ||
| metrics_to_log["prompts"] = (prompt, None) | ||
| metrics_to_log["completions"] = (completion, None) | ||
|
|
||
| # Log the summed rewards for this trajectory. | ||
| trajectory_sum = sum_rewards[j] | ||
| metrics_to_log["rewards/sum"] = (trajectory_sum, np.mean) | ||
| metrics_to_log["rewards/min"] = (np.min(rewards[j]), np.min) | ||
| metrics_to_log["rewards/max"] = (np.max(rewards[j]), np.max) | ||
|
|
||
| # Log individual rewards for this trajectory | ||
| for i, reward_fn in enumerate(self.reward_fns): | ||
| metric_name = f"rewards/{reward_fn.__name__}" | ||
| metrics_to_log[metric_name] = (rewards[j, i], np.mean) | ||
|
|
||
| # Log all metrics for this trajectory in one call | ||
| if step is not None: | ||
| self.rl_cluster.buffer_metrics_async( | ||
| metrics_to_log, mode=mode, step=step | ||
| ) | ||
| else: | ||
| self.rl_cluster.buffer_metrics(metrics_to_log, mode=mode) | ||
|
|
||
| return jnp.array(rewards) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This overridden _compute_rewards method changes the return type contract from its parent class RLLearner. The base method returns a 1D np.ndarray of summed rewards, but this implementation returns a 2D jnp.ndarray of per-function rewards.
This violates the Liskov Substitution Principle and can lead to unexpected behavior or runtime errors in other parts of the system that rely on the base class's contract, particularly the metric functions (metric_fns) which are passed the output of this method and likely expect a 1D array of summed rewards.
Additionally, the docstring for this method (lines 81-84) is incorrect as it claims the method returns a 1D array of summed rewards, which doesn't match the implementation.
To fix this, consider overriding _generate_and_compute_advantage in GDPOLearner to handle the GDPO-specific reward computation, while ensuring that the rewards passed to metric functions is the 1D array of summed rewards to maintain compatibility.
| rewards_per_func = jnp.nan_to_num(rewards) | ||
| all_reward_advantage = [] | ||
| for reward_index in range(rewards.shape[-1]): | ||
| reward_for_index = rewards_per_func[:, reward_index] | ||
| each_reward_mean_grouped = reward_for_index.reshape( | ||
| -1, num_generations | ||
| ).mean(axis=1) | ||
| each_reward_std_grouped = reward_for_index.reshape(-1, num_generations).std( | ||
| axis=1 | ||
| ) | ||
| each_reward_mean_grouped = each_reward_mean_grouped.repeat(num_generations) | ||
| each_reward_std_grouped = each_reward_std_grouped.repeat(num_generations) | ||
| each_reward_advantage = reward_for_index - each_reward_mean_grouped | ||
| each_reward_advantage = each_reward_advantage / ( | ||
| each_reward_std_grouped + 1e-4 | ||
| ) | ||
| all_reward_advantage.append(each_reward_advantage) | ||
|
|
||
| combined_reward_advantage = jnp.stack(all_reward_advantage, axis=1) | ||
| pre_bn_advantages = jnp.nansum(combined_reward_advantage, axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation uses a Python for loop to iterate over reward functions. This is not idiomatic for JAX and can be inefficient. For better performance and readability, this logic can be vectorized using reshape and broadcasting, avoiding the explicit loop.
This change aligns with the JAX-native philosophy of the library.
rewards_per_func = jnp.nan_to_num(rewards)
# Reshape to group by generations. Shape: (num_prompts, num_generations, num_reward_fns)
grouped_rewards = rewards_per_func.reshape(-1, num_generations, rewards.shape[-1])
# Compute mean and std per group for each reward function, using keepdims for broadcasting.
# Shape: (num_prompts, 1, num_reward_fns)
mean_grouped = grouped_rewards.mean(axis=1, keepdims=True)
std_grouped = grouped_rewards.std(axis=1, keepdims=True)
# Normalize within each group and reshape back.
normalized_advantages = (grouped_rewards - mean_grouped) / (std_grouped + 1e-4)
combined_reward_advantage = normalized_advantages.reshape(rewards_per_func.shape)
pre_bn_advantages = jnp.nansum(combined_reward_advantage, axis=1)References
- The style guide emphasizes following JAX-native patterns. Using vectorization instead of explicit Python loops is a core JAX pattern for performance and readability. (link)
Resolves #977
#977
Added implementation of GDPO. Shared below reference link and implementation. This paper improves upon GRPO.
Reference
Reference paper
https://arxiv.org/abs/2601.05242
Reference Implementation
https://github.com/NVlabs/GDPO/blob/b080e63d0126870ad08acc8ebc3f04b728175a9e/trl-GDPO/trl-0.18.0-gdpo/trl/trainer/grpo_trainer.py#L1222
Colab Notebook
NiL
Checklist