Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sagemaker-train/src/sagemaker/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@
"mistral.mistral-large-2402-v1:0": ["us-west-2", "us-east-1", "eu-west-1"],
"amazon.nova-pro-v1:0": ["us-east-1"]
}

SM_RECIPE = "recipe"
SM_RECIPE_YAML = "recipe.yaml"
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"
62 changes: 57 additions & 5 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@
SM_CODE_CONTAINER_PATH,
SM_DRIVERS,
SM_DRIVERS_LOCAL_PATH,
SM_RECIPE,
SM_RECIPE_YAML,
SM_RECIPE_CONTAINER_PATH,
TRAIN_SCRIPT,
DEFAULT_CONTAINER_ENTRYPOINT,
DEFAULT_CONTAINER_ARGUMENTS,
Expand All @@ -100,7 +103,13 @@
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train import logger
from sagemaker.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type
from sagemaker.train.sm_recipes.utils import (
_get_args_from_recipe,
_determine_device_type,
_is_nova_recipe,
_is_llmft_recipe,
_load_base_recipe,
)

from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.core.jumpstart.document import get_hub_content_and_document
Expand Down Expand Up @@ -249,6 +258,8 @@ class ModelTrainer(BaseModel):
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)

_is_nova_recipe: Optional[bool] = PrivateAttr(default=None)
_is_llmft_recipe: Optional[bool] = PrivateAttr(default=None)
# Private Attributes for Recipes
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)

Expand Down Expand Up @@ -573,6 +584,23 @@ def _create_training_job_args(

final_input_data_config = list(existing_channels.values()) + new_channels

if self._is_nova_recipe or self._is_llmft_recipe:
for input_data in final_input_data_config:
if input_data.channel_name == SM_RECIPE:
raise ValueError(
"Cannot use reserved channel name 'recipe' as an input channel name "
" for Nova or LLMFT Recipe"
)
recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML)
recipe_channel = self.create_input_data_channel(
channel_name=SM_RECIPE,
data_source=recipe_file_path,
key_prefix=input_data_key_prefix,
)
final_input_data_config.append(recipe_channel)
if self._is_nova_recipe or self._is_llmft_recipe:
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})

if final_input_data_config:
final_input_data_config = self._get_input_data_config(
final_input_data_config, input_data_key_prefix
Expand Down Expand Up @@ -1039,6 +1067,7 @@ def from_recipe(
checkpoint_config: Optional[shapes.CheckpointConfig] = None,
training_input_mode: Optional[str] = "File",
environment: Optional[Dict[str, str]] = None,
hyperparameters: Optional[Union[Dict[str, Any], str]] = {},
tags: Optional[List[Tag]] = None,
sagemaker_session: Optional[Session] = None,
role: Optional[str] = None,
Expand Down Expand Up @@ -1136,12 +1165,20 @@ def from_recipe(
if compute.instance_type is None:
raise ValueError("Must set ``instance_type`` in Compute when using training recipes.")
device_type = _determine_device_type(compute.instance_type)
if device_type == "cpu":
recipe = _load_base_recipe(
training_recipe=training_recipe, recipe_overrides=recipe_overrides
)
is_nova = _is_nova_recipe(recipe=recipe)
is_llmft = _is_llmft_recipe(recipe=recipe)
if device_type == "cpu" and not (is_nova or is_llmft):
raise ValueError(
"Training recipes are not supported for CPU instances. "
"Please provide a GPU or Tranium instance type."
)

if training_image is None and (is_nova or is_llmft):
raise ValueError("training_image must be provided when using recipe for Nova or LLMFT")

if training_image_config and training_image is None:
raise ValueError("training_image must be provided when using training_image_config.")

Expand All @@ -1154,16 +1191,29 @@ def from_recipe(
# - distributed
# - compute
# - hyperparameters
model_trainer_args, recipe_train_dir = _get_args_from_recipe(
training_recipe=training_recipe,
model_trainer_args, tmp_dir = _get_args_from_recipe(
training_recipe=recipe,
recipe_overrides=recipe_overrides,
requirements=requirements,
compute=compute,
region_name=sagemaker_session.boto_region_name,
role=role,
)
if training_image is not None:
model_trainer_args["training_image"] = training_image

if hyperparameters and not is_nova:
logger.warning(
"Hyperparameters are not supported for general and LLMFT training recipes. "
+ "Ignoring hyperparameters input."
)
if is_nova:
if hyperparameters and isinstance(hyperparameters, str):
hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters)
model_trainer_args["hyperparameters"].update(hyperparameters)
elif hyperparameters and isinstance(hyperparameters, dict):
model_trainer_args["hyperparameters"].update(hyperparameters)

model_trainer = cls(
sagemaker_session=sagemaker_session,
role=role,
Expand All @@ -1180,7 +1230,9 @@ def from_recipe(
**model_trainer_args,
)

model_trainer._temp_recipe_train_dir = recipe_train_dir
model_trainer._is_nova_recipe = is_nova
model_trainer._is_llmft_recipe = is_llmft
model_trainer._temp_recipe_train_dir = tmp_dir
return model_trainer

@classmethod
Expand Down
Loading
Loading