diff --git a/nemo_run/core/execution/dgxcloud.py b/nemo_run/core/execution/dgxcloud.py index 13596ebf..b6d5e855 100644 --- a/nemo_run/core/execution/dgxcloud.py +++ b/nemo_run/core/execution/dgxcloud.py @@ -29,8 +29,11 @@ import requests from invoke.context import Context -from nemo_run.config import get_nemorun_home +from nemo_run.config import RUNDIR_NAME, get_nemorun_home from nemo_run.core.execution.base import Executor, ExecutorMacros +from nemo_run.core.execution.launcher import FaultTolerance, Launcher, Torchrun +from nemo_run.core.execution.utils import fill_template +from nemo_run.core.frontend.console.api import CONSOLE from nemo_run.core.packaging.base import Packager from nemo_run.core.packaging.git import GitArchivePackager @@ -461,6 +464,24 @@ def cancel(self, job_id: str): response.text, ) + def _setup_launcher(self): + super()._setup_launcher() + launcher = self.launcher + if launcher and isinstance(launcher, (FaultTolerance, Torchrun)): + self.torchrun_nproc_per_node = self.nprocs_per_node + self.ntasks_per_node = 1 + CONSOLE.log( + f"Detected {launcher.__class__.__name__} launcher, setting ntasks_per_node=1 and torchrun_nproc_per_node={self.torchrun_nproc_per_node}" + ) + + if launcher and isinstance(launcher, FaultTolerance): + base_dir = os.path.join(self.job_dir, Path(self.job_dir).name) + launcher.cfg_path = os.path.join(base_dir, f"{self.job_name}_ft_cfg.yml") + launcher.finished_flag_file = os.path.join( + "/", RUNDIR_NAME, f"{self.job_name}_finished_flag" + ) + launcher.job_results_file = os.path.join(base_dir, f"{self.job_name}_job_results") + def cleanup(self, handle: str): ... def assign( @@ -556,3 +577,55 @@ def _default_headers(self, token: Optional[str] = None) -> dict: if token: headers["Authorization"] = f"Bearer {token}" return headers + + +@dataclass(kw_only=True) +class DGXCloudRequest: + launch_cmd: list[str] + jobs: list[str] + executor: DGXCloudExecutor + max_retries: int + extra_env: dict[str, str] + launcher: Optional[Launcher] = None + + def materialize(self) -> str: + """Creates the content of a DGXC entrypoint script.""" + + # 1. Environment Variables + # Combine executor defaults with extra envs + env_vars = [] + full_env_vars = self.executor.env_vars | self.extra_env + for key, value in full_env_vars.items(): + env_vars.append(f"export {key.upper()}={value}") + + # 3. Prepare Template Variables + vars_to_fill = { + "max_retries": self.max_retries, + "env_vars": env_vars, + "training_command": " ".join(self.launch_cmd), + "ft_enabled": bool(self.launcher and isinstance(self.launcher, FaultTolerance)), + } + + # 4. Fault Tolerance Injection + if self.launcher and isinstance(self.launcher, FaultTolerance): + assert ( + self.launcher.cfg_path + and self.launcher.finished_flag_file + and self.launcher.job_results_file + ), "Fault Tolerance requires cfg_path, finished_flag_file, and job_results_file" + + vars_to_fill["fault_tol_cfg_path"] = self.launcher.cfg_path + vars_to_fill["fault_tol_finished_flag_file"] = self.launcher.finished_flag_file + vars_to_fill["fault_tol_job_results_file"] = self.launcher.job_results_file + + # Render the template + entrypoint_script = fill_template("dgxc.sh.j2", vars_to_fill) + return entrypoint_script + + def __repr__(self) -> str: + return f"""# DGXC Entrypoint Script Request +# Executor: {self.executor.__class__.__name__} +# Jobs: {self.jobs} +# --------------------------------------------------- +{self.materialize()} +""" diff --git a/nemo_run/core/execution/templates/dgxc.sh.j2 b/nemo_run/core/execution/templates/dgxc.sh.j2 new file mode 100644 index 00000000..75bdede2 --- /dev/null +++ b/nemo_run/core/execution/templates/dgxc.sh.j2 @@ -0,0 +1,31 @@ +{%- import "ft_launcher_dgxc.j2" as fault_tolerance -%} +#!/bin/bash + +set -evx # Print commands, but DO NOT exit immediately on error (we handle that below) +export PYTHONUNBUFFERED=1 +export TORCHX_MAX_RETRIES={{max_retries}} + +{%- for env_var in env_vars %} +{{env_var}} +{%- endfor %} + +{%- if ft_enabled %} +{{ fault_tolerance.ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) }} +{%- endif %} + +echo "Starting training command..." +set +e # Turn off auto-exit so we can capture the code + +{{ training_command }} + +exitcode=$? +set -e + +echo "Main command exited with code $exitcode" + +{%- if ft_enabled %} +{{ fault_tolerance.ft_launcher_teardown() }} +{%- else %} + +exit $exitcode +{%- endif %} diff --git a/nemo_run/core/execution/templates/ft_launcher_dgxc.j2 b/nemo_run/core/execution/templates/ft_launcher_dgxc.j2 new file mode 100644 index 00000000..150d8b0c --- /dev/null +++ b/nemo_run/core/execution/templates/ft_launcher_dgxc.j2 @@ -0,0 +1,24 @@ +{% macro ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) -%} +# This script uses experimental fault tolerance launcher +# Fault tolerance related items +export FAULT_TOL_CFG_PATH="{{fault_tol_cfg_path}}" +export FAULT_TOL_FINISHED_FLAG_FILE="{{fault_tol_finished_flag_file}}" + +JOB_RESULTS_FILE="{{fault_tol_job_results_file}}" + +is_training_finished() { + test -f "$(dirname $JOB_RESULTS_FILE)/$(basename $FAULT_TOL_FINISHED_FLAG_FILE)" +} + +if is_training_finished ; then + echo "Training is finished"; + exit 0; +else + rm -f "$FAULT_TOL_FINISHED_FLAG_FILE" "$JOB_RESULTS_FILE" +fi + +{%- endmacro %} + +{% macro ft_launcher_teardown() -%} +exit $exitcode +{%- endmacro %} diff --git a/nemo_run/core/execution/templates/ft_launcher.j2 b/nemo_run/core/execution/templates/ft_launcher_slurm.j2 similarity index 100% rename from nemo_run/core/execution/templates/ft_launcher.j2 rename to nemo_run/core/execution/templates/ft_launcher_slurm.j2 diff --git a/nemo_run/core/execution/templates/slurm.sh.j2 b/nemo_run/core/execution/templates/slurm.sh.j2 index 26f756fa..dc2c93fa 100644 --- a/nemo_run/core/execution/templates/slurm.sh.j2 +++ b/nemo_run/core/execution/templates/slurm.sh.j2 @@ -1,4 +1,4 @@ -{%- import "ft_launcher.j2" as fault_tolerance -%} +{%- import "ft_launcher_slurm.j2" as fault_tolerance -%} #!/bin/bash # # Generated by NeMo Run diff --git a/nemo_run/run/torchx_backend/components/ft_launcher.py b/nemo_run/run/torchx_backend/components/ft_launcher.py index 3920041f..006c8321 100644 --- a/nemo_run/run/torchx_backend/components/ft_launcher.py +++ b/nemo_run/run/torchx_backend/components/ft_launcher.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import shlex from typing import Optional @@ -22,6 +23,8 @@ from nemo_run.run.torchx_backend.components import torchrun +logger = logging.getLogger(__name__) + # Adapted from torchrun component def ft_launcher( @@ -92,30 +95,36 @@ def ft_launcher( ): if workload_check_interval: ft_args += [ - "--ft-param-workload_check_interval", + "--ft-workload_check_interval", str(workload_check_interval), ] if initial_rank_heartbeat_timeout: ft_args += [ - "--ft-param-initial_rank_heartbeat_timeout", + "--ft-initial_rank_heartbeat_timeout", str(initial_rank_heartbeat_timeout), ] if rank_heartbeat_timeout: ft_args += [ - "--ft-param-rank_heartbeat_timeout", + "--ft-rank_heartbeat_timeout", str(rank_heartbeat_timeout), ] if rank_termination_signal: - ft_args += ["--ft-param-rank_termination_signal", rank_termination_signal] + ft_args += ["--ft-rank_termination_signal", rank_termination_signal] if log_level: - ft_args += ["--ft-param-log_level", log_level] + ft_args += ["--ft-log_level", log_level] if max_restarts: - ft_args += ["--max-restarts", str(max_restarts)] + if dgxc is True: + logger.warning("max_restarts is ignored for DGXCloudExecutor") + else: + ft_args += ["--max-restarts", str(max_restarts)] + + if dgxc is True: + ft_args += ["--ft-use-infra-group-rank", "False"] else: ft_args = ["--ignore-missing-fault-tol-cfg"] diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index 84b9dc4c..8a850de4 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -203,6 +203,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): log_level=launcher.log_level, max_retries=executor.retries, max_restarts=launcher.max_restarts, + dgxc=isinstance(executor, DGXCloudExecutor), use_env=use_env, ) else: diff --git a/nemo_run/run/torchx_backend/schedulers/dgxcloud.py b/nemo_run/run/torchx_backend/schedulers/dgxcloud.py index 4377ec71..b786d3c0 100644 --- a/nemo_run/run/torchx_backend/schedulers/dgxcloud.py +++ b/nemo_run/run/torchx_backend/schedulers/dgxcloud.py @@ -37,7 +37,7 @@ from nemo_run.config import get_nemorun_home from nemo_run.core.execution.base import Executor -from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState +from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudRequest, DGXCloudState from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin @@ -109,6 +109,23 @@ def _submit_dryrun( # type: ignore role = values.apply(role) cmd = [role.entrypoint] + role.args + + req = DGXCloudRequest( + launch_cmd=cmd, + jobs=[role.name], + executor=executor, + max_retries=role.max_retries, + extra_env=role.env, + launcher=executor.get_launcher(), + ) + + # Write and copy sbatch script + path = os.path.join(executor.experiment_dir, "torchrun_job.sh") + script = req.materialize() + + with open(path, "w") as f: + f.write(script) + return AppDryRunInfo( DGXRequest(app=app, executor=executor, cmd=cmd, name=role.name), # Minimal function to show the config, if any @@ -128,7 +145,9 @@ def schedule(self, dryrun_info: AppDryRunInfo[DGXRequest]) -> str: # The DGXExecutor's launch call typically returns (job_id, handle). # We'll call it without additional parameters here. - job_id, status = executor.launch(name=req.name, cmd=req.cmd) + cmd = os.path.join(executor.experiment_dir, "torchrun_job.sh") + req.launch_cmd = ["bash", cmd] + job_id, status = executor.launch(name=req.name, cmd=req.launch_cmd) if not job_id: raise RuntimeError("Failed scheduling run on DGX: no job_id returned") diff --git a/test/core/execution/artifacts/ft_het_slurm.sh b/test/core/execution/artifacts/ft_het_slurm.sh index a0a51456..df3ecb5b 100644 --- a/test/core/execution/artifacts/ft_het_slurm.sh +++ b/test/core/execution/artifacts/ft_het_slurm.sh @@ -77,7 +77,7 @@ echo "$SLURM_JOB_ID ${SLURM_RESTART_COUNT:-0} X" >> "$JOB_RESULTS_FILE" export CUSTOM_ENV_1=some_value_1 -srun --het-group=0 --output /root/experiment/sample_job/log-account-account.sample_job-0_%j_${SLURM_RESTART_COUNT:-0}.out --container-image image_1 --container-mounts /root/experiment/sample_job:/nemo_run --container-workdir /nemo_run/code --wait=60 --kill-on-bad-exit=1 ft_launcher --ft-param-workload_check_interval 10 --ft-param-rank_heartbeat_timeout 10 --rdzv-backend c10d --rdzv-endpoint localhost:0 --rdzv-id 1 --nnodes 1 --nproc-per-node 1 --node-rank 0 --tee 3 --no-python test_ft.sh & pids[0]=$! +srun --het-group=0 --output /root/experiment/sample_job/log-account-account.sample_job-0_%j_${SLURM_RESTART_COUNT:-0}.out --container-image image_1 --container-mounts /root/experiment/sample_job:/nemo_run --container-workdir /nemo_run/code --wait=60 --kill-on-bad-exit=1 ft_launcher --ft-workload_check_interval 10 --ft-rank_heartbeat_timeout 10 --rdzv-backend c10d --rdzv-endpoint localhost:0 --rdzv-id 1 --nnodes 1 --nproc-per-node 1 --node-rank 0 --tee 3 --no-python test_ft.sh & pids[0]=$! sleep 30 diff --git a/test/core/execution/artifacts/ft_slurm.sh b/test/core/execution/artifacts/ft_slurm.sh index 59b15123..4421b3c4 100644 --- a/test/core/execution/artifacts/ft_slurm.sh +++ b/test/core/execution/artifacts/ft_slurm.sh @@ -62,7 +62,7 @@ echo "$SLURM_JOB_ID ${SLURM_RESTART_COUNT:-0} X" >> "$JOB_RESULTS_FILE" # Command 1 -srun --output /root/sample_job/log-account-account.sample_job_%j_${SLURM_RESTART_COUNT:-0}.out --container-mounts /root/sample_job:/nemo_run --container-workdir /nemo_run/code --wait=60 --kill-on-bad-exit=1 ft_launcher --ft-param-workload_check_interval 10 --ft-param-rank_heartbeat_timeout 10 --rdzv-backend c10d --rdzv-endpoint localhost:0 --rdzv-id 7680 --nnodes 1 --nproc-per-node 1 --node-rank 0 --tee 3 --no-python test_ft.sh +srun --output /root/sample_job/log-account-account.sample_job_%j_${SLURM_RESTART_COUNT:-0}.out --container-mounts /root/sample_job:/nemo_run --container-workdir /nemo_run/code --wait=60 --kill-on-bad-exit=1 ft_launcher --ft-workload_check_interval 10 --ft-rank_heartbeat_timeout 10 --rdzv-backend c10d --rdzv-endpoint localhost:0 --rdzv-id 7680 --nnodes 1 --nproc-per-node 1 --node-rank 0 --tee 3 --no-python test_ft.sh exitcode=$? diff --git a/test/core/execution/test_dgxcloud.py b/test/core/execution/test_dgxcloud.py index 9098b5c5..49505c48 100644 --- a/test/core/execution/test_dgxcloud.py +++ b/test/core/execution/test_dgxcloud.py @@ -1143,4 +1143,501 @@ def test_default_headers_with_token(self): assert headers["Content-Type"] == "application/json" assert "Authorization" in headers assert headers["Authorization"] == "Bearer test_token" - assert headers["Authorization"] == "Bearer test_token" + + def test_setup_launcher_no_launcher(self): + """Test _setup_launcher when no launcher is set.""" + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + nprocs_per_node=8, + ) + + # Set up job details required by _setup_launcher + executor.job_name = "test_job" + executor.job_dir = "/workspace/test_job" + + with patch("nemo_run.core.execution.dgxcloud.CONSOLE"): + executor._setup_launcher() + + # When no launcher, torchrun_nproc_per_node and ntasks_per_node should not be modified + # ntasks_per_node is only set when launcher is Torchrun or FaultTolerance + assert ( + not hasattr(executor, "torchrun_nproc_per_node") + or executor.torchrun_nproc_per_node is None + ) + + def test_setup_launcher_with_torchrun(self): + """Test _setup_launcher with Torchrun launcher.""" + from nemo_run.core.execution.launcher import Torchrun + + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + nprocs_per_node=8, + launcher=Torchrun(), + ) + + executor.job_name = "test_job" + executor.job_dir = "/workspace/test_job" + + with patch("nemo_run.core.execution.dgxcloud.CONSOLE") as mock_console: + executor._setup_launcher() + + # With Torchrun, ntasks_per_node should be 1 and torchrun_nproc_per_node should be nprocs_per_node + assert executor.ntasks_per_node == 1 + assert executor.torchrun_nproc_per_node == 8 + mock_console.log.assert_called_once() + assert "Torchrun" in mock_console.log.call_args[0][0] + + def test_setup_launcher_with_fault_tolerance(self): + """Test _setup_launcher with FaultTolerance launcher.""" + from nemo_run.core.execution.launcher import FaultTolerance + + ft_launcher = FaultTolerance() + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + nprocs_per_node=4, + launcher=ft_launcher, + ) + + executor.job_name = "my_ft_job" + executor.job_dir = "/workspace/jobs/my_ft_job" + + with patch("nemo_run.core.execution.dgxcloud.CONSOLE") as mock_console: + with patch("nemo_run.config.RUNDIR_NAME", "nemo_run"): + executor._setup_launcher() + + # Verify Torchrun settings + assert executor.ntasks_per_node == 1 + assert executor.torchrun_nproc_per_node == 4 + + # Verify FaultTolerance paths are set + assert ft_launcher.cfg_path == "/workspace/jobs/my_ft_job/my_ft_job/my_ft_job_ft_cfg.yml" + assert ft_launcher.finished_flag_file == "/nemo_run/my_ft_job_finished_flag" + assert ( + ft_launcher.job_results_file + == "/workspace/jobs/my_ft_job/my_ft_job/my_ft_job_job_results" + ) + + # Verify console log was called + mock_console.log.assert_called_once() + assert "FaultTolerance" in mock_console.log.call_args[0][0] + + def test_setup_launcher_fault_tolerance_paths(self): + """Test that FaultTolerance paths are correctly constructed.""" + from nemo_run.core.execution.launcher import FaultTolerance + + ft_launcher = FaultTolerance() + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + launcher=ft_launcher, + ) + + executor.job_name = "test_training" + executor.job_dir = "/mnt/workspace/test_training" + + with patch("nemo_run.core.execution.dgxcloud.CONSOLE"): + with patch("nemo_run.core.execution.dgxcloud.RUNDIR_NAME", "custom_rundir"): + executor._setup_launcher() + + # Check path construction + base_dir = "/mnt/workspace/test_training/test_training" + assert ft_launcher.cfg_path == f"{base_dir}/test_training_ft_cfg.yml" + assert ft_launcher.finished_flag_file == "/custom_rundir/test_training_finished_flag" + assert ft_launcher.job_results_file == f"{base_dir}/test_training_job_results" + + def test_setup_launcher_with_different_nprocs(self): + """Test _setup_launcher with different nprocs_per_node values.""" + from nemo_run.core.execution.launcher import Torchrun + + for nprocs in [1, 2, 4, 8, 16]: + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + nprocs_per_node=nprocs, + launcher=Torchrun(), + ) + + executor.job_name = "test_job" + executor.job_dir = "/workspace/test_job" + + with patch("nemo_run.core.execution.dgxcloud.CONSOLE"): + executor._setup_launcher() + + assert executor.torchrun_nproc_per_node == nprocs + assert executor.ntasks_per_node == 1 + + def test_setup_launcher_super_called(self): + """Test that _setup_launcher calls super()._setup_launcher().""" + from nemo_run.core.execution.launcher import Torchrun + + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + launcher=Torchrun(), + ) + + executor.job_name = "test_job" + executor.job_dir = "/workspace/test_job" + + with patch("nemo_run.core.execution.dgxcloud.CONSOLE"): + with patch.object( + executor.__class__.__bases__[0], "_setup_launcher" + ) as mock_super_setup: + executor._setup_launcher() + + # Verify super() was called + mock_super_setup.assert_called_once() + + +class TestDGXCloudRequest: + """Test DGXCloudRequest dataclass and its methods.""" + + @pytest.fixture + def basic_executor(self): + """Create a basic DGXCloudExecutor for testing.""" + return DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + ) + + @pytest.fixture + def executor_with_env_vars(self): + """Create a DGXCloudExecutor with environment variables.""" + return DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + env_vars={"EXECUTOR_VAR": "executor_value", "SHARED_VAR": "from_executor"}, + ) + + def test_dgxcloud_request_init(self, basic_executor): + """Test basic initialization of DGXCloudRequest.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1", "job2"], + executor=basic_executor, + max_retries=3, + extra_env={"EXTRA_VAR": "extra_value"}, + ) + + assert request.launch_cmd == ["python", "train.py"] + assert request.jobs == ["job1", "job2"] + assert request.executor == basic_executor + assert request.max_retries == 3 + assert request.extra_env == {"EXTRA_VAR": "extra_value"} + assert request.launcher is None + + def test_dgxcloud_request_with_launcher(self, basic_executor): + """Test DGXCloudRequest with a launcher.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + from nemo_run.core.execution.launcher import Torchrun + + launcher = Torchrun() + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1"], + executor=basic_executor, + max_retries=5, + extra_env={}, + launcher=launcher, + ) + + assert request.launcher == launcher + assert isinstance(request.launcher, Torchrun) + + def test_materialize_basic(self, basic_executor): + """Test materialization of a basic request without fault tolerance.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + + request = DGXCloudRequest( + launch_cmd=["python", "train.py", "--epochs", "10"], + jobs=["job1"], + executor=basic_executor, + max_retries=3, + extra_env={"MY_VAR": "my_value"}, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "#!/bin/bash\necho 'test script'" + script = request.materialize() + + # Verify fill_template was called + mock_fill.assert_called_once() + args, kwargs = mock_fill.call_args + assert args[0] == "dgxc.sh.j2" + + template_vars = args[1] + assert template_vars["max_retries"] == 3 + assert template_vars["training_command"] == "python train.py --epochs 10" + assert template_vars["ft_enabled"] is False + assert "export MY_VAR=my_value" in template_vars["env_vars"] + + assert script == "#!/bin/bash\necho 'test script'" + + def test_materialize_with_env_vars(self, executor_with_env_vars): + """Test that environment variables from executor and extra_env are merged.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1"], + executor=executor_with_env_vars, + max_retries=1, + extra_env={"EXTRA_VAR": "extra_value", "SHARED_VAR": "from_extra"}, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "mock_script" + request.materialize() + + template_vars = mock_fill.call_args[0][1] + env_vars = template_vars["env_vars"] + + # Check that variables are present (order may vary due to dict merge) + assert "export EXECUTOR_VAR=executor_value" in env_vars + assert "export EXTRA_VAR=extra_value" in env_vars + # extra_env should override executor.env_vars for SHARED_VAR + assert "export SHARED_VAR=from_extra" in env_vars + assert "export SHARED_VAR=from_executor" not in env_vars + + def test_materialize_with_fault_tolerance(self, basic_executor): + """Test materialization with fault tolerance enabled.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + from nemo_run.core.execution.launcher import FaultTolerance + + ft_launcher = FaultTolerance( + cfg_path="/workspace/ft_config.yaml", + finished_flag_file="/workspace/.ft_finished", + job_results_file="/workspace/ft_results.json", + ) + + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1"], + executor=basic_executor, + max_retries=5, + extra_env={}, + launcher=ft_launcher, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "ft_script" + _ = request.materialize() + + template_vars = mock_fill.call_args[0][1] + assert template_vars["ft_enabled"] is True + assert template_vars["fault_tol_cfg_path"] == "/workspace/ft_config.yaml" + assert template_vars["fault_tol_finished_flag_file"] == "/workspace/.ft_finished" + assert template_vars["fault_tol_job_results_file"] == "/workspace/ft_results.json" + + def test_materialize_fault_tolerance_missing_fields(self, basic_executor): + """Test that fault tolerance with missing required fields raises an error.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + from nemo_run.core.execution.launcher import FaultTolerance + + # Create FaultTolerance with missing required fields + ft_launcher = FaultTolerance( + cfg_path="/workspace/ft_config.yaml", + # Missing finished_flag_file and job_results_file + ) + + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1"], + executor=basic_executor, + max_retries=5, + extra_env={}, + launcher=ft_launcher, + ) + + with pytest.raises(AssertionError) as exc_info: + with patch("nemo_run.core.execution.dgxcloud.fill_template"): + request.materialize() + + assert "Fault Tolerance requires" in str(exc_info.value) + + def test_materialize_with_non_fault_tolerance_launcher(self, basic_executor): + """Test materialization with a non-FaultTolerance launcher (e.g., Torchrun).""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + from nemo_run.core.execution.launcher import Torchrun + + launcher = Torchrun() + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1"], + executor=basic_executor, + max_retries=2, + extra_env={}, + launcher=launcher, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "torchrun_script" + _ = request.materialize() + + template_vars = mock_fill.call_args[0][1] + # FT should be disabled for non-FaultTolerance launchers + assert template_vars["ft_enabled"] is False + # FT-specific fields should not be in template_vars + assert "fault_tol_cfg_path" not in template_vars + assert "fault_tol_finished_flag_file" not in template_vars + assert "fault_tol_job_results_file" not in template_vars + + def test_materialize_empty_extra_env(self, basic_executor): + """Test materialization with empty extra_env.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1"], + executor=basic_executor, + max_retries=1, + extra_env={}, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "script" + request.materialize() + + template_vars = mock_fill.call_args[0][1] + assert template_vars["env_vars"] == [] + + def test_materialize_uppercase_env_vars(self, basic_executor): + """Test that environment variable keys are uppercased.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1"], + executor=basic_executor, + max_retries=1, + extra_env={"lowercase_var": "value", "MixedCase": "value2"}, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "script" + request.materialize() + + template_vars = mock_fill.call_args[0][1] + env_vars = template_vars["env_vars"] + + # Keys should be uppercased + assert "export LOWERCASE_VAR=value" in env_vars + assert "export MIXEDCASE=value2" in env_vars + + def test_repr(self, basic_executor): + """Test the __repr__ method.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1", "job2"], + executor=basic_executor, + max_retries=3, + extra_env={}, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "#!/bin/bash\necho 'script content'" + repr_str = repr(request) + + assert "# DGXC Entrypoint Script Request" in repr_str + assert "# Executor: DGXCloudExecutor" in repr_str + assert "# Jobs: ['job1', 'job2']" in repr_str + assert "#!/bin/bash" in repr_str + assert "echo 'script content'" in repr_str + + def test_complex_launch_command(self, basic_executor): + """Test materialization with a complex multi-argument launch command.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + + request = DGXCloudRequest( + launch_cmd=[ + "torchrun", + "--nproc_per_node=8", + "--nnodes=2", + "train.py", + "--batch-size", + "32", + "--lr", + "0.001", + ], + jobs=["job1"], + executor=basic_executor, + max_retries=1, + extra_env={}, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "script" + request.materialize() + + template_vars = mock_fill.call_args[0][1] + expected_cmd = ( + "torchrun --nproc_per_node=8 --nnodes=2 train.py --batch-size 32 --lr 0.001" + ) + assert template_vars["training_command"] == expected_cmd + + def test_max_retries_values(self, basic_executor): + """Test different max_retries values.""" + from nemo_run.core.execution.dgxcloud import DGXCloudRequest + + for retries in [0, 1, 10, 100]: + request = DGXCloudRequest( + launch_cmd=["python", "train.py"], + jobs=["job1"], + executor=basic_executor, + max_retries=retries, + extra_env={}, + ) + + with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill: + mock_fill.return_value = "script" + request.materialize() + + template_vars = mock_fill.call_args[0][1] + assert template_vars["max_retries"] == retries diff --git a/test/run/torchx_backend/test_packaging.py b/test/run/torchx_backend/test_packaging.py index 9343c637..a908730e 100644 --- a/test/run/torchx_backend/test_packaging.py +++ b/test/run/torchx_backend/test_packaging.py @@ -23,10 +23,7 @@ from nemo_run.core.execution.launcher import FaultTolerance, Torchrun from nemo_run.core.execution.local import LocalExecutor from nemo_run.core.packaging.base import Packager -from nemo_run.run.torchx_backend.packaging import ( - merge_executables, - package, -) +from nemo_run.run.torchx_backend.packaging import merge_executables, package @dataclass(kw_only=True) @@ -265,15 +262,15 @@ def test_package_fault_tolerance(mock_executor): assert role.entrypoint == "ft_launcher" assert role.args == [ - "--ft-param-workload_check_interval", + "--ft-workload_check_interval", "10", - "--ft-param-initial_rank_heartbeat_timeout", + "--ft-initial_rank_heartbeat_timeout", "5", - "--ft-param-rank_heartbeat_timeout", + "--ft-rank_heartbeat_timeout", "5", - "--ft-param-rank_termination_signal", + "--ft-rank_termination_signal", "SIGINT", - "--ft-param-log_level", + "--ft-log_level", "INFO", "--rdzv-backend", "etcd",