Multi stage pipeline parallelism support#418
Conversation
…elism support. (WIP)
…odel. Also made None returns more visible in get_module_class_from_name().
…ith interleaved 1F1B.
…tack traces/views).
- Switched from using abs=1e-16 to rel=1e-2 for loss comparisons. Need to investigate further, why this is necessary for some configurations. - Additional configs and test setups which are however commented out due to the long runtime of these tests. - Easier configurability for expected checkpoint paths (for debugging/messing around). - Better error logging.
| else: | ||
| assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." | ||
| sd = get_optimizer_state_dict( | ||
| model=app_state.model_parts[0], | ||
| optimizers=app_state.optimizer, | ||
| # NOTE: Flattening is required for pipeline parallelism to work correctly. | ||
| # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214 | ||
| options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
| ) |
There was a problem hiding this comment.
Should we remove this, since in case of PP we now always have an optimizer list which takes care of the flattening?
| @model_validator(mode="before") | ||
| @classmethod | ||
| def warn_deprecated_alias(cls, data: Any) -> Any: | ||
| if isinstance(data, dict) and "wrapped_model" in data: | ||
| warnings.warn( | ||
| "Field 'wrapped_model' is deprecated. Use 'wrapped_model_or_parts' instead.", | ||
| DeprecationWarning, | ||
| stacklevel=3, | ||
| ) | ||
| return data | ||
|
|
There was a problem hiding this comment.
Should we use this deprecation warning? If yes, should we use it also in other configs where a field got renamed to plural?
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), |
There was a problem hiding this comment.
These are currently deactivated due to the long runtime of these tests. Should we activate them anyways?
There was a problem hiding this comment.
The first and the third commented-out configs are the same, right?
There was a problem hiding this comment.
I don't think that
("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
is necessary since we already test
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2),
which is the same setup + data parallelism, correct?
And since we have
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1),
we can probably skip
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
There was a problem hiding this comment.
Yeah, these configs are mostly useful for debugging with fewer ranks. Probably makes sense to have them turned off (or even delete them in the future).
| ( # FIXME wpe and drop probably should not get the higher weight | ||
| ["transformer.wte", "transformer.wpe", "transformer.drop"], | ||
| self._input_layer_equivalence, | ||
| ), |
There was a problem hiding this comment.
I added this FIXME, anyone got an opinion on whether I can remove wpe and drop from this list?
rrutmann
left a comment
There was a problem hiding this comment.
There are some tests failing for me:
/workspaces/modalities/tests/conversion/gpt2/test_conversion_model.py::test_convert_model_checkpoint_produces_same_logits_as_original[gpt2_config_test.yaml-False]
TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'
/workspaces/modalities/tests/conversion/gpt2/test_convert_gpt2.py::test_converting_gpt2_does_not_change_outputs[gpt2_config_test.yaml-False]
TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'
/workspaces/modalities/tests/fsdp2_parallelization/test_tensor_parallelism.py::TestTensorParallelism::test_tp_sharding[swiglu-fsdp2_config_path1-tp_config_path1]
torch.multiprocessing.spawn.ProcessExitedException: process 2 terminated with signal SIGABRT
As well as an error importing one of the tests:
______ ERROR collecting tests/checkpointing/test_checkpoint_conversion.py ______
tests/checkpointing/test_checkpoint_conversion.py:59: in
@pytest.mark.skipif(
/home/richard-rutmann/.local/lib/python3.11/site-packages/_pytest/mark/structures.py:401: in call
store_mark(unwrapped_func, self.mark, stacklevel=3)
/home/richard-rutmann/.local/lib/python3.11/site-packages/_pytest/mark/structures.py:466: in store_mark
warnings.warn(MARKED_FIXTURE, stacklevel=stacklevel)
E pytest.PytestRemovedIn9Warning: Marks applied to fixtures have no effect
E See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), |
There was a problem hiding this comment.
The first and the third commented-out configs are the same, right?
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), |
There was a problem hiding this comment.
I don't think that
("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
is necessary since we already test
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2),
which is the same setup + data parallelism, correct?
And since we have
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1),
we can probably skip
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
rrutmann
left a comment
There was a problem hiding this comment.
Great work, thank you. A few tests are failing (see my comment), but aside from that, no major changes required from my side
Also enabled extra="forbid" in BaseModel to prevent accidental extra fields.
Note: Only strings are supported, not more complex path aliases.
…ecated all aliases created due to multi stage pp.
Co-authored-by: Richard Rutmann <97447451+rrutmann@users.noreply.github.com>
…s in code base. Also added missing deprecation marker for GPT2MFUCalculatorConfig.
| lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None. | ||
| """ | ||
| self._model = model | ||
| self._model_parts = list(model) if isinstance(model, list) else [model] |
There was a problem hiding this comment.
| self._model_parts = list(model) if isinstance(model, list) else [model] | |
| self._model_parts = model if isinstance(model, list) else [model] |
There was a problem hiding this comment.
I think, creating a new list here is saver in case an outside context accidentally changes the input list.
| @staticmethod | ||
| def get_state_dict(app_state: AppState) -> dict[str, Any]: | ||
| """Returns the state dict of the model in the AppState object. | ||
| """Returns the flattened state dicts of the model parts in the AppState object. |
There was a problem hiding this comment.
I guess flattened keys. Though, I'm not sure if I would call it that. We are mapping from a list of dicts to a single dict. Flattened keys sounds more like flattening a dict of dicts.
| dict[str, Any]: The state dict of the model in the AppState object. | ||
| """ | ||
| return get_model_state_dict(model=app_state.model) | ||
| return {k: v for sd in map(get_model_state_dict, app_state.model_parts) for k, v in sd.items()} |
There was a problem hiding this comment.
are we sure that k is always unique across model parts? Should we maybe throw an exception if k is not unique?
There was a problem hiding this comment.
It is assumed that the model parts are distinct and thus have distinct. I'll modify the function to check this is fulfilled.
| else: | ||
| assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." | ||
| sd = get_optimizer_state_dict( | ||
| model=app_state.model_parts[0], | ||
| optimizers=app_state.optimizer, | ||
| # NOTE: Flattening is required for pipeline parallelism to work correctly. | ||
| # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214 | ||
| options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
| ) |
| assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." | ||
| set_optimizer_state_dict( | ||
| model=app_state.model_parts[0], | ||
| optimizers=app_state.optimizer, | ||
| optim_state_dict=state_dict, | ||
| options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
| ) | ||
|
|
There was a problem hiding this comment.
Given your comment above, I assume the else case can also be removed here?
| component_config_type=component_config_type, | ||
| ) | ||
| comp_config = component_config_type(**config_dict, strict=True) | ||
| comp_config = component_config_type.model_validate(config_dict, extra="forbid") |
| scheduled_pipeline: Pipeline | None = None, | ||
| ): | ||
| if num_train_steps_done % evaluation_interval_in_steps == 0: | ||
| if num_train_steps_done % evaluation_interval_in_steps == 0 and num_train_steps_done > 0: |
There was a problem hiding this comment.
here, we should add a note with the details regarding the error that we were experiencing otherwise.
There was a problem hiding this comment.
Added a corresponding TODO.
… are mutually distinct.
…ge' into pp_multi_stage
What does this PR do?
Adds support for multi stage pipeline parallelism schedules, in particular interleaved 1F1B.
Issue #408
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)