diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py index e577784cbf..ad8307c156 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py @@ -24,7 +24,7 @@ from genkit.ai import ActionKind, GenkitRegistry from genkit.core.action._action import ActionRunContext -from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_MODELS +from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS from genkit.plugins.compat_oai.models.utils import DictMessageAdapter, MessageAdapter, MessageConverter from genkit.plugins.compat_oai.typing import OpenAIConfig, SupportedOutputFormat from genkit.types import ( @@ -120,8 +120,8 @@ def _get_response_format(self, output: OutputConfig) -> dict | None: }, } - model = SUPPORTED_OPENAI_MODELS[self._model] - if SupportedOutputFormat.JSON_MODE in model.supports.output: + model = SUPPORTED_OPENAI_MODELS.get(self._model, SUPPORTED_OPENAI_COMPAT_MODELS.get(self._model)) + if model and SupportedOutputFormat.JSON_MODE in model.supports.output: return {'type': 'json_object'} return {'type': 'text'} diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model_info.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model_info.py index 572de622a8..d539c1928a 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model_info.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model_info.py @@ -48,6 +48,7 @@ class PluginSource(StrEnum): LLAMA_3_1 = 'meta/llama-3.1-405b-instruct-maas' LLAMA_3_2 = 'meta/llama-3.2-90b-vision-instruct-maas' +MISTRAL_LARGE = 'mistralai/mistral-large' SUPPORTED_OPENAI_MODELS: dict[str, ModelInfo] = { GPT_3_5_TURBO: ModelInfo( @@ -165,6 +166,16 @@ class PluginSource(StrEnum): output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT], ), ), + MISTRAL_LARGE: ModelInfo( + label='ModelGarden - Mistral - Large', + supports=Supports( + multiturn=True, + media=False, + tools=True, + systemRole=True, + output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT], + ), + ), } diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py index e6204627fd..532151317a 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py @@ -70,10 +70,7 @@ def __init__( ) self.location = ( - location - or os.getenv('GOOGLE_CLOUD_LOCATION') - or os.getenv('GOOGLE_CLOUD_REGION') - or const.DEFAULT_REGION + location or os.getenv('GOOGLE_CLOUD_LOCATION') or os.getenv('GOOGLE_CLOUD_REGION') or const.DEFAULT_REGION ) self.models = models diff --git a/py/samples/model-garden/src/main.py b/py/samples/model-garden/src/main.py index f0f013478d..3b6658bdfe 100644 --- a/py/samples/model-garden/src/main.py +++ b/py/samples/model-garden/src/main.py @@ -23,50 +23,50 @@ ai = Genkit( plugins=[ - VertexAIModelGarden(), + VertexAIModelGarden(location='us-central1'), ], ) -# @ai.flow() -# async def say_hi(name: str) -> str: -# """Generate a greeting for the given name. -# -# Args: -# name: The name of the person to greet. -# -# Returns: -# The generated greeting response. -# """ -# response = await ai.generate( -# model=model_garden_name('meta/llama-3.2-90b-vision-instruct-maas'), -# config={'temperature': 1}, -# prompt=f'hi {name}', -# ) -# -# return response.message.content[0].root.text +@ai.flow() +async def say_hi(name: str) -> str: + """Generate a greeting for the given name. + Args: + name: The name of the person to greet. -# @ai.flow() -# async def say_hi_stream(name: str) -> str: -# """Say hi to a name and stream the response. -# -# Args: -# name: The name to say hi to. -# -# Returns: -# The response from the OpenAI API. -# """ -# stream, _ = ai.generate_stream( -# model=model_garden_name('meta/llama-3.2-90b-vision-instruct-maas'), -# config={'temperature': 1}, -# prompt=f'hi {name}', -# ) -# result = '' -# async for data in stream: -# for part in data.content: -# result += part.root.text -# return result + Returns: + The generated greeting response. + """ + response = await ai.generate( + model=model_garden_name('meta/llama-3.2-90b-vision-instruct-maas'), + config={'temperature': 1}, + prompt=f'hi {name}', + ) + + return response.message.content[0].root.text + + +@ai.flow() +async def say_hi_stream(name: str) -> str: + """Say hi to a name and stream the response. + + Args: + name: The name to say hi to. + + Returns: + The response from the Mistral API. + """ + stream, _ = ai.generate_stream( + model=model_garden_name('mistralai/mistral-large'), + config={'temperature': 1}, + prompt=f'hi {name}', + ) + result = '' + async for data in stream: + for part in data.content: + result += part.root.text + return result @ai.flow() @@ -90,8 +90,8 @@ async def jokes_flow(subject: str) -> str: async def main() -> None: - # await logger.ainfo(await say_hi('John Doe')) - # await logger.ainfo(await say_hi_stream('John Doe')) + await logger.ainfo(await say_hi('John Doe')) + await logger.ainfo(await say_hi_stream('John Doe')) await logger.ainfo(await jokes_flow('banana'))