Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv sync --extra dev --extra docs --extra llm
uv sync --extra dev --extra docs --extra llm --extra mcp
uv run python -m ensurepip
- name: Check types
run: |
Expand Down
6 changes: 3 additions & 3 deletions app/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ async def init_vllm_engine(app: FastAPI,
)

tokenizer = await engine.get_tokenizer()
vllm_config = await engine.get_vllm_config()
model_config = await engine.get_model_config()
vllm_config = await engine.get_vllm_config() # type: ignore
model_config = await engine.get_model_config() # type: ignore

await init_app_state(engine, vllm_config, app.state, args)
await init_app_state(engine, vllm_config, app.state, args) # type: ignore

async def generate_text(
request: Request,
Expand Down
63 changes: 58 additions & 5 deletions app/cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ $ cms [OPTIONS] COMMAND [ARGS]...

**Options**:

* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`
* `--host TEXT`
* `--port TEXT`
* `--install-completion`: Install completion for the current shell.
* `--show-completion`: Show completion for the current shell, to copy it or customize the installation.
* `--help`: Show this message and exit.
Expand All @@ -24,6 +27,7 @@ $ cms [OPTIONS] COMMAND [ARGS]...
* `export-openapi-spec`: This generates an API document for all...
* `stream`: This groups various stream operations
* `package`: This groups various package operations
* `mcp`: Run the MCP server for accessing CMS...

## `cms serve`

Expand All @@ -37,14 +41,17 @@ $ cms serve [OPTIONS]

**Options**:

* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner]`: The type of the model to serve [required]
* `--model-path TEXT`: The file path to the model package
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`: The type of the model to serve [required]
* `--model-path TEXT`: Either the file path to the local model package or the URL to the remote one
* `--mlflow-model-uri models:/MODEL_NAME/ENV`: The URI of the MLflow model to serve
* `--host TEXT`: The hostname of the server [default: 127.0.0.1]
* `--port TEXT`: The port of the server [default: 8000]
* `--model-name TEXT`: The string representation of the model name
* `--streamable / --no-streamable`: Serve the streamable endpoints only [default: no-streamable]
* `--device [default|cpu|cuda|mps]`: The device to serve the model on [default: default]
* `--llm-engine [CMS|vLLM]`: The engine to use for text generation [default: CMS]
* `--load-in-4bit / --no-load-in-4bit`: Load the model in 4-bit precision, used by 'huggingface_llm' models [default: no-load-in-4bit]
* `--load-in-8bit / --no-load-in-8bit`: Load the model in 8-bit precision, used by 'huggingface_llm' models [default: no-load-in-8bit]
* `--debug / --no-debug`: Run in the debug mode
* `--help`: Show this message and exit.

Expand All @@ -60,7 +67,7 @@ $ cms train [OPTIONS]

**Options**:

* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner]`: The type of the model to train [required]
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`: The type of the model to train [required]
* `--base-model-path TEXT`: The file path to the base model package to be trained on
* `--mlflow-model-uri models:/MODEL_NAME/ENV`: The URI of the MLflow model to train
* `--training-type [supervised|unsupervised|meta_supervised]`: The type of training [required]
Expand All @@ -71,6 +78,8 @@ $ cms train [OPTIONS]
* `--description TEXT`: The description of the training or change logs
* `--model-name TEXT`: The string representation of the model name
* `--device [default|cpu|cuda|mps]`: The device to train the model on [default: default]
* `--load-in-4bit / --no-load-in-4bit`: Load the model in 4-bit precision, used by 'huggingface_llm' models [default: no-load-in-4bit]
* `--load-in-8bit / --no-load-in-8bit`: Load the model in 8-bit precision, used by 'huggingface_llm' models [default: no-load-in-8bit]
* `--debug / --no-debug`: Run in the debug mode
* `--help`: Show this message and exit.

Expand All @@ -86,7 +95,7 @@ $ cms register [OPTIONS]

**Options**:

* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner]`: The type of the model to register [required]
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`: The type of the model to register [required]
* `--model-path TEXT`: The file path to the model package [required]
* `--model-name TEXT`: The string representation of the registered model [required]
* `--training-type [supervised|unsupervised|meta_supervised]`: The type of training the model went through
Expand All @@ -108,7 +117,7 @@ $ cms export-model-apis [OPTIONS]

**Options**:

* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner]`: The type of the model to serve [required]
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`: The type of the model to serve [required]
* `--add-training-apis / --no-add-training-apis`: Add training APIs to the doc [default: no-add-training-apis]
* `--add-evaluation-apis / --no-add-evaluation-apis`: Add evaluation APIs to the doc [default: no-add-evaluation-apis]
* `--add-previews-apis / --no-add-previews-apis`: Add preview APIs to the doc [default: no-add-previews-apis]
Expand Down Expand Up @@ -269,3 +278,47 @@ $ cms package hf-dataset [OPTIONS]
* `--remove-cached / --no-remove-cached`: Whether to remove the downloaded cache after the dataset package is saved [default: no-remove-cached]
* `--trust-remote-code / --no-trust-remote-code`: Whether to trust and use the remote script of the dataset [default: no-trust-remote-code]
* `--help`: Show this message and exit.

## `cms mcp`

Run the MCP server for accessing CMS capabilities

**Usage**:

```console
$ cms mcp [OPTIONS] COMMAND [ARGS]...
```

**Options**:

* `--help`: Show this message and exit.

**Commands**:

* `run`: Run the MCP server for accessing CMS...

### `cms mcp run`

Run the MCP server for accessing CMS capabilities

**Usage**:

```console
$ cms mcp run [OPTIONS]
```

**Options**:

* `--host TEXT`: The hostname of the MCP server [default: 127.0.0.1]
* `--port INTEGER`: The port of the MCP server [default: 8080]
* `--transport TEXT`: The transport type (either 'stdio', 'sse' or 'http') [default: http]
* `--cms-base-url TEXT`: The base URL of the CMS API [default: http://localhost:8000]
* `--cms-api-key TEXT`: The API key for authenticating with the CMS API
* `--mcp-api-keys TEXT`: Comma-separated API keys for authenticating MCP clients
* `--cms-mcp-oauth-enabled / --no-cms-mcp-oauth-enabled`: Whether to enable OAuth2 authentication for MCP clients
* `--github-client-id TEXT`: The GitHub OAuth2 client ID
* `--github-client-secret TEXT`: The GitHub OAuth2 client secret
* `--google-client-id TEXT`: The Google OAuth2 client ID
* `--google-client-secret TEXT`: The Google OAuth2 client secret
* `--debug / --no-debug`: Run in debug mode
* `--help`: Show this message and exit.
106 changes: 99 additions & 7 deletions app/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
import inspect
import warnings
import multiprocessing
import subprocess

current_frame = inspect.currentframe()
Expand Down Expand Up @@ -52,9 +53,11 @@

cmd_app = typer.Typer(name="cms", help="CLI for various CogStack ModelServe operations", add_completion=True)
stream_app = typer.Typer(name="stream", help="This groups various stream operations", add_completion=True)
cmd_app.add_typer(stream_app, name="stream")
mcp_app = typer.Typer(name="mcp", help="Run the MCP server for accessing CMS capabilities", add_completion=True)
package_app = typer.Typer(name="package", help="This groups various package operations", add_completion=True)
cmd_app.add_typer(stream_app, name="stream")
cmd_app.add_typer(package_app, name="package")
cmd_app.add_typer(mcp_app, name="mcp")
logging.config.fileConfig(os.path.join(parent_dir, "logging.ini"), disable_existing_loggers=False)

@cmd_app.command("serve", help="This serves various CogStack NLP models")
Expand All @@ -69,6 +72,7 @@ def serve_model(
device: Device = typer.Option(Device.DEFAULT.value, help="The device to serve the model on"),
llm_engine: Optional[LlmEngine] = typer.Option(LlmEngine.CMS.value, help="The engine to use for text generation"),
load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"),
load_in_8bit: Optional[bool] = typer.Option(False, help="Load the model in 8-bit precision, used by 'huggingface_llm' models"),
debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"),
) -> None:
"""
Expand All @@ -87,6 +91,7 @@ def serve_model(
device (Device): The device to serve the model on. Defaults to Device.DEFAULT.
llm_engine (LlmEngine): The inference engine to use. Defaults to LlmEngine.CMS.
load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False.
load_in_8bit (bool): Load the model in 8-bit precision, used by 'huggingface_llm' models. Defaults to False.
debug (Optional[bool]): Run in debug mode if set to True.
"""

Expand Down Expand Up @@ -138,7 +143,7 @@ def serve_model(
if model_path:
model_service = model_service_dep()
model_service.model_name = model_name
model_service.init_model(load_in_4bit=load_in_4bit)
model_service.init_model(load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit)
cms_globals.model_manager_dep = ModelManagerDep(model_service)
elif mlflow_model_uri:
model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path)
Expand Down Expand Up @@ -191,6 +196,7 @@ def train_model(
model_name: Optional[str] = typer.Option(None, help="The string representation of the model name"),
device: Device = typer.Option(Device.DEFAULT.value, help="The device to train the model on"),
load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"),
load_in_8bit: Optional[bool] = typer.Option(False, help="Load the model in 8-bit precision, used by 'huggingface_llm' models"),
debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"),
) -> None:
"""
Expand All @@ -211,6 +217,7 @@ def train_model(
model_name (Optional[str]): The optional string representation of the model name.
device (Device): The device to train the model on. Defaults to Device.DEFAULT.
load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False.
load_in_8bit (bool): Load the model in 8-bit precision, used by 'huggingface_llm' models. Defaults to False.
debug (Optional[bool]): Run in debug mode if set to True.
"""

Expand All @@ -232,7 +239,7 @@ def train_model(
pass
model_service = model_service_dep()
model_service.model_name = model_name if model_name is not None else "CMS model"
model_service.init_model(load_in_4bit=load_in_4bit)
model_service.init_model(load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit)
elif mlflow_model_uri:
model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path)
model_service.model_name = model_name if model_name is not None else "CMS model"
Expand Down Expand Up @@ -495,6 +502,7 @@ def package_model(

model_package_archive = os.path.abspath(os.path.expanduser(output_model_package))
if hf_repo_id:
download_path = None
try:
with tempfile.TemporaryDirectory() as tmp_dir:
if not hf_repo_revision:
Expand All @@ -510,15 +518,14 @@ def package_model(
local_dir=tmp_dir,
local_dir_use_symlinks=False,
)

shutil.make_archive(model_package_archive, archive_format.value, download_path)
_make_archive_file(model_package_archive, archive_format.value, download_path)
finally:
if remove_cached:
if remove_cached and download_path:
cached_model_path = os.path.abspath(os.path.join(download_path, "..", ".."))
shutil.rmtree(cached_model_path)
elif cached_model_dir:
cached_model_path = os.path.abspath(os.path.expanduser(cached_model_dir))
shutil.make_archive(model_package_archive, archive_format.value, cached_model_path)
_make_archive_file(model_package_archive, archive_format.value, cached_model_path)

typer.echo(f"Model package saved to {model_package_archive}.{'zip' if archive_format == ArchiveFormat.ZIP else 'tar.gz'}")

Expand Down Expand Up @@ -585,6 +592,73 @@ def package_dataset(
typer.echo(f"Dataset package saved to {dataset_package_archive}.{'zip' if archive_format == ArchiveFormat.ZIP else 'tar.gz'}")


@mcp_app.command("run", help="Run the MCP server for accessing CMS capabilities")
def run_mcp_server(
host: str = typer.Option("127.0.0.1", help="The hostname of the MCP server"),
port: int = typer.Option(8080, help="The port of the MCP server"),
transport: str = typer.Option("http", help="The transport type (either 'stdio', 'sse' or 'http')"),
cms_base_url: str = typer.Option("http://127.0.0.1:8000", help="The base URL of the CMS API"),
cms_api_key: str = typer.Option("Bearer", help="The API key for authenticating with the CMS API"),
mcp_api_keys: str = typer.Option("", help="Comma-separated API keys for authenticating MCP clients"),
cms_mcp_oauth_enabled: Optional[bool] = typer.Option(None, help="Whether to enable OAuth2 authentication for MCP clients"),
github_client_id: str = typer.Option("", help="The GitHub OAuth2 client ID"),
github_client_secret: str = typer.Option("", help="The GitHub OAuth2 client secret"),
google_client_id: str = typer.Option("", help="The Google OAuth2 client ID"),
google_client_secret: str = typer.Option("", help="The Google OAuth2 client secret"),
debug: Optional[bool] = typer.Option(None, help="Run in debug mode"),
) -> None:
"""
Runs the CogStack ModelServe MCP server.

This function starts an MCP server that provides AI assistants with tools to interact
with deployed CMS models through the Model Context Protocol interface.

Args:
host (str): The hostname of the MCP server. Defaults to "127.0.0.1".
port (int): The port of the MCP server. Defaults to 8080.
transport (str): The transport type for the MCP server. Can be "stdio" or "http". Defaults to "stdio".
cms_base_url (str): The base URL of the CMS API endpoint. Defaults to "http://localhost:8000".
debug (Optional[bool]): Run in debug mode if set to True.
"""

logger = _get_logger(debug)
logger.info("Starting CMS MCP server...")

os.environ["CMS_BASE_URL"] = cms_base_url
os.environ["CMS_MCP_SERVER_HOST"] = host
os.environ["CMS_MCP_SERVER_PORT"] = str(port)
os.environ["CMS_MCP_TRANSPORT"] = transport.lower()
os.environ["CMS_API_KEY"] = cms_api_key
os.environ["MCP_API_KEYS"] = mcp_api_keys
os.environ["CMS_MCP_OAUTH_ENABLED"] = "true" if cms_mcp_oauth_enabled else "false"
os.environ["GITHUB_CLIENT_ID"] = github_client_id
os.environ["GITHUB_CLIENT_SECRET"] = github_client_secret
os.environ["GOOGLE_CLIENT_ID"] = google_client_id
os.environ["GOOGLE_CLIENT_SECRET"] = google_client_secret

if debug:
os.environ["CMS_MCP_DEV"] = "1"

try:
from app.mcp.server import main
logger.info(f"MCP server starting with transport: {transport}")
logger.info(f"Connected to CMS API at {cms_base_url}")
main()
except ImportError as e:
logger.error(f"Cannot import MCP. Please install it with `pip install cms[mcp]`: {e}")
typer.echo(f"ERROR: Cannot import MCP: {e}")
typer.echo("Please install it with `pip install cms[mcp]`.")
raise typer.Exit(code=1)
except KeyboardInterrupt:
logger.info("MCP server stopped by the user")
typer.echo("MCP server stopped.")
raise typer.Exit(code=0)
except Exception as e:
logger.error(f"Failed to start MCP server: {e}")
typer.echo(f"ERROR: Failed to start MCP server: {e}")
raise typer.Exit(code=1)


@cmd_app.command("build", help="This builds an OCI-compliant image to containerise CMS")
def build_image(
dockerfile_path: str = typer.Option(..., help="The path to the Dockerfile"),
Expand Down Expand Up @@ -798,6 +872,24 @@ def _ensure_dst_model_path(model_path: str, parent_dir: str, config: Settings) -
return dst_model_path


def _make_archive_file(base_name: str, format: str, root_dir: str) -> None:
if format == ArchiveFormat.TAR_GZ.value:
try:
result = subprocess.run(["which", "pigz"], capture_output=True, text=True, check=True)
if result.returncode == 0:
num_cores = max(1, multiprocessing.cpu_count() - 1)
compress_program = f"pigz -p {num_cores}"
subprocess.run(
["tar", f"--use-compress-program={compress_program}", "-cf", f"{base_name}.tar.gz", "-C", root_dir, "."],
check=True
)
return
except subprocess.CalledProcessError:
typer.echo("Use non-parallel compression...")

shutil.make_archive(base_name, format, root_dir)


def _get_logger(
debug: Optional[bool] = None,
model_type: Optional[ModelType] = None,
Expand Down
Loading