Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from genkit.ai import ActionKind, GenkitRegistry
from genkit.core.action._action import ActionRunContext
from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_MODELS
from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS
from genkit.plugins.compat_oai.models.utils import DictMessageAdapter, MessageAdapter, MessageConverter
from genkit.plugins.compat_oai.typing import OpenAIConfig, SupportedOutputFormat
from genkit.types import (
Expand Down Expand Up @@ -120,8 +120,8 @@ def _get_response_format(self, output: OutputConfig) -> dict | None:
},
}

model = SUPPORTED_OPENAI_MODELS[self._model]
if SupportedOutputFormat.JSON_MODE in model.supports.output:
model = SUPPORTED_OPENAI_MODELS.get(self._model, SUPPORTED_OPENAI_COMPAT_MODELS.get(self._model))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and to make the logic clearer, you could simplify this line. Using or to chain the dictionary lookups is a more common and idiomatic Python pattern for this scenario. It also has the minor benefit of short-circuiting, so the second lookup is only performed if the first one fails.

Suggested change
model = SUPPORTED_OPENAI_MODELS.get(self._model, SUPPORTED_OPENAI_COMPAT_MODELS.get(self._model))
model = SUPPORTED_OPENAI_MODELS.get(self._model) or SUPPORTED_OPENAI_COMPAT_MODELS.get(self._model)

if model and SupportedOutputFormat.JSON_MODE in model.supports.output:
return {'type': 'json_object'}

return {'type': 'text'}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class PluginSource(StrEnum):

LLAMA_3_1 = 'meta/llama-3.1-405b-instruct-maas'
LLAMA_3_2 = 'meta/llama-3.2-90b-vision-instruct-maas'
MISTRAL_LARGE = 'mistralai/mistral-large'

SUPPORTED_OPENAI_MODELS: dict[str, ModelInfo] = {
GPT_3_5_TURBO: ModelInfo(
Expand Down Expand Up @@ -165,6 +166,16 @@ class PluginSource(StrEnum):
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
),
),
MISTRAL_LARGE: ModelInfo(
label='ModelGarden - Mistral - Large',
supports=Supports(
multiturn=True,
media=False,
tools=True,
systemRole=True,
output=[SupportedOutputFormat.JSON_MODE, SupportedOutputFormat.TEXT],
),
),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def __init__(
)

self.location = (
location
or os.getenv('GOOGLE_CLOUD_LOCATION')
or os.getenv('GOOGLE_CLOUD_REGION')
or const.DEFAULT_REGION
location or os.getenv('GOOGLE_CLOUD_LOCATION') or os.getenv('GOOGLE_CLOUD_REGION') or const.DEFAULT_REGION
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is quite long (over 100 characters) and can be difficult to read. The previous multi-line formatting was more readable and aligned better with Python style guides like PEP 8, which recommend limiting line length. I'd suggest reverting to the multi-line format for better maintainability.

            location
            or os.getenv('GOOGLE_CLOUD_LOCATION')
            or os.getenv('GOOGLE_CLOUD_REGION')
            or const.DEFAULT_REGION

)

self.models = models
Expand Down
80 changes: 40 additions & 40 deletions py/samples/model-garden/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,50 +23,50 @@

ai = Genkit(
plugins=[
VertexAIModelGarden(),
VertexAIModelGarden(location='us-central1'),
],
)


# @ai.flow()
# async def say_hi(name: str) -> str:
# """Generate a greeting for the given name.
#
# Args:
# name: The name of the person to greet.
#
# Returns:
# The generated greeting response.
# """
# response = await ai.generate(
# model=model_garden_name('meta/llama-3.2-90b-vision-instruct-maas'),
# config={'temperature': 1},
# prompt=f'hi {name}',
# )
#
# return response.message.content[0].root.text
@ai.flow()
async def say_hi(name: str) -> str:
"""Generate a greeting for the given name.

Args:
name: The name of the person to greet.

# @ai.flow()
# async def say_hi_stream(name: str) -> str:
# """Say hi to a name and stream the response.
#
# Args:
# name: The name to say hi to.
#
# Returns:
# The response from the OpenAI API.
# """
# stream, _ = ai.generate_stream(
# model=model_garden_name('meta/llama-3.2-90b-vision-instruct-maas'),
# config={'temperature': 1},
# prompt=f'hi {name}',
# )
# result = ''
# async for data in stream:
# for part in data.content:
# result += part.root.text
# return result
Returns:
The generated greeting response.
"""
response = await ai.generate(
model=model_garden_name('meta/llama-3.2-90b-vision-instruct-maas'),
config={'temperature': 1},
prompt=f'hi {name}',
)

return response.message.content[0].root.text


@ai.flow()
async def say_hi_stream(name: str) -> str:
"""Say hi to a name and stream the response.

Args:
name: The name to say hi to.

Returns:
The response from the Mistral API.
"""
stream, _ = ai.generate_stream(
model=model_garden_name('mistralai/mistral-large'),
config={'temperature': 1},
prompt=f'hi {name}',
)
result = ''
async for data in stream:
for part in data.content:
result += part.root.text
return result


@ai.flow()
Expand All @@ -90,8 +90,8 @@ async def jokes_flow(subject: str) -> str:


async def main() -> None:
# await logger.ainfo(await say_hi('John Doe'))
# await logger.ainfo(await say_hi_stream('John Doe'))
await logger.ainfo(await say_hi('John Doe'))
await logger.ainfo(await say_hi_stream('John Doe'))
await logger.ainfo(await jokes_flow('banana'))


Expand Down
Loading