From 4f5de830167124b08131249bbf4d669b9d294a00 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Mon, 12 Jan 2026 05:49:04 -0800 Subject: [PATCH 1/5] fix: include model channel for gated uncompressed models --- .../src/sagemaker/core/jumpstart/cache.py | 22 +++++++++++++++---- .../src/sagemaker/core/jumpstart/types.py | 16 ++++++++++---- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/jumpstart/cache.py b/sagemaker-core/src/sagemaker/core/jumpstart/cache.py index c1084dcf40..035e57359e 100644 --- a/sagemaker-core/src/sagemaker/core/jumpstart/cache.py +++ b/sagemaker-core/src/sagemaker/core/jumpstart/cache.py @@ -372,10 +372,18 @@ def _get_json_file( object and None when reading from the local file system. """ if self._is_local_metadata_mode(): - file_content, etag = self._get_json_file_from_local_override(key, filetype), None - else: - file_content, etag = self._get_json_file_and_etag_from_s3(key) - return file_content, etag + if filetype in { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartS3FileType.OPEN_WEIGHT_SPECS, + }: + return self._get_json_file_from_local_override(key, filetype), None + else: + JUMPSTART_LOGGER.warning( + "Local metadata mode is enabled, but the file type %s is not supported " + "for local override. Falling back to s3.", + filetype, + ) + return self._get_json_file_and_etag_from_s3(key) def _get_json_md5_hash(self, key: str): """Retrieves md5 object hash for s3 objects, using `s3.head_object`. @@ -552,6 +560,12 @@ def _select_version( ) return version_str if version_str in available_versions else None + if version_str[-1] == "*": + # major or minor version is pinned, e.g 1.* or 1.0.* + return utils.get_latest_version( + [version for version in available_versions if version.startswith(version_str[:-1])] + ) + try: spec = SpecifierSet(f"=={version_str}") except InvalidSpecifier: diff --git a/sagemaker-core/src/sagemaker/core/jumpstart/types.py b/sagemaker-core/src/sagemaker/core/jumpstart/types.py index 33753c7ded..3eb9e2d432 100644 --- a/sagemaker-core/src/sagemaker/core/jumpstart/types.py +++ b/sagemaker-core/src/sagemaker/core/jumpstart/types.py @@ -1943,12 +1943,20 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" - # gated model never use training model artifact - if self.gated_bucket: + # old models with this environment variable present don't use model channel + if any( + self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( + instance_type + ) + for instance_type in self.supported_training_instance_types + ): + return False + + # even older models with training model package artifact uris present also don't use model channel + if len(self.training_model_package_artifact_uris or {}) > 0: return False - # otherwise, return true is a training model package is not set - return len(self.training_model_package_artifact_uris or {}) == 0 + return getattr(self, "training_artifact_key", None) is not None def is_gated_model(self) -> bool: """Returns True if the model has a EULA key or the model bucket is gated.""" From e25bc8c22b76c3d2db63cb85dbedfc7039ce6e32 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Mon, 12 Jan 2026 07:05:53 -0800 Subject: [PATCH 2/5] feat: Added Amazon Nova training support for ModelTrainer --- .../src/sagemaker/train/constants.py | 4 + .../src/sagemaker/train/model_trainer.py | 105 ++++++- .../src/sagemaker/train/sm_recipes/utils.py | 258 +++++++++++++++--- 3 files changed, 319 insertions(+), 48 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 309265d659..0778b2601d 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -29,6 +29,10 @@ os.path.dirname(os.path.abspath(__file__)), "container_drivers" ) +SM_RECIPE = "recipe" +SM_RECIPE_YAML = "recipe.yaml" +SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}" + SOURCE_CODE_JSON = "sourcecode.json" DISTRIBUTED_JSON = "distributed.json" TRAIN_SCRIPT = "sm_train.sh" diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index 1a1fcab410..49d4429db8 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -83,6 +83,9 @@ SM_CODE_CONTAINER_PATH, SM_DRIVERS, SM_DRIVERS_LOCAL_PATH, + SM_RECIPE, + SM_RECIPE_YAML, + SM_RECIPE_CONTAINER_PATH, TRAIN_SCRIPT, DEFAULT_CONTAINER_ENTRYPOINT, DEFAULT_CONTAINER_ARGUMENTS, @@ -100,7 +103,12 @@ from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature from sagemaker.train import logger -from sagemaker.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type +from sagemaker.modules.train.sm_recipes.utils import ( + _get_args_from_recipe, + _determine_device_type, + _is_nova_recipe, + _load_base_recipe, +) from sagemaker.core.jumpstart.configs import JumpStartConfig from sagemaker.core.jumpstart.document import get_hub_content_and_document @@ -248,6 +256,7 @@ class ModelTrainer(BaseModel): _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) + _is_nova_recipe: Optional[bool] = PrivateAttr(default=None) # Private Attributes for Recipes _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) @@ -499,6 +508,37 @@ def _validate_and_fetch_hyperparameters_file(hyperparameters_file: str): ) return hyperparameters + @staticmethod + def _validate_and_load_hyperparameters_file(hyperparameters_file: str) -> Dict[str, Any]: + """Validate the hyperparameters file.""" + if not os.path.exists(hyperparameters_file): + raise ValueError(f"Hyperparameters file not found: {hyperparameters_file}") + + logger.info(f"Loading hyperparameters from file: {hyperparameters_file}") + + with open(hyperparameters_file, "r") as f: + contents = f.read() + try: + hyperparameters = json.loads(contents) + logger.debug("Hyperparameters loaded as JSON") + return hyperparameters + except json.JSONDecodeError: + try: + logger.info(f"contents: {contents}") + hyperparameters = yaml.safe_load(contents) + if not isinstance(hyperparameters, dict): + raise ValueError("YAML contents must be a valid mapping") + + logger.info(f"hyperparameters: {hyperparameters}") + logger.debug("Hyperparameters loaded as YAML") + + return hyperparameters + except (yaml.YAMLError, ValueError): + raise ValueError( + f"Invalid hyperparameters file: {hyperparameters_file}. " + "Must be a valid JSON or YAML file." + ) + def model_post_init(self, __context: Any): """Post init method to perform custom validation and set default values.""" self._validate_training_image_and_algorithm_name(self.training_image, self.algorithm_name) @@ -506,8 +546,8 @@ def model_post_init(self, __context: Any): self._validate_distributed_config(self.source_code, self.distributed) if self.hyperparameters and isinstance(self.hyperparameters, str): - self.hyperparameters = self._validate_and_fetch_hyperparameters_file( - hyperparameters_file=self.hyperparameters + self.hyperparameters = self._validate_and_load_hyperparameters_file( + self.hyperparameters ) if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: @@ -573,6 +613,23 @@ def _create_training_job_args( final_input_data_config = list(existing_channels.values()) + new_channels + if self._is_nova_recipe: + for input_data in final_input_data_config: + if input_data.channel_name == SM_RECIPE: + raise ValueError( + "Cannot use reserved channel name 'recipe' as an input channel name " + " for Nova Recipe" + ) + + recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML) + recipe_channel = self.create_input_data_channel( + channel_name=SM_RECIPE, + data_source=recipe_file_path, + key_prefix=input_data_key_prefix, + ) + final_input_data_config.append(recipe_channel) + self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH}) + if final_input_data_config: final_input_data_config = self._get_input_data_config( final_input_data_config, input_data_key_prefix @@ -1039,6 +1096,7 @@ def from_recipe( checkpoint_config: Optional[shapes.CheckpointConfig] = None, training_input_mode: Optional[str] = "File", environment: Optional[Dict[str, str]] = None, + hyperparameters: Optional[Union[Dict[str, Any], str]] = {}, tags: Optional[List[Tag]] = None, sagemaker_session: Optional[Session] = None, role: Optional[str] = None, @@ -1134,19 +1192,32 @@ def from_recipe( or training image. """ if compute.instance_type is None: - raise ValueError("Must set ``instance_type`` in Compute when using training recipes.") + raise ValueError( + "Must set ``instance_type`` in ``compute`` input when using training recipes." + ) device_type = _determine_device_type(compute.instance_type) - if device_type == "cpu": + recipe = _load_base_recipe( + training_recipe=training_recipe, recipe_overrides=recipe_overrides + ) + is_nova = _is_nova_recipe(recipe=recipe) + + if device_type == "cpu" and not is_nova: raise ValueError( - "Training recipes are not supported for CPU instances. " - "Please provide a GPU or Tranium instance type." + "Training recipe is not supported for CPU instances. " + + "Please provide a GPU or Tranium instance type." ) + if training_image is None and is_nova: + raise ValueError("training_image must be provided when using recipe for Nova.") if training_image_config and training_image is None: raise ValueError("training_image must be provided when using training_image_config.") - sagemaker_session = TrainDefaults.get_sagemaker_session(sagemaker_session) - role = TrainDefaults.get_role(role=role, sagemaker_session=sagemaker_session) + if sagemaker_session is None: + sagemaker_session = Session() + logger.warning("SageMaker session not provided. Using default Session.") + if role is None: + role = get_execution_role(sagemaker_session=sagemaker_session) + logger.warning(f"Role not provided. Using default role:\n{role}") # The training recipe is used to prepare the following args: # - source_code @@ -1154,15 +1225,27 @@ def from_recipe( # - distributed # - compute # - hyperparameters - model_trainer_args, recipe_train_dir = _get_args_from_recipe( - training_recipe=training_recipe, + model_trainer_args, tmp_dir = _get_args_from_recipe( + training_recipe=recipe, recipe_overrides=recipe_overrides, requirements=requirements, compute=compute, region_name=sagemaker_session.boto_region_name, + role=role, ) if training_image is not None: model_trainer_args["training_image"] = training_image + if hyperparameters and not is_nova: + logger.warning( + "Hyperparameters are not supported for general training recipes. " + + "Ignoring hyperparameters input." + ) + if is_nova: + if hyperparameters and isinstance(hyperparameters, str): + hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters) + model_trainer_args["hyperparameters"].update(hyperparameters) + elif hyperparameters and isinstance(hyperparameters, dict): + model_trainer_args["hyperparameters"].update(hyperparameters) model_trainer = cls( sagemaker_session=sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py index f234d12d20..ecf4929da4 100644 --- a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py +++ b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py @@ -19,20 +19,21 @@ import shutil import tempfile from urllib.request import urlretrieve -from typing import Dict, Any, Optional, Tuple +from typing import Dict, Any, Optional, Tuple, Union import omegaconf -from omegaconf import OmegaConf, dictconfig +from omegaconf import OmegaConf, dictconfig, DictConfig -# from sagemaker.utils.image_uris import retrieve +from sagemaker.core.image_uris import retrieve from sagemaker.train import logger from sagemaker.train.utils import _run_clone_command_silent +from sagemaker.train.constants import SM_RECIPE_YAML from sagemaker.train.configs import Compute, SourceCode from sagemaker.train.distributed import Torchrun, SMP -def _try_resolve_recipe(recipe, key=None): +def _try_resolve_recipe(recipe: DictConfig, key=None) -> DictConfig: """Try to resolve recipe and return resolved recipe.""" if key is not None: recipe = dictconfig.DictConfig({key: recipe}) @@ -45,6 +46,19 @@ def _try_resolve_recipe(recipe, key=None): return recipe[key] +def _resolve_final_recipe(recipe: DictConfig): + """Resolve final recipe.""" + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + return final_recipe + + def _determine_device_type(instance_type: str) -> str: """Determine device type (gpu, cpu, trainium) based on instance type.""" instance_family = instance_type.split(".")[1] @@ -86,6 +100,8 @@ def _load_base_recipe( ) else: recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") + if training_recipes_cfg is None: + training_recipes_cfg = _load_recipes_cfg() launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( "launcher_repo" @@ -133,6 +149,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), + "gpt_oss": ("custom_model", "custom_pretrain.py"), } for key in model_type_to_script: @@ -149,7 +166,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): def _configure_gpu_args( training_recipes_cfg: Dict[str, Any], region_name: str, - recipe: OmegaConf, + recipe: DictConfig, recipe_train_dir: tempfile.TemporaryDirectory, ) -> Dict[str, Any]: """Configure arguments specific to GPU.""" @@ -176,14 +193,13 @@ def _configure_gpu_args( if isinstance(gpu_image_cfg, str): training_image = gpu_image_cfg else: - # training_image = retrieve( - # gpu_image_cfg.get("framework"), - # region=region_name, - # version=gpu_image_cfg.get("version"), - # image_scope="training", - # **gpu_image_cfg.get("additional_args"), - # ) - training_image = "dummy_image" # Placeholder for actual image retrieval + training_image = retrieve( + gpu_image_cfg.get("framework"), + region=region_name, + version=gpu_image_cfg.get("version"), + image_scope="training", + **gpu_image_cfg.get("additional_args"), + ) # Setting dummy parameters for now torch_distributed = Torchrun(smp=SMP(random_seed="123456")) @@ -214,14 +230,13 @@ def _configure_trainium_args( if isinstance(neuron_image_cfg, str): training_image = neuron_image_cfg else: - # training_image = retrieve( - # neuron_image_cfg.get("framework"), - # region=region_name, - # version=neuron_image_cfg.get("version"), - # image_scope="training", - # **neuron_image_cfg.get("additional_args"), - # ) - training_image = "dummy_image" # Placeholder for actual image retrieval + training_image = retrieve( + neuron_image_cfg.get("framework"), + region=region_name, + version=neuron_image_cfg.get("version"), + image_scope="training", + **neuron_image_cfg.get("additional_args"), + ) args.update( { @@ -233,12 +248,181 @@ def _configure_trainium_args( return args +def _is_nova_recipe( + recipe: DictConfig, +) -> bool: + """Check if the recipe is a Nova recipe. + + A recipe is considered a Nova recipe if it meets either of the following conditions: + + 1. It has a run section with: + - A model_type that includes "amazon.nova" + - A model_name_or_path field + + OR + + 2. It has a training_config section with: + - A distillation_data field + + Args: + recipe (DictConfig): The loaded recipe configuration + + Returns: + bool: True if the recipe is a Nova recipe, False otherwise + """ + run_config = recipe.get("run", {}) + model_type = run_config.get("model_type", "").lower() + has_nova_model = ( + model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config + ) + + # Check for distillation data + training_config = recipe.get("training_config", {}) + has_distillation = training_config.get("distillation_data") is not None + return bool(has_nova_model) or bool(has_distillation) + + +def _is_llmft_recipe( + recipe: DictConfig, +) -> bool: + """Check if the recipe is a LLMFT recipe. + + A recipe is considered a LLMFT recipe if it meets the following conditions: + 1. Having a run section + 2. The model_type in run is llmft + 3. Having a training_config section + + Args: + recipe (DictConfig): The loaded recipe configuration + + Returns: + bool: True if the recipe is a LLMFT recipe, False otherwise + """ + run_config = recipe.get("run", {}) + has_llmft_model = run_config.get("model_type", "").lower() == "llm_finetuning_aws" + return bool(has_llmft_model) and bool(recipe.get("training_config")) + + +def _get_args_from_nova_recipe( + recipe: DictConfig, + compute: Compute, + role: Optional[str] = None, +) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: + if not compute.instance_count and not recipe.get("run", {}).get("replicas", None): + raise ValueError("Must set ``instance_type`` in compute or ``replicas`` in recipe.") + compute.instance_count = compute.instance_count or recipe.get("run", {}).get("replicas") + + args = dict() + args.update({"hyperparameters": {}}) + + run_config = recipe.get("run", {}) + model_name_or_path = run_config.get("model_name_or_path") + if model_name_or_path: + if model_name_or_path.startswith("s3://"): + args["hyperparameters"]["base_model_location"] = model_name_or_path + else: + args["hyperparameters"]["base_model"] = model_name_or_path + + # Handle distillation configuration + training_config = recipe.get("training_config", {}) + distillation_data = training_config.get("distillation_data") + if bool(distillation_data): + args["hyperparameters"]["distillation_data"] = distillation_data + if not role: + raise ValueError("Must provide 'role' parameter when using Nova distillation") + args["hyperparameters"]["role_arn"] = role + + kms_key = training_config.get("kms_key") + if kms_key is None: + raise ValueError( + 'Nova distillation job recipe requires "kms_key" field in "training_config"' + ) + args["hyperparameters"]["kms_key"] = kms_key + + # Handle eval custom lambda configuration + if recipe.get("evaluation", {}): + processor = recipe.get("processor", {}) + lambda_arn = processor.get("lambda_arn", "") + if lambda_arn: + args["hyperparameters"]["eval_lambda_arn"] = lambda_arn + + # Handle reward lambda configuration + run_config = recipe.get("run", {}) + reward_lambda_arn = run_config.get("reward_lambda_arn", "") + if reward_lambda_arn: + args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn + + _register_custom_resolvers() + + # Resolve Final Recipe + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + # Save Final Recipe to tmp dir + recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_") + final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML) + OmegaConf.save(config=final_recipe, f=final_recipe_path) + + args.update( + { + "compute": compute, + "training_image": None, + "source_code": None, + "distributed": None, + } + ) + return args, recipe_local_dir + + +def _get_args_from_llmft_recipe( + recipe: DictConfig, + compute: Compute, +) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: + + if not compute.instance_count and not recipe.get("trainer", {}).get("num_nodes", None): + raise ValueError( + "Must set ``instance_count`` in compute or ``num_nodes`` in trainer in recipe." + ) + if compute.instance_count and recipe.get("trainer", {}).get("num_nodes", None) is not None: + logger.warning( + f"Using Compute to set instance_count:\n{compute}." + "\nIgnoring trainer -> num_nodes in recipe." + ) + compute.instance_count = compute.instance_count or recipe.get("trainer", {}).get("num_nodes") + + args = dict() + + _register_custom_resolvers() + final_recipe = _resolve_final_recipe(recipe) + + # Save Final Recipe to tmp dir + recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_") + final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML) + OmegaConf.save(config=final_recipe, f=final_recipe_path) + + args.update( + { + "compute": compute, + "training_image": None, + "source_code": None, + "distributed": None, + } + ) + return args, recipe_local_dir + + def _get_args_from_recipe( - training_recipe: str, + training_recipe: Union[str, DictConfig], compute: Compute, region_name: str, recipe_overrides: Optional[Dict[str, Any]], requirements: Optional[str], + role: Optional[str] = None, ) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: """Get arguments for ModelTrainer from a training recipe. @@ -254,8 +438,8 @@ def _get_args_from_recipe( ``` Args: - training_recipe (str): - Name of the training recipe or path to the recipe file. + training_recipe (Union[str, Dict[str, Any]]): + Name of the training recipe or path to the recipe file or loaded recipe Dict. compute (Compute): Compute configuration for training. region_name (str): @@ -269,7 +453,16 @@ def _get_args_from_recipe( raise ValueError("Must set `instance_type` in compute when using training recipes.") training_recipes_cfg = _load_recipes_cfg() - recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) + if isinstance(training_recipe, str): + recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) + else: + recipe = training_recipe + if _is_nova_recipe(recipe): + args, recipe_local_dir = _get_args_from_nova_recipe(recipe, compute, role=role) + return args, recipe_local_dir + if _is_llmft_recipe(recipe): + args, recipe_local_dir = _get_args_from_llmft_recipe(recipe, compute) + return args, recipe_local_dir if "trainer" not in recipe: raise ValueError("Supplied recipe does not contain required field trainer.") @@ -283,7 +476,7 @@ def _get_args_from_recipe( if compute.instance_count is None: if "num_nodes" not in recipe["trainer"]: raise ValueError( - "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe." + "Must provide Compute with instance_count or set trainer -> num_nodes in recipe." ) compute.instance_count = recipe["trainer"]["num_nodes"] @@ -313,7 +506,7 @@ def _get_args_from_recipe( # Save Final Recipe to source_dir OmegaConf.save( - config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml") + config=final_recipe, f=os.path.join(args["source_code"].source_dir, SM_RECIPE_YAML) ) # If recipe_requirements is provided, copy it to source_dir @@ -322,19 +515,10 @@ def _get_args_from_recipe( args["source_code"].requirements = os.path.basename(requirements) # Update args with compute and hyperparameters - hyperparameters = {"config-path": ".", "config-name": "recipe.yaml"} - - # Handle eval custom lambda configuration - if recipe.get("evaluation", {}): - processor = recipe.get("processor", {}) - lambda_arn = processor.get("lambda_arn", "") - if lambda_arn: - hyperparameters["lambda_arn"] = lambda_arn - args.update( { "compute": compute, - "hyperparameters": hyperparameters, + "hyperparameters": {"config-path": ".", "config-name": SM_RECIPE_YAML}, } ) From b21a0bfff7fa8754e69a454422814f408ad8e16b Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Mon, 12 Jan 2026 07:23:01 -0800 Subject: [PATCH 3/5] feat: integrate amtviz for visualization of tuning jobs --- .../src/sagemaker/core/amtviz/__init__.py | 27 + .../src/sagemaker/core/amtviz/job_metrics.py | 180 ++++ .../sagemaker/core/amtviz/visualization.py | 857 ++++++++++++++++++ sagemaker-train/src/sagemaker/train/tuner.py | 65 ++ 4 files changed, 1129 insertions(+) create mode 100644 sagemaker-core/src/sagemaker/core/amtviz/__init__.py create mode 100644 sagemaker-core/src/sagemaker/core/amtviz/job_metrics.py create mode 100644 sagemaker-core/src/sagemaker/core/amtviz/visualization.py diff --git a/sagemaker-core/src/sagemaker/core/amtviz/__init__.py b/sagemaker-core/src/sagemaker/core/amtviz/__init__.py new file mode 100644 index 0000000000..8554b32c4a --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/amtviz/__init__.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Amazon SageMaker Automatic Model Tuning Visualization module. + +This module provides visualization capabilities for SageMaker hyperparameter tuning jobs. +It enables users to create interactive visualizations to analyze and understand the +performance of hyperparameter optimization experiments. + +Example: + >>> from sagemaker.amtviz import visualize_tuning_job + >>> visualize_tuning_job('my-tuning-job') +""" +from __future__ import absolute_import + +from sagemaker.amtviz.visualization import visualize_tuning_job + +__all__ = ["visualize_tuning_job"] diff --git a/sagemaker-core/src/sagemaker/core/amtviz/job_metrics.py b/sagemaker-core/src/sagemaker/core/amtviz/job_metrics.py new file mode 100644 index 0000000000..b99886941f --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/amtviz/job_metrics.py @@ -0,0 +1,180 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Helper functions to retrieve job metrics from CloudWatch.""" +from __future__ import absolute_import + +from datetime import datetime, timedelta +from typing import Callable, List, Optional, Tuple, Dict, Any +import hashlib +import os +from pathlib import Path + +import logging +import pandas as pd +import numpy as np +import boto3 + +logger = logging.getLogger(__name__) + +cw = boto3.client("cloudwatch") +sm = boto3.client("sagemaker") + + +def disk_cache(outer: Callable) -> Callable: + """A decorator that implements disk-based caching for CloudWatch metrics data. + + This decorator caches the output of the wrapped function to disk in JSON Lines format. + It creates a cache key using MD5 hash of the function arguments and stores the data + in the user's home directory under .amtviz/cw_metrics_cache/. + + Args: + outer (Callable): The function to be wrapped. Must return a pandas DataFrame + containing CloudWatch metrics data. + + Returns: + Callable: A wrapper function that implements the caching logic. + """ + + def inner(*args: Any, **kwargs: Any) -> pd.DataFrame: + key_input = str(args) + str(kwargs) + # nosec b303 - Not used for cryptography, but to create lookup key + key = hashlib.md5(key_input.encode("utf-8")).hexdigest() + cache_dir = Path.home().joinpath(".amtviz/cw_metrics_cache") + fn = f"{cache_dir}/req_{key}.jsonl.gz" + if Path(fn).exists(): + try: + df = pd.read_json(fn, lines=True) + logger.debug("H", end="") + df["ts"] = pd.to_datetime(df["ts"]) + df["ts"] = df["ts"].dt.tz_localize(None) + # pyright: ignore [reportIndexIssue, reportOptionalSubscript] + df["rel_ts"] = pd.to_datetime(df["rel_ts"]) + df["rel_ts"] = df["rel_ts"].dt.tz_localize(None) + return df + except KeyError: + # Empty file leads to empty df, hence no df['ts'] possible + pass + # nosec b110 - doesn't matter why we could not load it. + except BaseException as e: + logger.error("\nException: %s - %s", type(e), e) + + logger.debug("M", end="") + df = outer(*args, **kwargs) + assert isinstance(df, pd.DataFrame), "Only caching Pandas DataFrames." + + os.makedirs(cache_dir, exist_ok=True) + df.to_json(fn, orient="records", date_format="iso", lines=True) + + return df + + return inner + + +def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> Dict[str, Any]: + """Returns a CloudWatch metric data query template.""" + return { + "Id": metric_name.lower().replace(":", "_").replace("-", "_"), + "MetricStat": { + "Stat": "Average", + "Metric": { + "Namespace": "/aws/sagemaker/TrainingJobs", + "MetricName": metric_name, + "Dimensions": [ + {"Name": dim_name, "Value": dim_value}, + ], + }, + "Period": 60, + }, + "ReturnData": True, + } + + +def _get_metric_data( + queries: List[Dict[str, Any]], start_time: datetime, end_time: datetime +) -> pd.DataFrame: + """Fetches CloudWatch metrics between timestamps, returns a DataFrame with selected columns.""" + start_time = start_time - timedelta(hours=1) + end_time = end_time + timedelta(hours=1) + response = cw.get_metric_data(MetricDataQueries=queries, StartTime=start_time, EndTime=end_time) + + df = pd.DataFrame() + if "MetricDataResults" not in response: + return df + + for metric_data in response["MetricDataResults"]: + values = metric_data["Values"] + ts = np.array(metric_data["Timestamps"], dtype=np.datetime64) + labels = [metric_data["Label"]] * len(values) + + df = pd.concat([df, pd.DataFrame({"value": values, "ts": ts, "label": labels})]) + + # We now calculate the relative time based on the first actual observed + # time stamps, not the potentially start time that we used to scope our CW + # API call. The difference could be for example startup times or waiting + # for Spot. + if not df.empty: + df["rel_ts"] = datetime.fromtimestamp(1) + (df["ts"] - df["ts"].min()) # pyright: ignore + return df + + +@disk_cache +def _collect_metrics( + dimensions: List[Tuple[str, str]], start_time: datetime, end_time: Optional[datetime] +) -> pd.DataFrame: + """Collects SageMaker training job metrics from CloudWatch for dimensions and time range.""" + df = pd.DataFrame() + for dim_name, dim_value in dimensions: + response = cw.list_metrics( + Namespace="/aws/sagemaker/TrainingJobs", + Dimensions=[ + {"Name": dim_name, "Value": dim_value}, + ], + ) + if not response["Metrics"]: + continue + metric_names = [metric["MetricName"] for metric in response["Metrics"]] + if not metric_names: + # No metric data yet, or not any longer, because the data were aged out + continue + metric_data_queries = [ + _metric_data_query_tpl(metric_name, dim_name, dim_value) for metric_name in metric_names + ] + df = pd.concat([df, _get_metric_data(metric_data_queries, start_time, end_time)]) + + return df + + +def get_cw_job_metrics( + job_name: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None +) -> pd.DataFrame: + """Retrieves CloudWatch metrics for a SageMaker training job. + + Args: + job_name (str): Name of the SageMaker training job. + start_time (datetime, optional): Start time for metrics collection. + Defaults to now - 4 hours. + end_time (datetime, optional): End time for metrics collection. + Defaults to start_time + 4 hours. + + Returns: + pd.DataFrame: Metrics data with columns for value, timestamp, and metric name. + Results are cached to disk for improved performance. + """ + dimensions = [ + ("TrainingJobName", job_name), + ("Host", job_name + "/algo-1"), + ] + # If not given, use reasonable defaults for start and end time + start_time = start_time or datetime.now() - timedelta(hours=4) + end_time = end_time or start_time + timedelta(hours=4) + return _collect_metrics(dimensions, start_time, end_time) diff --git a/sagemaker-core/src/sagemaker/core/amtviz/visualization.py b/sagemaker-core/src/sagemaker/core/amtviz/visualization.py new file mode 100644 index 0000000000..7f09117d1e --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/amtviz/visualization.py @@ -0,0 +1,857 @@ +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides visualization capabilities for SageMaker hyperparameter tuning jobs. + +It contains utilities to create interactive visualizations of hyperparameter tuning results +using Altair charts. The module enables users to analyze and understand the performance +of their hyperparameter optimization experiments through various visual representations +including: +- Progress of objective metrics over time +- Distribution of results +- Relationship between hyperparameters and objective values +- Training job metrics and instance utilization +- Comparative analysis across multiple tuning jobs + +Main Features: + - Visualize single or multiple hyperparameter tuning jobs + - Display training job metrics from CloudWatch + - Support for both completed and in-progress tuning jobs + - Interactive filtering and highlighting of data points + - CPU, memory, and GPU utilization visualization + - Advanced visualization options for detailed analysis + +Primary Classes and Functions: + - visualize_tuning_job: Main function to create visualizations for tuning jobs + - create_charts: Core chart creation functionality + - get_job_analytics_data: Retrieves and processes tuning job data + +Dependencies: + - altair: For creating interactive visualizations + - pandas: For data manipulation and analysis + - boto3: For AWS service interaction + - sagemaker: For accessing SageMaker resources +""" +from __future__ import absolute_import + +from typing import Union, List, Optional, Tuple +import os +import warnings +import logging +import altair as alt +import pandas as pd +import numpy as np +import boto3 +import sagemaker +from sagemaker.amtviz.job_metrics import get_cw_job_metrics + +warnings.filterwarnings("ignore") +logger = logging.getLogger(__name__) + +pd.set_option("display.max_rows", 500) +pd.set_option("display.max_columns", 500) +pd.set_option("display.width", 1000) +pd.set_option("display.max_colwidth", None) # Don't truncate TrainingJobName + + +alt.data_transformers.disable_max_rows() +altair_renderer = os.getenv("ALTAIR_RENDERER", "default") +logger.info("Setting altair renderer to %s.", altair_renderer) +alt.renderers.enable(altair_renderer) + + +sm = boto3.client("sagemaker") + + +def _columnize(charts: List[alt.Chart], cols: int = 2) -> alt.VConcatChart: + """Arrange charts in columns.""" + return alt.vconcat(*[alt.hconcat(*charts[i : i + cols]) for i in range(0, len(charts), cols)]) + + +def visualize_tuning_job( + tuning_jobs: Union[str, List[str], "sagemaker.tuner.HyperparameterTuner"], + return_dfs: bool = False, + job_metrics: Optional[List[str]] = None, + trials_only: bool = False, + advanced: bool = False, +) -> Union[alt.Chart, Tuple[alt.Chart, pd.DataFrame, pd.DataFrame]]: + """Visualize SageMaker hyperparameter tuning jobs. + + Args: + tuning_jobs: Single tuning job or list of tuning jobs (name or HyperparameterTuner object) + return_dfs: Whether to return the underlying DataFrames + job_metrics: List of additional job metrics to include + trials_only: Whether to only show trials data + advanced: Whether to show advanced visualizations + + Returns: + If return_dfs is False, returns Altair chart + If return_dfs is True, returns tuple of (chart, trials_df, full_df) + """ + + trials_df, tuned_parameters, objective_name, is_minimize = get_job_analytics_data(tuning_jobs) + + try: + from IPython import get_ipython, display + + if get_ipython(): + # Running in a Jupyter Notebook + display(trials_df.head(10)) + else: + # Running in a non-Jupyter environment + logger.info(trials_df.head(10).to_string()) + except ImportError: + # Not running in a Jupyter Notebook + logger.info(trials_df.head(10).to_string()) + + full_df = _prepare_consolidated_df(trials_df) if not trials_only else pd.DataFrame() + + trials_df.columns = trials_df.columns.map(_clean_parameter_name) + full_df.columns = full_df.columns.map(_clean_parameter_name) + tuned_parameters = [_clean_parameter_name(tp) for tp in tuned_parameters] + objective_name = _clean_parameter_name(objective_name) + + charts = create_charts( + trials_df, + tuned_parameters, + full_df, + objective_name, + minimize_objective=is_minimize, + job_metrics=job_metrics, + advanced=advanced, + ) + + if return_dfs: + return charts, trials_df, full_df + return charts + + +def create_charts( + trials_df: pd.DataFrame, + tuning_parameters: List[str], + full_df: pd.DataFrame, + objective_name: str, + minimize_objective: bool, + job_metrics: Optional[List[str]] = None, + highlight_trials: bool = True, + color_trials: bool = False, + advanced: bool = False, +) -> alt.Chart: + """Create visualization charts for hyperparameter tuning results. + + Args: + trials_df: DataFrame containing trials data + tuning_parameters: List of hyperparameter names + full_df: DataFrame with consolidated data + objective_name: Name of the objective metric + minimize_objective: Whether objective should be minimized + job_metrics: Additional job metrics to include + highlight_trials: Whether to highlight selected trials + color_trials: Whether to color trials by job + advanced: Whether to show advanced visualizations + + Returns: + Altair chart visualization + """ + + if trials_df.empty: + logger.info("No results available yet.") + return pd.DataFrame() + + if job_metrics is None: + job_metrics = [] + + multiple_tuning_jobs = len(trials_df["TuningJobName"].unique()) > 1 + multiple_job_status = len(trials_df["TrainingJobStatus"].unique()) > 1 + + # Rows, n>1 + # Detail Charts + + brush = alt.selection_interval(encodings=["x"], resolve="intersect", empty=True) + + job_highlight_selection = alt.selection_point( + on="mouseover", + nearest=False, + empty=False, + fields=["TrainingJobName", "TrainingStartTime"], + ) + + # create tooltip + detail_tooltip = [] + for trp in [objective_name] + tuning_parameters: + if trials_df[trp].dtype == np.float64: + trp = alt.Tooltip(trp, format=".2e") + detail_tooltip.append(trp) + + detail_tooltip.append(alt.Tooltip("TrainingStartTime:T", format="%H:%M:%S")) + detail_tooltip.extend(["TrainingJobName", "TrainingJobStatus", "TrainingElapsedTimeSeconds"]) + + # create stroke/stroke-width for tuning_jobs + # and color for training jobs, if wanted + # add coloring of the stroke to highlight correlated + # data points + jobs_props = {"shape": alt.Shape("TrainingJobStatus:N", legend=None)} + + if multiple_tuning_jobs: + jobs_props["strokeWidth"] = alt.StrokeWidthValue(2.0) + jobs_props["stroke"] = alt.Stroke("TuningJobName:N", legend=None) + + if color_trials: + jobs_props["color"] = alt.Color("TrainingJobName:N") + + if highlight_trials: + jobs_props["strokeWidth"] = alt.condition( + job_highlight_selection, + alt.StrokeWidthValue(2.0), + alt.StrokeWidthValue(2.0), + ) + jobs_props["stroke"] = alt.condition( + job_highlight_selection, + alt.StrokeValue("gold"), + ( + alt.Stroke("TuningJobName:N", legend=None) + if multiple_tuning_jobs + else alt.StrokeValue("white") + ), + ) + + opacity = alt.condition(brush, alt.value(1.0), alt.value(0.35)) + charts = [] + + # Min and max of the objective. This is used in filtered + # charts, so that the filtering does not make the axis + # jump, which would make comparisons harder. + objective_scale = alt.Scale( + domain=( + trials_df[objective_name].min(), + trials_df[objective_name].max(), + ) + ) + + # If we have multiple tuning jobs, we also want to be able + # to discriminate based on the individual tuning job, so + # we just treat them as an additional tuning parameter + tuning_job_param = ["TuningJobName"] if multiple_tuning_jobs else [] + tuning_parameters = tuning_parameters.copy() + tuning_job_param + + # If we use early stopping and at least some jobs were + # stopped early, we want to be able to discriminate + # those jobs. + if multiple_job_status: + tuning_parameters.append("TrainingJobStatus") + + def render_detail_charts(): + # To force a tuning job to sample a combination more than once, we + # sometimes introduce a hyperparameter that has no effect. + # It's values are random and without impact, so we omit it from analysis. + ignored_parameters = {"dummy"} + for tuning_parameter in tuning_parameters: + if tuning_parameter in ignored_parameters: + continue + + # Map dataframe's dtype to altair's types and + # adjust scale if necessary + scale_type = "linear" + scale_log_base = 10 + + few_values = len(trials_df[tuning_parameter].unique()) < 8 + parameter_type = "N" # Nominal + dtype = str(trials_df.dtypes[tuning_parameter]) + if "float" in dtype: + parameter_type = "Q" # Quantitative + ratio = (trials_df[tuning_parameter].max() + 1e-10) / ( + trials_df[tuning_parameter].min() + 1e-10 + ) + not_likely_discrete = ( + len(trials_df[tuning_parameter].unique()) > trials_df[tuning_parameter].count() + ) # edge case when both are equal + if few_values and not_likely_discrete: + if ratio > 50: + scale_type = "log" + elif ratio > 10: + scale_type = "log" + scale_log_base = 2 + + elif "int" in dtype or "object" in dtype: + parameter_type = "O" # Ordinal + + x_encoding = alt.X( + f"{tuning_parameter}:{parameter_type}", + scale=alt.Scale( + zero=False, + padding=1, + type=scale_type, + base=scale_log_base, + ), + ) + + # Sync the coloring for categorical hyperparameters + discrete = parameter_type in ["O", "N"] and few_values + + # Detail Chart + charts.append( + alt.Chart(trials_df) + .add_params(brush) + .add_params(job_highlight_selection) + .mark_point(filled=True, size=50) + .encode( + x=x_encoding, + y=alt.Y( + f"{objective_name}:Q", + scale=alt.Scale(zero=False, padding=1), + axis=alt.Axis(title=objective_name), + ), + opacity=opacity, + tooltip=detail_tooltip, + **jobs_props, + ) + ) + + if discrete: + # Individually coloring the values only if we don't already + # use the colors to show the different tuning jobs + logger.info("%s, %s", parameter_type, tuning_parameter) + if not multiple_tuning_jobs: + charts[-1] = charts[-1].encode(color=f"{tuning_parameter}:N") + charts[-1] = ( + ( + charts[-1] + | alt.Chart(trials_df) + .transform_filter(brush) + .transform_density( + objective_name, + bandwidth=0.01, + groupby=[tuning_parameter], + # https://github.com/vega/altair/issues/3203#issuecomment-2141558911 + # Specifying extent no longer necessary (>5.1.2). + extent=[ + trials_df[objective_name].min(), + trials_df[objective_name].max(), + ], + ) + .mark_area(opacity=0.5) + .encode( + x=alt.X( + "value:Q", + title=objective_name, + scale=objective_scale, + ), + y="density:Q", + color=alt.Color( + f"{tuning_parameter}:N", + ), + tooltip=tuning_parameter, + ) + ).properties(title=tuning_parameter) + # .resolve_scale("independent") + # .resolve_legend(color="independent") + ) + + if advanced and parameter_type == "Q": + # Adding tick marks to the detail charts with quantitative hyperparameters + x_enc = x_encoding.copy() + charts[-1].encoding.x.title = None + charts[-1].encoding.x.axis = alt.Axis(labels=False) + + charts[-1] = charts[-1] & alt.Chart(trials_df).mark_tick(opacity=0.5).encode( + x=x_enc, + opacity=alt.condition(brush, alt.value(0.5), alt.value(0.1)), + ) + + return _columnize(charts) + + detail_charts = render_detail_charts() + + # First Row + # Progress Over Time Chart + + def render_progress_chart(): + # Sorting trials by training start time, so that we can track the \ + # progress of the best objective so far over time + trials_df_by_tst = trials_df.sort_values(["TuningJobName", "TrainingStartTime"]) + trials_df_by_tst["cum_objective"] = trials_df_by_tst.groupby(["TuningJobName"]).transform( + lambda x: x.cummin() if minimize_objective else x.cummax() + )[objective_name] + + progress_chart = ( + alt.Chart(trials_df_by_tst) + .add_params(brush) + .add_params(job_highlight_selection) + .mark_point(filled=True, size=50) + .encode( + x=alt.X("TrainingStartTime:T", scale=alt.Scale(nice=True)), + y=alt.Y( + f"{objective_name}:Q", + scale=alt.Scale(zero=False, padding=1), + axis=alt.Axis(title=objective_name), + ), + opacity=opacity, + tooltip=detail_tooltip, + **jobs_props, + ) + ) + + cum_obj_chart = ( + alt.Chart(trials_df_by_tst) + .mark_line( + interpolate="step-after", + opacity=1.0, + strokeDash=[3, 3], + strokeWidth=2.0, + ) + .encode( + x=alt.X("TrainingStartTime:T", scale=alt.Scale(nice=True)), + y=alt.Y("cum_objective:Q", scale=alt.Scale(zero=False, padding=1)), + stroke=alt.Stroke("TuningJobName:N", legend=None), + ) + ) + + if advanced: + return cum_obj_chart + progress_chart + return progress_chart + + progress_chart = render_progress_chart() + + # First Row + # KDE Training Objective + result_hist_chart = ( + alt.Chart(trials_df) + .transform_filter(brush) + .transform_density(objective_name, bandwidth=0.01) + .mark_area() + .encode( + x=alt.X("value:Q", scale=objective_scale, title=objective_name), + y="density:Q", + ) + ) + # Training Jobs + training_jobs_chart = ( + alt.Chart(trials_df.sort_values(objective_name), title="Training Jobs") + .mark_bar() + .add_params(brush) + .add_params(job_highlight_selection) + .encode( + y=alt.Y(f"{objective_name}:Q"), + x=alt.X("TrainingJobName:N", sort=None), + color=alt.Color("TrainingJobName:N"), + opacity=opacity, + **jobs_props, + ) + ) + + # Job Level Stats + + training_job_name_encodings = { + "color": alt.condition( + brush, + alt.Color("TrainingJobName:N", legend=None), + alt.value("grey"), + ), + "opacity": alt.condition(brush, alt.value(1.0), alt.value(0.3)), + "strokeWidth": alt.condition(brush, alt.value(2.5), alt.value(0.8)), + } + + duration_format = "%M:%S" + metrics_tooltip = [ + "TrainingJobName:N", + "value:Q", + "label:N", + alt.Tooltip("ts:T", format="%e:%H:%M"), + alt.Tooltip("rel_ts:T", format="%e:%H:%M"), + ] + + job_level_rows = alt.HConcatChart() + + # Use CW metrics + if not full_df.empty: + # Objective Progression + + objective_progression_chart = None + # Suppress diagram if we only have one, final, value + if ( + full_df.loc[full_df.label == objective_name] + .groupby(["TuningJobName", "TrainingJobName"])[objective_name] + .count() + .max() + > 1 + ): + objective_progression_chart = ( + alt.Chart(full_df, title=f"Progression {objective_name}", width=400) + .transform_filter(alt.FieldEqualPredicate(field="label", equal=objective_name)) + .mark_line(point=True) + .encode( + x=alt.X("rel_ts:T", axis=alt.Axis(format=duration_format)), + y=alt.Y("value:Q", scale=alt.Scale(zero=False)), + **training_job_name_encodings, + tooltip=metrics_tooltip, + ) + .interactive() + ) + + if multiple_job_status: + objective_progression_chart = objective_progression_chart.encode( + strokeDash=alt.StrokeDash("TrainingJobStatus:N", legend=None) + ) + + # Secondary chart showing the same contents, but by absolute time. + objective_progression_absolute_chart = objective_progression_chart.encode( + x=alt.X("ts:T", scale=alt.Scale(nice=True)) + ) + + objective_progression_chart = ( + objective_progression_chart | objective_progression_absolute_chart + ) + + ### + + job_metrics_charts = [] + for metric in job_metrics: + metric_chart = ( + alt.Chart(full_df, title=metric, width=400) + .transform_filter(alt.FieldEqualPredicate(field="label", equal=metric)) + .encode( + y=alt.Y("value:Q", scale=alt.Scale(zero=False)), + **training_job_name_encodings, + tooltip=metrics_tooltip, + ) + .interactive() + ) + + if ( + full_df.loc[full_df.label == metric] + .groupby(["TuningJobName", "TrainingJobName"]) + .count() + .value.max() + == 1 + ): + # single value, render as a bar over the training jobs on the x-axis + metric_chart = metric_chart.encode( + x=alt.X("TrainingJobName:N", sort=None) + ).mark_bar(interpolate="linear", point=True) + else: + # multiple values, render the values over time on the x-axis + metric_chart = metric_chart.encode( + x=alt.X("rel_ts:T", axis=alt.Axis(format=duration_format)) + ).mark_line(interpolate="linear", point=True) + + job_metrics_charts.append(metric_chart) + + job_metrics_chart = _columnize(job_metrics_charts, 3) + + # Job instance + # 'MemoryUtilization', 'CPUUtilization' + instance_metrics_chart = ( + alt.Chart(full_df, title="CPU and Memory") + .transform_filter( + alt.FieldOneOfPredicate( + field="label", + oneOf=[ + "MemoryUtilization", + "CPUUtilization", + ], + ) + ) + .mark_line() + .encode( + x=alt.X("rel_ts:T", axis=alt.Axis(format=duration_format)), + y="value:Q", + **training_job_name_encodings, + strokeDash=alt.StrokeDash("label:N", legend=alt.Legend(orient="bottom")), + tooltip=metrics_tooltip, + ) + .interactive() + ) + + if "GPUUtilization" in full_df.label.values: + instance_metrics_chart = ( + instance_metrics_chart + | alt.Chart(full_df, title="GPU and GPU Memory") + .transform_filter( + alt.FieldOneOfPredicate( + field="label", + oneOf=[ + "GPUMemoryUtilization", + "GPUUtilization", + ], + ) + ) + .mark_line() + .encode( + x=alt.X("rel_ts:T", axis=alt.Axis(format=duration_format)), + y=alt.Y("value:Q"), + **training_job_name_encodings, + strokeDash=alt.StrokeDash("label:N", legend=alt.Legend(orient="bottom")), + tooltip=metrics_tooltip, + ) + .interactive() + ) + + job_level_rows = job_metrics_chart & instance_metrics_chart + if objective_progression_chart: + job_level_rows = objective_progression_chart & job_level_rows + job_level_rows = job_level_rows.resolve_scale(strokeDash="independent").properties( + title="Job / Instance Level Metrics" + ) + + overview_row = (progress_chart | result_hist_chart).properties( + title="Hyper Parameter Tuning Job" + ) + detail_rows = detail_charts.properties(title="Hyper Parameter Details") + if job_level_rows: + job_level_rows = training_jobs_chart & job_level_rows + + return overview_row & detail_rows & job_level_rows + + +def _clean_parameter_name(s): + """Helper method to ensure proper parameter name characters for altair 5+""" + return s.replace(":", "_").replace(".", "_") + + +def _prepare_training_job_metrics(jobs): + """Fetches and combines CloudWatch metrics for multiple training jobs. + + Args: + jobs (list): List of (job_name, start_time, end_time) tuples. + + Returns: + pandas.DataFrame: Combined metrics DataFrame with 'TrainingJobName' column. + """ + df = pd.DataFrame() + for job_name, start_time, end_time in jobs: + job_df = get_cw_job_metrics( + job_name, + start_time=pd.Timestamp(start_time) - pd.DateOffset(hours=8), + end_time=pd.Timestamp(end_time) + pd.DateOffset(hours=8), + ) + if job_df is None: + logger.info("No CloudWatch metrics for %s. Skipping.", job_name) + continue + + job_df["TrainingJobName"] = job_name + df = pd.concat([df, job_df]) + return df + + +def _prepare_consolidated_df(trials_df): + """Merges training job metrics with trials data into a consolidated DataFrame.""" + if trials_df.empty: + return pd.DataFrame() + + logger.debug("Cache Hit/Miss: ", end="") + jobs_df = _prepare_training_job_metrics( + zip( + trials_df.TrainingJobName.values, + trials_df.TrainingStartTime.values, + trials_df.TrainingEndTime.values, + ) + ) + logger.info("") + + if jobs_df.empty: + return pd.DataFrame() + + merged_df = pd.merge(jobs_df, trials_df, on="TrainingJobName") + return merged_df + + +def _get_df(tuning_job_name, filter_out_stopped=False): + """Retrieves hyperparameter tuning job results and returns preprocessed DataFrame. + + Returns a DataFrame containing tuning metrics and parameters for the specified job. + """ + + tuner = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name) + + df = tuner.dataframe() + if df.empty: # HPO job just started; no results yet + return df + + df["TuningJobName"] = tuning_job_name + + # Filter out jobs without FinalObjectiveValue + df = df[df["FinalObjectiveValue"] > -float("inf")] + + # Jobs early stopped by AMT are reported with their last + # objective value, before they are stopped. + # However this value may not be a good representation + # of the eventual objective value we would have seen + # if run without stopping. Therefore it may be confusing + # to include those runs. + # For now, if included, we use a different mark to + # discriminate visually between a stopped and finished job + + if filter_out_stopped: + df = df[df["TrainingJobStatus"] != "Stopped"] + + # Preprocessing values for [32], [64] etc. + for tuning_range in tuner.tuning_ranges.values(): + parameter_name = tuning_range["Name"] + if df.dtypes[parameter_name] == "O": + try: + # Remove decorations, like [] + df[parameter_name] = df[parameter_name].apply( + lambda v: v.replace("[", "").replace("]", "").replace('"', "") + ) + + # Is it an int? 3 would work, 3.4 would fail. + try: + df[parameter_name] = df[parameter_name].astype(int) + except ValueError: + # A float then? + df[parameter_name] = df[parameter_name].astype(float) + + except (ValueError, TypeError, AttributeError): + # Catch exceptions that might occur during string manipulation or type conversion + # - ValueError: Could not convert string to float/int + # - TypeError: Object doesn't support the operation + # - AttributeError: Object doesn't have replace method + # Leaving the value untouched + pass + + return df + + +def _get_tuning_job_names_with_parents(tuning_job_names): + """Resolve dependent jobs, one level only""" + + all_tuning_job_names = [] + for tuning_job_name in tuning_job_names: + tuning_job_result = sm.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=tuning_job_name + ) + + # find parent jobs and retrieve all tuner dataframes + parent_jobs = [] + if "WarmStartConfig" in tuning_job_result: + parent_jobs = [ + cfg["HyperParameterTuningJobName"] + for cfg in tuning_job_result["WarmStartConfig"]["ParentHyperParameterTuningJobs"] + ] + if parent_jobs: + logger.info("Tuning job %s's parents: %s", tuning_job_name, ", ".join(parent_jobs)) + all_tuning_job_names.extend([tuning_job_name, *parent_jobs]) + + # return de-duplicated tuning job names + return list(set(all_tuning_job_names)) + + +def get_job_analytics_data(tuning_job_names): + """Retrieves and processes analytics data from hyperparameter tuning jobs. + + Args: + tuning_job_names (str or list): Single tuning job name or list of names/tuner objects. + + Returns: + tuple: (DataFrame with training results, tuned params list, objective name, is_minimize). + + Raises: + ValueError: If tuning jobs have different objectives or optimization directions. + """ + if not isinstance(tuning_job_names, list): + tuning_job_names = [tuning_job_names] + + # Ensure to create a list of tuning job names (strings) + tuning_job_names = [ + ( + tuning_job.describe()["HyperParameterTuningJobName"] + if isinstance(tuning_job, sagemaker.tuner.HyperparameterTuner) + else tuning_job + ) + for tuning_job in tuning_job_names + ] + + # Maintain combined tuner dataframe from all tuning jobs + df = pd.DataFrame() + + # maintain objective, direction of optimization and tuned parameters + objective_name = None + is_minimize = None + tuned_parameters = None + + all_tuning_job_names = _get_tuning_job_names_with_parents(tuning_job_names) + + for tuning_job_name in all_tuning_job_names: + tuning_job_result = sm.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=tuning_job_name + ) + status = tuning_job_result["HyperParameterTuningJobStatus"] + logger.info("Tuning job %-25s status: %s", tuning_job_name, status) + + df = pd.concat([df, _get_df(tuning_job_name)]) + + # maintain objective and assure that all tuning jobs use the same + job_is_minimize = ( + tuning_job_result["HyperParameterTuningJobConfig"]["HyperParameterTuningJobObjective"][ + "Type" + ] + != "Maximize" + ) + job_objective_name = tuning_job_result["HyperParameterTuningJobConfig"][ + "HyperParameterTuningJobObjective" + ]["MetricName"] + job_tuned_parameters = [ + v["Name"] + for v in sagemaker.HyperparameterTuningJobAnalytics( + tuning_job_name + ).tuning_ranges.values() + ] + + if not objective_name: + objective_name = job_objective_name + is_minimize = job_is_minimize + tuned_parameters = job_tuned_parameters + else: + if ( + objective_name != job_objective_name + or is_minimize != job_is_minimize + or set(tuned_parameters) != set(job_tuned_parameters) + ): + raise ValueError( + "All tuning jobs must use the same objective and optimization direction." + ) + + if not df.empty: + # Cleanup wrongly encoded floats, e.g. containing quotes. + for i, dtype in enumerate(df.dtypes): + column_name = str(df.columns[i]) + if column_name in [ + "TrainingJobName", + "TrainingJobStatus", + "TuningJobName", + ]: + continue + if dtype == "object": + val = df[column_name].iloc[0] + if isinstance(val, str) and val.startswith('"'): + try: + df[column_name] = df[column_name].apply(lambda x: int(x.replace('"', ""))) + except (ValueError, TypeError, AttributeError): + # noqa: E722 nosec b110 if we fail, we just continue with what we had + pass # Value is not an int, but a string + + df = df.sort_values("FinalObjectiveValue", ascending=is_minimize) + df[objective_name] = df.pop("FinalObjectiveValue") + + # Fix potential issue with dates represented as objects, instead of a timestamp + # This can in other cases lead to: + # https://www.markhneedham.com/blog/2020/01/10/altair-typeerror-object-type- + # date-not-json-serializable/ + # Seen this for TrainingEndTime, but will watch TrainingStartTime as well now. + df["TrainingEndTime"] = pd.to_datetime(df["TrainingEndTime"]) + df["TrainingStartTime"] = pd.to_datetime(df["TrainingStartTime"]) + + logger.info("") + logger.info("Number of training jobs with valid objective: %d", len(df)) + logger.info("Lowest: %s Highest %s", min(df[objective_name]), max(df[objective_name])) + + tuned_parameters = [_clean_parameter_name(tp) for tp in tuned_parameters] + + return df, tuned_parameters, objective_name, is_minimize diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index d1af08e2f1..a970029f51 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -54,6 +54,7 @@ from sagemaker.train.model_trainer import ModelTrainer from sagemaker.core.training.configs import InputData from sagemaker.core.training.utils import _is_valid_s3_uri +import importlib HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName" PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs" @@ -1203,6 +1204,70 @@ def _add_model_trainer( if metric_definitions is not None: self.metric_definitions_dict[model_trainer_name] = metric_definitions + @staticmethod + def visualize_jobs( + tuning_jobs: Union[ + str, + "sagemaker.train.tuner.HyperparameterTuner", + List[Union[str, "sagemaker.train.tuner.HyperparameterTuner"]], + ], + return_dfs: bool = False, + job_metrics: Optional[List[str]] = None, + trials_only: bool = False, + advanced: bool = False, + ): + """Create interactive visualization via altair charts using the sagemaker.amtviz package. + Args: + tuning_jobs (str or sagemaker.tuner.HyperparameterTuner or list[str, sagemaker.tuner.HyperparameterTuner]): + One or more tuning jobs to create visualization for. + return_dfs: (bool): Option to return trials and full dataframe. + job_metrics: (list[str]): Metrics to be used in charts. + trials_only: (bool): Whether to show trials only or full dataframe. + advanced: (bool): Show a cumulative step line in the progress over time chart. + + Returns: + A collection of charts (altair.VConcatChart); or charts, trials_df (pandas.DataFrame), + full_df (pandas.DataFrame) if ``return_dfs=True``. + """ + + try: + # Check if altair is installed + importlib.import_module("altair") + except ImportError: + print("Altair is not installed. Install Altair to use the visualization feature:") + print(" pip install altair") + print("After installing Altair, use the methods visualize_jobs or visualize_job.") + + return None + + # If altair is installed, proceed with visualization + from sagemaker.core.amtviz import visualize_tuning_job + + return visualize_tuning_job( + tuning_jobs, + return_dfs=return_dfs, + job_metrics=job_metrics, + trials_only=trials_only, + advanced=advanced, + ) + + def visualize_job( + self, + return_dfs: bool = False, + job_metrics: Optional[List[str]] = None, + trials_only: bool = False, + advanced: bool = False, + ): + """Convenience method on instance level for visualize_jobs(). + See static method visualize_jobs(). + """ + return HyperparameterTuner.visualize_jobs( + self, + return_dfs=return_dfs, + job_metrics=job_metrics, + trials_only=trials_only, + advanced=advanced, + ) def _start_tuning_job(self, inputs): """Start a new hyperparameter tuning job using HyperParameterTuningJob.""" From 247a8f28672b10cc126c5eaaf3f23c0a23232bb0 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Mon, 12 Jan 2026 08:15:01 -0800 Subject: [PATCH 4/5] change: When rootlessDocker is enabled, return a fixed SageMaker IP --- .../src/sagemaker/core/local/utils.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/sagemaker-core/src/sagemaker/core/local/utils.py b/sagemaker-core/src/sagemaker/core/local/utils.py index 4b8cdead66..c7e3be8782 100644 --- a/sagemaker-core/src/sagemaker/core/local/utils.py +++ b/sagemaker-core/src/sagemaker/core/local/utils.py @@ -154,7 +154,8 @@ def get_child_process_ids(pid): def get_docker_host(): """Discover remote docker host address (if applicable) or use "localhost" - Use "docker context inspect" to read current docker host endpoint url, + When rootlessDocker is enabled (Cgroup Driver: none), use fixed SageMaker IP. + Otherwise, Use "docker context inspect" to read current docker host endpoint url, url must start with "tcp://" Args: @@ -162,6 +163,27 @@ def get_docker_host(): Returns: docker_host (str): Docker host DNS or IP address """ + # Check if using SageMaker rootless Docker by examining storage driver + try: + cmd = ["docker", "info"] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, err = process.communicate() + if process.returncode == 0: # Check return code instead of stderr + output_text = output.decode("utf-8") + # Check for rootless Docker by looking at Cgroup Driver + if "Cgroup Driver: none" in output_text: + # log the result of check + logger.warning("RootlessDocker detected (Cgroup Driver: none), returning fixed IP.") + # SageMaker rootless Docker detected - return fixed IP + return "172.17.0.1" + else: + logger.warning( + "RootlessDocker not detected, falling back to remote host IP or localhost." + ) + except subprocess.SubprocessError as e: + logger.warning("Failed to run 'docker info' command when checking rootlessDocker: %s.", e) + + # Fallback to existing logic for remote Docker hosts cmd = "docker context inspect".split() process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, err = process.communicate() From 13b4f964ee395bb7194b9f54a873d13961ea5c20 Mon Sep 17 00:00:00 2001 From: jzhaoqwa <52220743+zhaoqizqwang@users.noreply.github.com> Date: Mon, 12 Jan 2026 08:16:04 -0800 Subject: [PATCH 5/5] feat: change S3 endpoint env name --- sagemaker-core/src/sagemaker/core/local/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-core/src/sagemaker/core/local/image.py b/sagemaker-core/src/sagemaker/core/local/image.py index 4ca91a9469..94d88b0943 100644 --- a/sagemaker-core/src/sagemaker/core/local/image.py +++ b/sagemaker-core/src/sagemaker/core/local/image.py @@ -48,7 +48,7 @@ # Environment variables to be set during training REGION_ENV_NAME = "AWS_REGION" TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME" -S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL" +S3_ENDPOINT_URL_ENV_NAME = "AWS_ENDPOINT_URL_S3" SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE" # SELinux Enabled