Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1012aa0
agent hyperparam interface
alexmanle Jan 27, 2026
bd551e4
fix formatting
alexmanle Jan 27, 2026
c1ecf25
default pass no kwargs
alexmanle Jan 27, 2026
9e7d6c3
update error logging
alexmanle Jan 29, 2026
e7fee0c
fix docstring
alexmanle Jan 29, 2026
cfb38f5
make agent interface abstract
alexmanle Jan 31, 2026
ab3df17
fix copyright
alexmanle Feb 3, 2026
4d2667f
better kwargs handling
alexmanle Feb 3, 2026
5afe35c
Merge pull request #792 from alexmanle/agent-params
srivatsankrishnan Feb 3, 2026
09ee673
add agent seed parameter support/validation
alexmanle Feb 11, 2026
db8f259
fix unneeded dict modification and error printing
alexmanle Feb 11, 2026
3c39487
Merge pull request #804 from alexmanle/agent-dev
srivatsankrishnan Feb 11, 2026
8583261
m-bridge for nemo container 26.02
srivatsankrishnan Jan 7, 2026
357c00a
fix pyright issue
srivatsankrishnan Jan 7, 2026
9d3d454
remove domain flag
srivatsankrishnan Jan 12, 2026
18fa653
fix the stderr/stdout
srivatsankrishnan Jan 13, 2026
c814974
make cuda_graph_impl dse'ble flag
srivatsankrishnan Jan 17, 2026
c256e98
fix regex
srivatsankrishnan Jan 17, 2026
b8cdf37
add constraint for moe_overlap
srivatsankrishnan Jan 21, 2026
89e1159
add vp/pp constrint check for deepseek
srivatsankrishnan Jan 27, 2026
37053e5
fix the new flags
srivatsankrishnan Feb 12, 2026
a7172ce
fix
srivatsankrishnan Feb 12, 2026
a5a1399
more fixes
srivatsankrishnan Feb 12, 2026
69bc44d
remove -cm and update all flags in m-bridge
srivatsankrishnan Feb 13, 2026
fe1a186
avoid mounting the install repo inside the container
srivatsankrishnan Feb 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ mount_as = "/opt/Megatron-Bridge"
[cmd_args]
gpu_type = "b200"
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
gpus_per_node = 8
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
gpus_per_node = 4
num_gpus = 8
domain = "llm"
task = "pretrain"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
[cmd_args]
gpu_type = "gb200"
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
gpus_per_node = 4
num_gpus = 8
domain = "llm"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
[cmd_args]
gpu_type = "gb300"
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
gpus_per_node = 4
num_gpus = 8
domain = "llm"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
[cmd_args]
gpu_type = "h100"
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
gpus_per_node = 8
num_gpus = 16
domain = "llm"
Expand Down
10 changes: 5 additions & 5 deletions doc/workloads/megatron_bridge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ Test TOML example:
[cmd_args]
# Container can be an NGC/enroot URL (nvcr.io#...) or a local .sqsh path.
container_image = "nvcr.io#nvidia/nemo:25.11.01"

model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
task = "pretrain"
domain = "llm"
compute_dtype = "fp8_mx"
Expand Down Expand Up @@ -55,8 +55,8 @@ Test-in-Scenario example:

[Tests.cmd_args]
container_image = "nvcr.io#nvidia/nemo:25.11.01"
model_name = "qwen3"
model_size = "30b_a3b"
model_family_name = "qwen3"
model_recipe_name = "30b_a3b"
task = "pretrain"
domain = "llm"
compute_dtype = "fp8_mx"
Expand Down
87 changes: 82 additions & 5 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,11 +21,12 @@
import signal
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional
from unittest.mock import Mock

import toml
import yaml
from pydantic import ValidationError

from cloudai.core import (
BaseInstaller,
Expand All @@ -41,10 +42,10 @@
TestScenario,
)
from cloudai.models.scenario import ReportConfig
from cloudai.models.workload import TestDefinition
from cloudai.models.workload import TestDefinition, TestRun
from cloudai.parser import HOOK_ROOT
from cloudai.systems.slurm import SingleSbatchRunner, SlurmSystem
from cloudai.util import prepare_output_dir
from cloudai.util import flatten_dict, prepare_output_dir


def _log_installation_dirs(prefix: str, system: System) -> None:
Expand Down Expand Up @@ -145,7 +146,23 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
continue

env = CloudAIGymEnv(test_run=test_run, runner=runner.runner)
agent = agent_class(env)

try:
agent_config = test_run.test.agent_config
agent_overrides = (
validate_agent_overrides(test_run, agent_type, agent_config) if agent_config is not None else None
)
except ValidationError as e:
logging.error(f"Invalid agent_config for agent '{agent_type}': ")
for error in e.errors():
logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}")
logging.error("Valid overrides: ")
for item, desc in validate_agent_overrides(test_run, agent_type).items():
logging.error(f" - {item}: {desc}")
err = 1
continue
agent = agent_class(env, **agent_overrides) if agent_overrides is not None else agent_class(env)

for step in range(agent.max_steps):
result = agent.select_action()
if result is None:
Expand All @@ -166,6 +183,66 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
return err


def validate_agent_overrides(
test_run: TestRun, agent_type: str, agent_config: Optional[dict[str, Any]] = None
) -> dict[str, Any]:
"""
Validate and process agent configuration overrides.

If agent_config is empty, returns the available configuration fields for the agent type.
"""
registry = Registry()
config_class_map = {}
for agent_name, agent_class in registry.agents_map.items():
if agent_class.config:
config_class_map[agent_name] = agent_class.config

config_class = config_class_map.get(agent_type)
if not config_class:
valid_types = ", ".join(f"'{agent_name}'" for agent_name in config_class_map)
raise ValueError(
f"Agent type '{agent_type}' does not support configuration overrides. "
f"Valid agent types are: {valid_types}. "
)

if agent_config:
seed_parameters = agent_config.get("seed_parameters", None)
if seed_parameters:
validate_seed_parameters(test_run, seed_parameters)

validated_config = config_class.model_validate(agent_config)
agent_kwargs = validated_config.model_dump(exclude_none=True)
logging.debug(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}")
else:
agent_kwargs = {}
for field_name, field_info in config_class.model_fields.items():
agent_kwargs[field_name] = field_info.description
return agent_kwargs


def validate_seed_parameters(test_run: TestRun, seed_parameters: dict[str, Any]) -> None:
"""Validate seed parameters against DSE-able command-line arguments."""
flat_cmd_args = flatten_dict(test_run.test.cmd_args.model_dump(exclude_none=True))
dse_cmd_args = {k: v for k, v in flat_cmd_args.items() if isinstance(v, list)}

logging.debug("Validating seed parameters against DSE-able command-line arguments:")
logging.debug(f"\t{dse_cmd_args}")

for key, value in seed_parameters.items():
if key not in dse_cmd_args:
raise KeyError(
f"Seed parameter '{key}' not found in DSE-able command-line arguments. "
f"Ensure that the key is one of the following available keys: {list(dse_cmd_args.keys())}"
)
if value not in dse_cmd_args[key]:
raise ValueError(
f"Seed parameter '{key}' value '{value}' not found in DSE-able command-line arguments. "
f"Ensure that the value is one of the following available values: {dse_cmd_args[key]}"
)

logging.debug("Seed parameters validated successfully.")


def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None:
registry = Registry()

Expand Down
8 changes: 6 additions & 2 deletions src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,7 +15,9 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

from cloudai.models.agent_config import AgentConfig

from .base_gym import BaseGym

Expand All @@ -28,6 +30,8 @@ class BaseAgent(ABC):
Automatically infers parameter types from TestRun's cmd_args.
"""

config: Optional[AgentConfig] = None

def __init__(self, env: BaseGym):
"""
Initialize the agent with the environment.
Expand Down
28 changes: 28 additions & 0 deletions src/cloudai/models/agent_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, Field


class AgentConfig(BaseModel, ABC):
"""Base configuration for agent overrides."""

model_config = ConfigDict(extra="forbid")
random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
seed_parameters: Optional[dict[str, Any]] = Field(default=None, description="Seed parameters for reproducibility")
3 changes: 2 additions & 1 deletion src/cloudai/models/workload.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -107,6 +107,7 @@ class TestDefinition(BaseModel, ABC):
agent_steps: int = 1
agent_metrics: list[str] = Field(default=["default"])
agent_reward_function: str = "inverse"
agent_config: Optional[dict[str, Any]] = None

@property
def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]:
Expand Down
Loading