Added script for updating old checkpoints and configs.#397
Added script for updating old checkpoints and configs.#397BlueCrescent wants to merge 3 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds a utility script for migrating old model checkpoints and configuration files to a new format. The script handles both configuration file updates and checkpoint state dictionary transformations to maintain compatibility with updated model structures.
Key Changes
- Added comprehensive checkpoint and config migration script with YAML processing
- Implemented state dictionary updates for model weight key transformations
- Added validation functionality to test updated configurations
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
…ariables. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
le1nux
left a comment
There was a problem hiding this comment.
I'm a bit hesitating if the automated config updates are the way to go, or if we should provide documentation e.g., a diff for model_raw, explaining how to update the models.
The reason is that we are updating now based on existing component names, e.g., checkpointed_model. However, the configs themselves never enforce certain component names, which is why if the user renames checkpointed_model to something like my_checkpointed_model, then the conversion script already fails.
Also, we are deleting some components, which are still used if you create the diff between these two configs:
Nevertheless, I think that the automated checkpoint update is still useful and I would place it in a backward_compatibility/ module.
| old_model_config = sys.argv[1] | ||
| new_model_config = sys.argv[2] |
| config_type = dict[str, "str | config_type"] | ||
|
|
||
|
|
||
| def update_model(old_model_config: str, new_model_config: str, new_checkpoint_path: str | None): |
|
|
||
| def add_new_keys(config: config_type): | ||
| model_config = config["model_raw" if "model_raw" in config else "model"]["config"] | ||
| model_config["use_weight_tying"] = False |
There was a problem hiding this comment.
weight tying we also had before. Why are we hardcoding this to False now?
| if "evaluation_subscriber" in config and "experiment_id" in config["evaluation_subscriber"]["config"]: | ||
| del config["evaluation_subscriber"]["config"]["experiment_id"] | ||
| if "settings" in config and "experiment_id" in config["settings"]: | ||
| del config["settings"]["experiment_id"] | ||
| if ( | ||
| "checkpoint_saving" in config | ||
| and "checkpoint_saving_execution" in config["checkpoint_saving"]["config"] | ||
| and "experiment_id" in config["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"] | ||
| ): | ||
| del config["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"]["experiment_id"] |
There was a problem hiding this comment.
Also in the current FSDP2 config e.g., https://github.com/Modalities/modalities/blob/83c87b9d6d6fbbb228bab31dccf1870b12679775/config_files/training/config_lorem_ipsum_long_fsdp2.yaml we still have all of this.
|
|
||
|
|
||
| def rename_keys(config: config_type): | ||
| model_config = config["model_raw" if "model_raw" in config else "model"]["config"] |
There was a problem hiding this comment.
we could have the convention that general model must be always named model_raw.
We are already enforcing it here:
| new_model_config = sys.argv[2] | ||
| new_checkpoint_path = sys.argv[3] if len(sys.argv) > 3 else None | ||
|
|
||
| update_model(old_model_config, new_model_config, new_checkpoint_path) |
There was a problem hiding this comment.
I would make updating checkpoint and updating the config two separate functions that get called sequentially here.
| old_norm_keys = ["attention_norm", "ffn_norm", "lm_head_norm"] | ||
| new_norm_keys = ["attention_norm_config", "ffn_norm_config", "lm_head_norm_config"] | ||
| for old_key, new_key in zip(old_norm_keys, new_norm_keys): | ||
| rename_config_key(model_config, old_key, new_key) | ||
| rename_config_key(model_config[new_key], "variant_key", "norm_type") |
There was a problem hiding this comment.
We should delete component_key, no?
| if new_checkpoint_path is not None: | ||
| if "checkpointed_model" in config: | ||
| old_path = config["checkpointed_model"]["config"]["checkpoint_path"] | ||
| config["checkpointed_model"]["config"]["checkpoint_path"] = new_checkpoint_path |
There was a problem hiding this comment.
I checked all configs, where did you see checkpointed_model?
| """ | ||
| state_dict = torch.load(old_model_path) | ||
| if "lm_head.weight" in state_dict: | ||
| state_dict["transformer.lm_head.weight"] = state_dict["lm_head.weight"] |
There was a problem hiding this comment.
How would this behave, if we used weight tying?
Do we store them twice (i.e., embeddings and lm_head) and then internally replace the lm_head with a reference to the embeddings weights?
What does this PR do?
This PR adds a script for updating old checkpoints and configs.
General Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)