Perform gradient clipping on global batch when using gradient accumulation#6
Perform gradient clipping on global batch when using gradient accumulation#6ashors1 wants to merge 9 commits intogoogle:mainfrom
Conversation
praxis/optimizers.py
Outdated
|
|
||
| raw_grad_norm = _compute_grad_norm(raw_grads) | ||
|
|
||
| grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) |
There was a problem hiding this comment.
do we need to compute and return grad_scale?
There was a problem hiding this comment.
This is not needed. I no longer return grad_scale with the latest commit
praxis/optimizers.py
Outdated
| grad_scale = jnp.array(1.0) | ||
| return grads, grad_scale | ||
|
|
||
| raw_grad_norm = _compute_grad_norm(raw_grads) |
There was a problem hiding this comment.
iiuc, if clip_grad_single_norm_to_value is True, then raw_grad_norm is not used and we have to compute grad_single_norm separately anyways?
can we move the if-elif-else statement inside out and avoid redundant computation?
There was a problem hiding this comment.
Definitely. I have addressed this with my latest commit
praxis/optimizers.py
Outdated
|
|
||
| def scale_gradients( | ||
| raw_grads: NestedMap, | ||
| clip_grad_norm_to_value: Optional[float] = None, |
There was a problem hiding this comment.
looking at praxis optimizers, clip_gradient_norm_to_value and clip_gradient_single_norm_to_value default are 0.0 and not None right?
so perhaps the types here should be float and default 0.0 instead of Optional?
| clip_grad_single_norm_to_value: Optional[float] = None): | ||
|
|
||
| def clip_grads(grads): | ||
| if clip_grad_norm_to_value: |
There was a problem hiding this comment.
maybe assert only one of them is true?
Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using
ShardedStaticAccumulator. Note that this refactor allows us to maintain support forenable_skip_step_on_gradient_anomaliesand requiresx+1grad norm calculations per global batch when usingShardedStaticAccumulatorwithxsubbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.This PR should be taken together with the corresponding Paxml PR.