From cf5ca24a2af469ea366bd78b3498922091f1a6a2 Mon Sep 17 00:00:00 2001 From: Nagarjun Krishnan Date: Mon, 9 Feb 2026 08:06:06 -0700 Subject: [PATCH 1/3] Add model discovery feature for UI Signed-off-by: Nagarjun Krishnan --- go/api/v1alpha2/provider_types.go | 130 +++++ go/api/v1alpha2/zz_generated.deepcopy.go | 121 +++++ go/config/crd/bases/kagent.dev_providers.yaml | 199 +++++++ go/config/rbac/role.yaml | 3 + go/config/samples/provider_anthropic.yaml | 11 + go/config/samples/provider_openai.yaml | 11 + go/config/samples/provider_secrets.yaml | 17 + go/go.mod | 1 + go/go.sum | 4 +- .../mcp_server_tool_controller_test.go | 4 + go/internal/controller/provider/discoverer.go | 217 ++++++++ .../controller/provider/discoverer_test.go | 511 ++++++++++++++++++ go/internal/controller/provider/manager.go | 148 +++++ .../controller/provider/manager_test.go | 403 ++++++++++++++ go/internal/controller/provider/types.go | 59 ++ go/internal/controller/provider_controller.go | 117 ++++ .../controller/reconciler/reconciler.go | 193 +++++++ .../controller/service_controller_test.go | 4 + go/internal/httpserver/handlers/handlers.go | 8 +- go/internal/httpserver/handlers/providers.go | 81 ++- go/internal/httpserver/server.go | 6 +- go/pkg/app/app.go | 14 + .../templates/kagent.dev_providers.yaml | 195 +++++++ helm/kagent/templates/rbac/clusterrole.yaml | 5 + ui/jest.setup.ts | 16 + ui/src/app/actions/providers.ts | 38 +- .../new/__tests__/providerSelection.test.tsx | 169 ++++++ ui/src/app/models/new/page.tsx | 180 +++++- ui/src/components/ModelCombobox.tsx | 91 ++++ ui/src/components/ProviderCombobox.tsx | 156 ++++++ .../models/new/BasicInfoSection.tsx | 122 ++++- ui/src/lib/__tests__/utils.test.ts | 3 +- ui/src/lib/utils.ts | 4 - ui/src/types/index.ts | 16 +- 34 files changed, 3202 insertions(+), 55 deletions(-) create mode 100644 go/api/v1alpha2/provider_types.go create mode 100644 go/config/crd/bases/kagent.dev_providers.yaml create mode 100644 go/config/samples/provider_anthropic.yaml create mode 100644 go/config/samples/provider_openai.yaml create mode 100644 go/config/samples/provider_secrets.yaml create mode 100644 go/internal/controller/provider/discoverer.go create mode 100644 go/internal/controller/provider/discoverer_test.go create mode 100644 go/internal/controller/provider/manager.go create mode 100644 go/internal/controller/provider/manager_test.go create mode 100644 go/internal/controller/provider/types.go create mode 100644 go/internal/controller/provider_controller.go create mode 100644 helm/kagent-crds/templates/kagent.dev_providers.yaml create mode 100644 ui/src/app/models/new/__tests__/providerSelection.test.tsx create mode 100644 ui/src/components/ModelCombobox.tsx create mode 100644 ui/src/components/ProviderCombobox.tsx diff --git a/go/api/v1alpha2/provider_types.go b/go/api/v1alpha2/provider_types.go new file mode 100644 index 000000000..22a21c4b3 --- /dev/null +++ b/go/api/v1alpha2/provider_types.go @@ -0,0 +1,130 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha2 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + // ProviderConditionTypeReady indicates whether the provider is ready for use + ProviderConditionTypeReady = "Ready" + + // ProviderConditionTypeSecretResolved indicates whether the provider's secret reference is valid + ProviderConditionTypeSecretResolved = "SecretResolved" + + // ProviderConditionTypeModelsDiscovered indicates whether model discovery has succeeded + ProviderConditionTypeModelsDiscovered = "ModelsDiscovered" + + // ProviderAnnotationForceDiscovery is set by clients to trigger immediate model discovery + ProviderAnnotationForceDiscovery = "kagent.dev/force-discovery" +) + +// SecretReference contains information to locate a secret. +type SecretReference struct { + // Name is the name of the secret in the same namespace as the Provider + // +required + Name string `json:"name"` + + // Key is the key within the secret that contains the API key or credential + // +required + Key string `json:"key"` +} + +// ProviderSpec defines the desired state of Provider. +// +// +kubebuilder:validation:XValidation:message="endpoint must be a valid URL starting with http:// or https://",rule="self.endpoint.startsWith('http://') || self.endpoint.startsWith('https://')" +// +kubebuilder:validation:XValidation:message="secretRef.name and secretRef.key are required",rule="has(self.secretRef) && has(self.secretRef.name) && size(self.secretRef.name) > 0 && has(self.secretRef.key) && size(self.secretRef.key) > 0" +type ProviderSpec struct { + // Type is the model provider type (OpenAI, Anthropic, etc.) + // +required + // +kubebuilder:validation:Required + Type ModelProvider `json:"type"` + + // Endpoint is the API endpoint URL for the provider + // +required + // +kubebuilder:validation:Required + // +kubebuilder:validation:Pattern=`^https?://.*` + Endpoint string `json:"endpoint"` + + // SecretRef references the Kubernetes Secret containing the API key + // +required + // +kubebuilder:validation:Required + SecretRef SecretReference `json:"secretRef"` +} + +// ProviderStatus defines the observed state of Provider. +type ProviderStatus struct { + // ObservedGeneration reflects the generation of the most recently observed Provider spec + // +optional + ObservedGeneration int64 `json:"observedGeneration,omitempty"` + + // Conditions represent the latest available observations of the Provider's state + // +optional + // +listType=map + // +listMapKey=type + Conditions []metav1.Condition `json:"conditions,omitempty"` + + // DiscoveredModels is the cached list of model IDs available from this provider + // +optional + DiscoveredModels []string `json:"discoveredModels,omitempty"` + + // ModelCount is the number of discovered models (for kubectl display) + // +optional + ModelCount int `json:"modelCount,omitempty"` + + // LastDiscoveryTime is the timestamp of the last successful model discovery + // +optional + LastDiscoveryTime *metav1.Time `json:"lastDiscoveryTime,omitempty"` + + // SecretHash is a hash of the referenced secret data, used to detect secret changes + // +optional + SecretHash string `json:"secretHash,omitempty"` +} + +// +kubebuilder:object:root=true +// +kubebuilder:resource:categories=kagent,shortName=prov +// +kubebuilder:subresource:status +// +kubebuilder:printcolumn:name="Type",type="string",JSONPath=".spec.type" +// +kubebuilder:printcolumn:name="Endpoint",type="string",JSONPath=".spec.endpoint" +// +kubebuilder:printcolumn:name="Models",type="integer",JSONPath=".status.modelCount" +// +kubebuilder:printcolumn:name="Ready",type="string",JSONPath=".status.conditions[?(@.type=='Ready')].status" +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" +// +kubebuilder:storageversion + +// Provider is the Schema for the providers API. +// It represents a model provider configuration with automatic model discovery. +type Provider struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec ProviderSpec `json:"spec,omitempty"` + Status ProviderStatus `json:"status,omitempty"` +} + +// +kubebuilder:object:root=true + +// ProviderList contains a list of Provider. +type ProviderList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []Provider `json:"items"` +} + +func init() { + SchemeBuilder.Register(&Provider{}, &ProviderList{}) +} diff --git a/go/api/v1alpha2/zz_generated.deepcopy.go b/go/api/v1alpha2/zz_generated.deepcopy.go index 100fdbbf1..0a41e821f 100644 --- a/go/api/v1alpha2/zz_generated.deepcopy.go +++ b/go/api/v1alpha2/zz_generated.deepcopy.go @@ -695,6 +695,112 @@ func (in *OpenAIConfig) DeepCopy() *OpenAIConfig { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Provider) DeepCopyInto(out *Provider) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + out.Spec = in.Spec + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Provider. +func (in *Provider) DeepCopy() *Provider { + if in == nil { + return nil + } + out := new(Provider) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *Provider) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ProviderList) DeepCopyInto(out *ProviderList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]Provider, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProviderList. +func (in *ProviderList) DeepCopy() *ProviderList { + if in == nil { + return nil + } + out := new(ProviderList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *ProviderList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ProviderSpec) DeepCopyInto(out *ProviderSpec) { + *out = *in + out.SecretRef = in.SecretRef +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProviderSpec. +func (in *ProviderSpec) DeepCopy() *ProviderSpec { + if in == nil { + return nil + } + out := new(ProviderSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ProviderStatus) DeepCopyInto(out *ProviderStatus) { + *out = *in + if in.Conditions != nil { + in, out := &in.Conditions, &out.Conditions + *out = make([]metav1.Condition, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.DiscoveredModels != nil { + in, out := &in.DiscoveredModels, &out.DiscoveredModels + *out = make([]string, len(*in)) + copy(*out, *in) + } + if in.LastDiscoveryTime != nil { + in, out := &in.LastDiscoveryTime, &out.LastDiscoveryTime + *out = (*in).DeepCopy() + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProviderStatus. +func (in *ProviderStatus) DeepCopy() *ProviderStatus { + if in == nil { + return nil + } + out := new(ProviderStatus) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *RemoteMCPServer) DeepCopyInto(out *RemoteMCPServer) { *out = *in @@ -829,6 +935,21 @@ func (in *RemoteMCPServerStatus) DeepCopy() *RemoteMCPServerStatus { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SecretReference) DeepCopyInto(out *SecretReference) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SecretReference. +func (in *SecretReference) DeepCopy() *SecretReference { + if in == nil { + return nil + } + out := new(SecretReference) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ServiceAccountConfig) DeepCopyInto(out *ServiceAccountConfig) { *out = *in diff --git a/go/config/crd/bases/kagent.dev_providers.yaml b/go/config/crd/bases/kagent.dev_providers.yaml new file mode 100644 index 000000000..834f73f8e --- /dev/null +++ b/go/config/crd/bases/kagent.dev_providers.yaml @@ -0,0 +1,199 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.19.0 + name: providers.kagent.dev +spec: + group: kagent.dev + names: + categories: + - kagent + kind: Provider + listKind: ProviderList + plural: providers + shortNames: + - prov + singular: provider + scope: Namespaced + versions: + - additionalPrinterColumns: + - jsonPath: .spec.type + name: Type + type: string + - jsonPath: .spec.endpoint + name: Endpoint + type: string + - jsonPath: .status.modelCount + name: Models + type: integer + - jsonPath: .status.conditions[?(@.type=='Ready')].status + name: Ready + type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date + name: v1alpha2 + schema: + openAPIV3Schema: + description: |- + Provider is the Schema for the providers API. + It represents a model provider configuration with automatic model discovery. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: ProviderSpec defines the desired state of Provider. + properties: + endpoint: + description: Endpoint is the API endpoint URL for the provider + pattern: ^https?://.* + type: string + secretRef: + description: SecretRef references the Kubernetes Secret containing + the API key + properties: + key: + description: Key is the key within the secret that contains the + API key or credential + type: string + name: + description: Name is the name of the secret in the same namespace + as the Provider + type: string + required: + - key + - name + type: object + type: + description: Type is the model provider type (OpenAI, Anthropic, etc.) + enum: + - Anthropic + - OpenAI + - AzureOpenAI + - Ollama + - Gemini + - GeminiVertexAI + - AnthropicVertexAI + type: string + required: + - endpoint + - secretRef + - type + type: object + x-kubernetes-validations: + - message: endpoint must be a valid URL starting with http:// or https:// + rule: self.endpoint.startsWith('http://') || self.endpoint.startsWith('https://') + - message: secretRef.name and secretRef.key are required + rule: has(self.secretRef) && has(self.secretRef.name) && size(self.secretRef.name) + > 0 && has(self.secretRef.key) && size(self.secretRef.key) > 0 + status: + description: ProviderStatus defines the observed state of Provider. + properties: + conditions: + description: Conditions represent the latest available observations + of the Provider's state + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + type: array + x-kubernetes-list-map-keys: + - type + x-kubernetes-list-type: map + discoveredModels: + description: DiscoveredModels is the cached list of model IDs available + from this provider + items: + type: string + type: array + lastDiscoveryTime: + description: LastDiscoveryTime is the timestamp of the last successful + model discovery + format: date-time + type: string + modelCount: + description: ModelCount is the number of discovered models (for kubectl + display) + type: integer + observedGeneration: + description: ObservedGeneration reflects the generation of the most + recently observed Provider spec + format: int64 + type: integer + secretHash: + description: SecretHash is a hash of the referenced secret data, used + to detect secret changes + type: string + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/go/config/rbac/role.yaml b/go/config/rbac/role.yaml index 14b23eb3d..8aa4e89bc 100644 --- a/go/config/rbac/role.yaml +++ b/go/config/rbac/role.yaml @@ -43,6 +43,7 @@ rules: resources: - agents - modelconfigs + - providers - remotemcpservers verbs: - create @@ -57,6 +58,7 @@ rules: resources: - agents/finalizers - modelconfigs/finalizers + - providers/finalizers - remotemcpservers/finalizers verbs: - update @@ -65,6 +67,7 @@ rules: resources: - agents/status - modelconfigs/status + - providers/status - remotemcpservers/status verbs: - get diff --git a/go/config/samples/provider_anthropic.yaml b/go/config/samples/provider_anthropic.yaml new file mode 100644 index 000000000..b3247ad83 --- /dev/null +++ b/go/config/samples/provider_anthropic.yaml @@ -0,0 +1,11 @@ +apiVersion: kagent.dev/v1alpha2 +kind: Provider +metadata: + name: anthropic-prod + namespace: kagent +spec: + type: Anthropic + endpoint: https://api.anthropic.com + secretRef: + name: anthropic-secret + key: apiKey diff --git a/go/config/samples/provider_openai.yaml b/go/config/samples/provider_openai.yaml new file mode 100644 index 000000000..7958632af --- /dev/null +++ b/go/config/samples/provider_openai.yaml @@ -0,0 +1,11 @@ +apiVersion: kagent.dev/v1alpha2 +kind: Provider +metadata: + name: openai-prod + namespace: kagent +spec: + type: OpenAI + endpoint: https://api.openai.com/v1 + secretRef: + name: openai-secret + key: apiKey diff --git a/go/config/samples/provider_secrets.yaml b/go/config/samples/provider_secrets.yaml new file mode 100644 index 000000000..664b18f01 --- /dev/null +++ b/go/config/samples/provider_secrets.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Secret +metadata: + name: openai-secret + namespace: kagent +type: Opaque +stringData: + apiKey: "sk-proj-REPLACE_WITH_YOUR_KEY" +--- +apiVersion: v1 +kind: Secret +metadata: + name: anthropic-secret + namespace: kagent +type: Opaque +stringData: + apiKey: "sk-ant-REPLACE_WITH_YOUR_KEY" diff --git a/go/go.mod b/go/go.mod index 182360c4d..cc1a5c400 100644 --- a/go/go.mod +++ b/go/go.mod @@ -16,6 +16,7 @@ require ( github.com/jedib0t/go-pretty/v6 v6.7.8 github.com/kagent-dev/kmcp v0.2.5 github.com/kagent-dev/mockllm v0.0.3 + github.com/mark3labs/mcp-go v0.33.0 github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/muesli/reflow v0.3.0 github.com/prometheus/client_golang v1.23.2 diff --git a/go/go.sum b/go/go.sum index 02a5e1ba4..780047beb 100644 --- a/go/go.sum +++ b/go/go.sum @@ -165,8 +165,6 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/kagent-dev/kmcp v0.2.2 h1:uvbKmo9IT6OT9RBNXYwGX0PWyxBLfAW1F9yWd5/wxaI= -github.com/kagent-dev/kmcp v0.2.2/go.mod h1:g7wS/3m2wonRo/1DMwVoHxnilr/urPgV2hwV1DwkwrQ= github.com/kagent-dev/kmcp v0.2.5 h1:Em5A2vROJuR5JpMe5luSMe2vQJTwxX93AMXJm6Lg/E0= github.com/kagent-dev/kmcp v0.2.5/go.mod h1:g7wS/3m2wonRo/1DMwVoHxnilr/urPgV2hwV1DwkwrQ= github.com/kagent-dev/mockllm v0.0.3 h1:hk6Oa/vxHoBrGqRig4GCzox8EqRQYXM4c3oFPP/k9Tg= @@ -195,6 +193,8 @@ github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQ github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc= +github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= diff --git a/go/internal/controller/mcp_server_tool_controller_test.go b/go/internal/controller/mcp_server_tool_controller_test.go index b823061b2..35f518738 100644 --- a/go/internal/controller/mcp_server_tool_controller_test.go +++ b/go/internal/controller/mcp_server_tool_controller_test.go @@ -39,6 +39,10 @@ func (f *fakeReconciler) ReconcileKagentRemoteMCPServer(ctx context.Context, req return nil } +func (f *fakeReconciler) ReconcileKagentProvider(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + return ctrl.Result{}, nil +} + func (f *fakeReconciler) GetOwnedResourceTypes() []client.Object { return nil } diff --git a/go/internal/controller/provider/discoverer.go b/go/internal/controller/provider/discoverer.go new file mode 100644 index 000000000..9cba0868f --- /dev/null +++ b/go/internal/controller/provider/discoverer.go @@ -0,0 +1,217 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package provider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + v1alpha2 "github.com/kagent-dev/kagent/go/api/v1alpha2" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + // DefaultTimeout is the default HTTP timeout for model discovery requests + DefaultTimeout = 30 * time.Second +) + +// ModelDiscoverer fetches available models from LLM provider APIs. +// It supports OpenAI-compatible APIs and provider-specific endpoints. +type ModelDiscoverer struct { + httpClient *http.Client +} + +// NewModelDiscoverer creates a new ModelDiscoverer instance. +func NewModelDiscoverer() *ModelDiscoverer { + return &ModelDiscoverer{ + httpClient: &http.Client{ + Timeout: DefaultTimeout, + }, + } +} + +// openAIModelsResponse represents the response from OpenAI-compatible /models endpoints. +// This format is used by OpenAI, Anthropic, and most other providers. +type openAIModelsResponse struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` +} + +// DiscoverModels calls the provider's models endpoint and returns available model IDs. +// Most providers use OpenAI-compatible /v1/models, but auth headers vary by provider. +func (d *ModelDiscoverer) DiscoverModels(ctx context.Context, providerType v1alpha2.ModelProvider, endpoint, apiKey string) ([]string, error) { + logger := log.FromContext(ctx).WithName("model-discoverer") + + // Ollama has a completely different API - delegate to specialized function + if providerType == v1alpha2.ModelProviderOllama { + logger.V(1).Info("Discovering models from Ollama", "endpoint", endpoint) + return d.DiscoverOllamaModels(ctx, endpoint) + } + + modelsURL := buildModelsURL(endpoint, providerType) + logger.V(1).Info("Discovering models", "provider", providerType, "url", modelsURL) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers based on provider type + d.setAuthHeaders(req, providerType, apiKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + // Handle error responses + switch resp.StatusCode { + case http.StatusOK: + return d.parseModelsResponse(resp) + case http.StatusUnauthorized: + return nil, fmt.Errorf("unauthorized: invalid API key for provider %s", providerType) + case http.StatusForbidden: + return nil, fmt.Errorf("forbidden: API key lacks permission to list models for provider %s", providerType) + case http.StatusNotFound: + return nil, fmt.Errorf("models endpoint not found for provider %s (URL: %s)", providerType, modelsURL) + default: + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API returned status %d for provider %s: %s", resp.StatusCode, providerType, string(body)) + } +} + +// setAuthHeaders sets the appropriate authentication headers based on provider type. +func (d *ModelDiscoverer) setAuthHeaders(req *http.Request, providerType v1alpha2.ModelProvider, apiKey string) { + switch providerType { + case v1alpha2.ModelProviderAnthropic, v1alpha2.ModelProviderAnthropicVertexAI: + // Anthropic uses x-api-key header + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + case v1alpha2.ModelProviderGemini, v1alpha2.ModelProviderGeminiVertexAI: + // Google uses query parameter for API key (handled in URL) or Bearer token + req.Header.Set("Authorization", "Bearer "+apiKey) + default: + // OpenAI and compatible providers use Bearer token + req.Header.Set("Authorization", "Bearer "+apiKey) + } +} + +// parseModelsResponse parses OpenAI-compatible response and extracts model IDs. +// Format: {"data": [{"id": "model-name"}, ...]} +func (d *ModelDiscoverer) parseModelsResponse(resp *http.Response) ([]string, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var result openAIModelsResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse models response: %w", err) + } + + models := make([]string, 0, len(result.Data)) + for _, m := range result.Data { + if m.ID != "" { + models = append(models, m.ID) + } + } + + return models, nil +} + +// buildModelsURL constructs the models endpoint URL based on provider type. +// Note: Ollama is handled separately via DiscoverOllamaModels. +func buildModelsURL(endpoint string, providerType v1alpha2.ModelProvider) string { + endpoint = strings.TrimSuffix(endpoint, "/") + + switch providerType { + case v1alpha2.ModelProviderAnthropic: + // Anthropic: https://api.anthropic.com/v1/models + if strings.HasSuffix(endpoint, "/v1") { + return endpoint + "/models" + } + return endpoint + "/v1/models" + + case v1alpha2.ModelProviderGemini: + // Google AI: https://generativelanguage.googleapis.com/v1beta/models + if strings.Contains(endpoint, "generativelanguage.googleapis.com") { + return endpoint + "/v1beta/models" + } + return endpoint + "/v1/models" + + case v1alpha2.ModelProviderGeminiVertexAI, v1alpha2.ModelProviderAnthropicVertexAI: + // Vertex AI has different discovery patterns - may not be supported + return endpoint + "/v1/models" + + default: + // OpenAI and compatible (Azure OpenAI, LiteLLM, vLLM, etc.) + if strings.HasSuffix(endpoint, "/v1") { + return endpoint + "/models" + } + return endpoint + "/v1/models" + } +} + +// OllamaTagsResponse represents Ollama's /api/tags response format +type ollamaTagsResponse struct { + Models []struct { + Name string `json:"name"` + } `json:"models"` +} + +// DiscoverOllamaModels handles Ollama's different response format. +func (d *ModelDiscoverer) DiscoverOllamaModels(ctx context.Context, endpoint string) ([]string, error) { + endpoint = strings.TrimSuffix(endpoint, "/") + modelsURL := endpoint + "/api/tags" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ollama API returned status %d", resp.StatusCode) + } + + var result ollamaTagsResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse Ollama response: %w", err) + } + + models := make([]string, 0, len(result.Models)) + for _, m := range result.Models { + if m.Name != "" { + models = append(models, m.Name) + } + } + + return models, nil +} diff --git a/go/internal/controller/provider/discoverer_test.go b/go/internal/controller/provider/discoverer_test.go new file mode 100644 index 000000000..1336729f1 --- /dev/null +++ b/go/internal/controller/provider/discoverer_test.go @@ -0,0 +1,511 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package provider + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + v1alpha2 "github.com/kagent-dev/kagent/go/api/v1alpha2" +) + +func TestBuildModelsURL(t *testing.T) { + tests := []struct { + name string + endpoint string + providerType v1alpha2.ModelProvider + want string + }{ + // OpenAI + { + name: "OpenAI - base URL", + endpoint: "https://api.openai.com", + providerType: v1alpha2.ModelProviderOpenAI, + want: "https://api.openai.com/v1/models", + }, + { + name: "OpenAI - with v1", + endpoint: "https://api.openai.com/v1", + providerType: v1alpha2.ModelProviderOpenAI, + want: "https://api.openai.com/v1/models", + }, + { + name: "OpenAI - trailing slash", + endpoint: "https://api.openai.com/v1/", + providerType: v1alpha2.ModelProviderOpenAI, + want: "https://api.openai.com/v1/models", + }, + + // Anthropic + { + name: "Anthropic - base URL", + endpoint: "https://api.anthropic.com", + providerType: v1alpha2.ModelProviderAnthropic, + want: "https://api.anthropic.com/v1/models", + }, + { + name: "Anthropic - with v1", + endpoint: "https://api.anthropic.com/v1", + providerType: v1alpha2.ModelProviderAnthropic, + want: "https://api.anthropic.com/v1/models", + }, + + // Azure OpenAI + { + name: "Azure OpenAI", + endpoint: "https://my-resource.openai.azure.com", + providerType: v1alpha2.ModelProviderAzureOpenAI, + want: "https://my-resource.openai.azure.com/v1/models", + }, + + // Note: Ollama is handled via DiscoverOllamaModels delegation, + // so buildModelsURL is not called for Ollama providers. + // See TestDiscoverModels_OllamaDelegation for the integration test. + + // Gemini + { + name: "Gemini - googleapis", + endpoint: "https://generativelanguage.googleapis.com", + providerType: v1alpha2.ModelProviderGemini, + want: "https://generativelanguage.googleapis.com/v1beta/models", + }, + { + name: "Gemini - custom endpoint", + endpoint: "https://custom-gemini.example.com", + providerType: v1alpha2.ModelProviderGemini, + want: "https://custom-gemini.example.com/v1/models", + }, + + // LiteLLM / OpenAI-compatible + { + name: "LiteLLM gateway", + endpoint: "https://litellm.company.com/v1", + providerType: v1alpha2.ModelProviderOpenAI, + want: "https://litellm.company.com/v1/models", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildModelsURL(tt.endpoint, tt.providerType) + if got != tt.want { + t.Errorf("buildModelsURL() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDiscoverModels_OpenAI(t *testing.T) { + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request + if r.URL.Path != "/v1/models" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + + auth := r.Header.Get("Authorization") + if auth != "Bearer test-api-key" { + t.Errorf("unexpected Authorization header: %s", auth) + } + + // Return mock response + response := openAIModelsResponse{ + Data: []struct { + ID string `json:"id"` + }{ + {ID: "gpt-4"}, + {ID: "gpt-3.5-turbo"}, + {ID: "text-embedding-ada-002"}, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + d := NewModelDiscoverer() + models, err := d.DiscoverModels(context.Background(), v1alpha2.ModelProviderOpenAI, server.URL, "test-api-key") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(models) != 3 { + t.Errorf("expected 3 models, got %d", len(models)) + } + + expectedModels := map[string]bool{ + "gpt-4": true, + "gpt-3.5-turbo": true, + "text-embedding-ada-002": true, + } + + for _, model := range models { + if !expectedModels[model] { + t.Errorf("unexpected model: %s", model) + } + } +} + +func TestDiscoverModels_Anthropic(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify Anthropic-specific headers + apiKey := r.Header.Get("x-api-key") + if apiKey != "test-anthropic-key" { + t.Errorf("unexpected x-api-key header: %s", apiKey) + } + + version := r.Header.Get("anthropic-version") + if version != "2023-06-01" { + t.Errorf("unexpected anthropic-version header: %s", version) + } + + // Return mock response (same format as OpenAI) + response := openAIModelsResponse{ + Data: []struct { + ID string `json:"id"` + }{ + {ID: "claude-3-opus-20240229"}, + {ID: "claude-3-sonnet-20240229"}, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + d := NewModelDiscoverer() + models, err := d.DiscoverModels(context.Background(), v1alpha2.ModelProviderAnthropic, server.URL, "test-anthropic-key") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(models) != 2 { + t.Errorf("expected 2 models, got %d", len(models)) + } +} + +func TestDiscoverModels_ErrorResponses(t *testing.T) { + tests := []struct { + name string + statusCode int + responseBody string + wantErrContain string + }{ + { + name: "unauthorized", + statusCode: http.StatusUnauthorized, + responseBody: `{"error": "Invalid API key"}`, + wantErrContain: "unauthorized", + }, + { + name: "forbidden", + statusCode: http.StatusForbidden, + responseBody: `{"error": "Access denied"}`, + wantErrContain: "forbidden", + }, + { + name: "not found", + statusCode: http.StatusNotFound, + responseBody: `{"error": "Not found"}`, + wantErrContain: "not found", + }, + { + name: "server error", + statusCode: http.StatusInternalServerError, + responseBody: `{"error": "Internal server error"}`, + wantErrContain: "status 500", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + w.Write([]byte(tt.responseBody)) + })) + defer server.Close() + + d := NewModelDiscoverer() + _, err := d.DiscoverModels(context.Background(), v1alpha2.ModelProviderOpenAI, server.URL, "test-key") + + if err == nil { + t.Error("expected error but got nil") + return + } + + if !strings.Contains(strings.ToLower(err.Error()), tt.wantErrContain) { + t.Errorf("error = %v, want error containing %v", err, tt.wantErrContain) + } + }) + } +} + +func TestDiscoverModels_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("not valid json")) + })) + defer server.Close() + + d := NewModelDiscoverer() + _, err := d.DiscoverModels(context.Background(), v1alpha2.ModelProviderOpenAI, server.URL, "test-key") + + if err == nil { + t.Error("expected error for invalid JSON") + } + + if !strings.Contains(err.Error(), "failed to parse") { + t.Errorf("error = %v, want error containing 'failed to parse'", err) + } +} + +func TestDiscoverModels_EmptyResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := openAIModelsResponse{ + Data: []struct { + ID string `json:"id"` + }{}, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + d := NewModelDiscoverer() + models, err := d.DiscoverModels(context.Background(), v1alpha2.ModelProviderOpenAI, server.URL, "test-key") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(models) != 0 { + t.Errorf("expected 0 models, got %d", len(models)) + } +} + +func TestDiscoverModels_FilterEmptyIDs(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := openAIModelsResponse{ + Data: []struct { + ID string `json:"id"` + }{ + {ID: "gpt-4"}, + {ID: ""}, // Empty ID should be filtered + {ID: "gpt-3"}, // Valid + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + d := NewModelDiscoverer() + models, err := d.DiscoverModels(context.Background(), v1alpha2.ModelProviderOpenAI, server.URL, "test-key") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(models) != 2 { + t.Errorf("expected 2 models (empty ID filtered), got %d", len(models)) + } +} + +func TestDiscoverOllamaModels(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/tags" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + + response := ollamaTagsResponse{ + Models: []struct { + Name string `json:"name"` + }{ + {Name: "llama2"}, + {Name: "mistral"}, + {Name: "codellama"}, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + d := NewModelDiscoverer() + models, err := d.DiscoverOllamaModels(context.Background(), server.URL) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(models) != 3 { + t.Errorf("expected 3 models, got %d", len(models)) + } + + expectedModels := map[string]bool{ + "llama2": true, + "mistral": true, + "codellama": true, + } + + for _, model := range models { + if !expectedModels[model] { + t.Errorf("unexpected model: %s", model) + } + } +} + +// TestDiscoverModels_OllamaDelegation verifies that DiscoverModels correctly +// delegates to DiscoverOllamaModels when the provider type is Ollama. +func TestDiscoverModels_OllamaDelegation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Ollama should call /api/tags, not /v1/models + if r.URL.Path != "/api/tags" { + t.Errorf("DiscoverModels with Ollama should call /api/tags, got: %s", r.URL.Path) + } + + response := ollamaTagsResponse{ + Models: []struct { + Name string `json:"name"` + }{ + {Name: "llama2"}, + {Name: "mistral"}, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + d := NewModelDiscoverer() + // Use DiscoverModels (not DiscoverOllamaModels directly) to test delegation + models, err := d.DiscoverModels(context.Background(), v1alpha2.ModelProviderOllama, server.URL, "") + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(models) != 2 { + t.Errorf("expected 2 models, got %d", len(models)) + } + + expectedModels := map[string]bool{ + "llama2": true, + "mistral": true, + } + + for _, model := range models { + if !expectedModels[model] { + t.Errorf("unexpected model: %s", model) + } + } +} + +func TestNewModelDiscoverer(t *testing.T) { + d := NewModelDiscoverer() + + if d == nil { + t.Fatal("NewModelDiscoverer returned nil") + } + + if d.httpClient == nil { + t.Fatal("httpClient should be initialized") + } + + if d.httpClient.Timeout != DefaultTimeout { + t.Errorf("httpClient timeout = %v, want %v", d.httpClient.Timeout, DefaultTimeout) + } +} + +func TestSetAuthHeaders(t *testing.T) { + d := NewModelDiscoverer() + + tests := []struct { + name string + providerType v1alpha2.ModelProvider + apiKey string + wantAuthz string + wantAPIKey string + wantAnthVer string + }{ + { + name: "OpenAI", + providerType: v1alpha2.ModelProviderOpenAI, + apiKey: "sk-test", + wantAuthz: "Bearer sk-test", + }, + { + name: "Azure OpenAI", + providerType: v1alpha2.ModelProviderAzureOpenAI, + apiKey: "azure-key", + wantAuthz: "Bearer azure-key", + }, + { + name: "Anthropic", + providerType: v1alpha2.ModelProviderAnthropic, + apiKey: "anth-key", + wantAPIKey: "anth-key", + wantAnthVer: "2023-06-01", + }, + { + name: "Gemini", + providerType: v1alpha2.ModelProviderGemini, + apiKey: "gemini-key", + wantAuthz: "Bearer gemini-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + d.setAuthHeaders(req, tt.providerType, tt.apiKey) + + if tt.wantAuthz != "" { + got := req.Header.Get("Authorization") + if got != tt.wantAuthz { + t.Errorf("Authorization = %v, want %v", got, tt.wantAuthz) + } + } + + if tt.wantAPIKey != "" { + got := req.Header.Get("x-api-key") + if got != tt.wantAPIKey { + t.Errorf("x-api-key = %v, want %v", got, tt.wantAPIKey) + } + } + + if tt.wantAnthVer != "" { + got := req.Header.Get("anthropic-version") + if got != tt.wantAnthVer { + t.Errorf("anthropic-version = %v, want %v", got, tt.wantAnthVer) + } + } + }) + } +} diff --git a/go/internal/controller/provider/manager.go b/go/internal/controller/provider/manager.go new file mode 100644 index 000000000..39b439168 --- /dev/null +++ b/go/internal/controller/provider/manager.go @@ -0,0 +1,148 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package provider + +import ( + "context" + "fmt" + "time" + + "github.com/kagent-dev/kagent/go/api/v1alpha2" + "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + // DefaultNamespace is the default namespace for kagent resources + DefaultNamespace = "kagent" +) + +// Manager handles provider configuration and model discovery. +// It reads provider configs from Provider CRDs and caches discovered models in CRD status. +type Manager struct { + client client.Client + namespace string +} + +// NewManager creates a new provider Manager instance. +func NewManager(client client.Client, namespace string) *Manager { + if namespace == "" { + namespace = DefaultNamespace + } + + return &Manager{ + client: client, + namespace: namespace, + } +} + +// GetProviders returns all configured providers that are Ready. +func (m *Manager) GetProviders() []ProviderConfig { + ctx := context.Background() + var providerList v1alpha2.ProviderList + if err := m.client.List(ctx, &providerList, + client.InNamespace(m.namespace)); err != nil { + return nil + } + + var providers []ProviderConfig + for _, p := range providerList.Items { + // Only include Ready providers + if meta.IsStatusConditionTrue(p.Status.Conditions, + v1alpha2.ProviderConditionTypeReady) { + providers = append(providers, ProviderConfig{ + Name: p.Name, + Type: p.Spec.Type, + Endpoint: p.Spec.Endpoint, + SecretRef: SecretReference{ + Name: p.Spec.SecretRef.Name, + Key: p.Spec.SecretRef.Key, + }, + }) + } + } + + return providers +} + +// GetModels returns models for a provider from the cached status or triggers discovery. +// Models are cached in Provider.Status.DiscoveredModels by the Provider controller. +// If forceRefresh is true, sets the force-discovery annotation to trigger controller reconciliation. +func (m *Manager) GetModels(ctx context.Context, providerName string, forceRefresh bool) ([]string, error) { + logger := log.FromContext(ctx).WithName("provider-manager") + + provider := &v1alpha2.Provider{} + if err := m.client.Get(ctx, types.NamespacedName{ + Name: providerName, Namespace: m.namespace, + }, provider); err != nil { + return nil, fmt.Errorf("provider %s not found: %w", providerName, err) + } + + // Force refresh by setting annotation + if forceRefresh { + if provider.Annotations == nil { + provider.Annotations = make(map[string]string) + } + provider.Annotations[v1alpha2.ProviderAnnotationForceDiscovery] = "true" + if err := m.client.Update(ctx, provider); err != nil { + return nil, fmt.Errorf("failed to trigger discovery: %w", err) + } + + // Wait for controller to refresh (with timeout) + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + time.Sleep(500 * time.Millisecond) + if err := m.client.Get(ctx, client.ObjectKeyFromObject(provider), provider); err == nil { + // Check if annotation cleared (refresh done) + if provider.Annotations[v1alpha2.ProviderAnnotationForceDiscovery] != "true" { + logger.Info("Model discovery completed", "provider", providerName) + break + } + } + } + + // Re-fetch provider to get updated status + if err := m.client.Get(ctx, client.ObjectKeyFromObject(provider), provider); err != nil { + return nil, fmt.Errorf("failed to get updated provider: %w", err) + } + } + + // Return cached models from status + if len(provider.Status.DiscoveredModels) > 0 { + return provider.Status.DiscoveredModels, nil + } + + // No models discovered - provide helpful message + if forceRefresh { + return nil, fmt.Errorf("no models discovered for provider %s", providerName) + } + return nil, fmt.Errorf("no models discovered for provider %s, try refreshing", providerName) +} + +// ClearCache is a no-op for CRD-based implementation. +// Models are cached in Provider.Status.DiscoveredModels and cleared by the controller. +func (m *Manager) ClearCache(providerName string) { + // No-op - cache is managed by Provider controller in CRD status +} + +// HasProviders returns true if any Ready providers are configured. +func (m *Manager) HasProviders() bool { + providers := m.GetProviders() + return len(providers) > 0 +} diff --git a/go/internal/controller/provider/manager_test.go b/go/internal/controller/provider/manager_test.go new file mode 100644 index 000000000..0da10ffd9 --- /dev/null +++ b/go/internal/controller/provider/manager_test.go @@ -0,0 +1,403 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package provider + +import ( + "context" + "testing" + + v1alpha2 "github.com/kagent-dev/kagent/go/api/v1alpha2" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestNewManager(t *testing.T) { + tests := []struct { + name string + namespace string + wantNamespace string + }{ + { + name: "with default namespace", + namespace: "", + wantNamespace: DefaultNamespace, + }, + { + name: "with custom namespace", + namespace: "custom-ns", + wantNamespace: "custom-ns", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = v1alpha2.AddToScheme(scheme) + _ = corev1.AddToScheme(scheme) + client := fake.NewClientBuilder().WithScheme(scheme).Build() + + m := NewManager(client, tt.namespace) + + if m.namespace != tt.wantNamespace { + t.Errorf("namespace = %v, want %v", m.namespace, tt.wantNamespace) + } + if m.client == nil { + t.Error("client should be initialized") + } + }) + } +} + +func TestGetProviders(t *testing.T) { + tests := []struct { + name string + providers []*v1alpha2.Provider + wantCount int + wantNames []string + }{ + { + name: "no providers", + providers: []*v1alpha2.Provider{}, + wantCount: 0, + wantNames: []string{}, + }, + { + name: "single ready provider", + providers: []*v1alpha2.Provider{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "openai-prod", + Namespace: DefaultNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOpenAI, + Endpoint: "https://api.openai.com/v1", + SecretRef: v1alpha2.SecretReference{ + Name: "openai-secret", + Key: "apiKey", + }, + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + { + Type: v1alpha2.ProviderConditionTypeReady, + Status: metav1.ConditionTrue, + }, + }, + }, + }, + }, + wantCount: 1, + wantNames: []string{"openai-prod"}, + }, + { + name: "mixed ready and not ready providers", + providers: []*v1alpha2.Provider{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "openai-prod", + Namespace: DefaultNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOpenAI, + Endpoint: "https://api.openai.com/v1", + SecretRef: v1alpha2.SecretReference{ + Name: "openai-secret", + Key: "apiKey", + }, + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + { + Type: v1alpha2.ProviderConditionTypeReady, + Status: metav1.ConditionTrue, + }, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "anthropic-prod", + Namespace: DefaultNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderAnthropic, + Endpoint: "https://api.anthropic.com", + SecretRef: v1alpha2.SecretReference{ + Name: "anthropic-secret", + Key: "apiKey", + }, + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + { + Type: v1alpha2.ProviderConditionTypeReady, + Status: metav1.ConditionFalse, + }, + }, + }, + }, + }, + wantCount: 1, // Only ready provider should be returned + wantNames: []string{"openai-prod"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = v1alpha2.AddToScheme(scheme) + _ = corev1.AddToScheme(scheme) + + objs := make([]runtime.Object, len(tt.providers)) + for i, p := range tt.providers { + objs[i] = p + } + + client := fake.NewClientBuilder(). + WithScheme(scheme). + WithRuntimeObjects(objs...). + Build() + + m := NewManager(client, DefaultNamespace) + providers := m.GetProviders() + + if len(providers) != tt.wantCount { + t.Errorf("GetProviders() returned %d providers, want %d", len(providers), tt.wantCount) + } + + for _, wantName := range tt.wantNames { + found := false + for _, p := range providers { + if p.Name == wantName { + found = true + break + } + } + if !found { + t.Errorf("Expected provider %s not found in results", wantName) + } + } + }) + } +} + +func TestGetModels(t *testing.T) { + tests := []struct { + name string + provider *v1alpha2.Provider + forceRefresh bool + wantModels []string + wantErr bool + }{ + { + name: "provider with cached models", + provider: &v1alpha2.Provider{ + ObjectMeta: metav1.ObjectMeta{ + Name: "openai-prod", + Namespace: DefaultNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOpenAI, + Endpoint: "https://api.openai.com/v1", + SecretRef: v1alpha2.SecretReference{ + Name: "openai-secret", + Key: "apiKey", + }, + }, + Status: v1alpha2.ProviderStatus{ + DiscoveredModels: []string{"gpt-4", "gpt-3.5-turbo"}, + Conditions: []metav1.Condition{ + { + Type: v1alpha2.ProviderConditionTypeReady, + Status: metav1.ConditionTrue, + }, + }, + }, + }, + forceRefresh: false, + wantModels: []string{"gpt-4", "gpt-3.5-turbo"}, + wantErr: false, + }, + { + name: "provider not found", + provider: &v1alpha2.Provider{ + ObjectMeta: metav1.ObjectMeta{ + Name: "different-provider", + Namespace: DefaultNamespace, + }, + }, + forceRefresh: false, + wantModels: nil, + wantErr: true, + }, + { + name: "provider without models", + provider: &v1alpha2.Provider{ + ObjectMeta: metav1.ObjectMeta{ + Name: "empty-provider", + Namespace: DefaultNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOpenAI, + Endpoint: "https://api.openai.com/v1", + SecretRef: v1alpha2.SecretReference{ + Name: "openai-secret", + Key: "apiKey", + }, + }, + Status: v1alpha2.ProviderStatus{ + DiscoveredModels: []string{}, + }, + }, + forceRefresh: false, + wantModels: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = v1alpha2.AddToScheme(scheme) + _ = corev1.AddToScheme(scheme) + + client := fake.NewClientBuilder(). + WithScheme(scheme). + WithRuntimeObjects(tt.provider). + Build() + + m := NewManager(client, DefaultNamespace) + + // Use the provider name from the test case + providerName := "openai-prod" + switch tt.name { + case "provider not found": + providerName = "nonexistent-provider" + case "provider without models": + providerName = "empty-provider" + } + + models, err := m.GetModels(context.Background(), providerName, tt.forceRefresh) + + if (err != nil) != tt.wantErr { + t.Errorf("GetModels() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && len(models) != len(tt.wantModels) { + t.Errorf("GetModels() returned %d models, want %d", len(models), len(tt.wantModels)) + } + }) + } +} + +func TestHasProviders(t *testing.T) { + tests := []struct { + name string + providers []*v1alpha2.Provider + want bool + }{ + { + name: "no providers", + providers: []*v1alpha2.Provider{}, + want: false, + }, + { + name: "has ready provider", + providers: []*v1alpha2.Provider{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "openai-prod", + Namespace: DefaultNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOpenAI, + Endpoint: "https://api.openai.com/v1", + SecretRef: v1alpha2.SecretReference{ + Name: "openai-secret", + Key: "apiKey", + }, + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + { + Type: v1alpha2.ProviderConditionTypeReady, + Status: metav1.ConditionTrue, + }, + }, + }, + }, + }, + want: true, + }, + { + name: "has only not ready provider", + providers: []*v1alpha2.Provider{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "anthropic-prod", + Namespace: DefaultNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderAnthropic, + Endpoint: "https://api.anthropic.com", + SecretRef: v1alpha2.SecretReference{ + Name: "anthropic-secret", + Key: "apiKey", + }, + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + { + Type: v1alpha2.ProviderConditionTypeReady, + Status: metav1.ConditionFalse, + }, + }, + }, + }, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = v1alpha2.AddToScheme(scheme) + _ = corev1.AddToScheme(scheme) + + objs := make([]runtime.Object, len(tt.providers)) + for i, p := range tt.providers { + objs[i] = p + } + + client := fake.NewClientBuilder(). + WithScheme(scheme). + WithRuntimeObjects(objs...). + Build() + + m := NewManager(client, DefaultNamespace) + + if got := m.HasProviders(); got != tt.want { + t.Errorf("HasProviders() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/go/internal/controller/provider/types.go b/go/internal/controller/provider/types.go new file mode 100644 index 000000000..728a4c8e6 --- /dev/null +++ b/go/internal/controller/provider/types.go @@ -0,0 +1,59 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package provider + +import ( + v1alpha2 "github.com/kagent-dev/kagent/go/api/v1alpha2" +) + +// ProviderConfig represents a configured LLM provider instance. +// Multiple ProviderConfigs can exist for the same provider type (e.g., two OpenAI instances). +type ProviderConfig struct { + // Name is the unique identifier for this provider instance + Name string `yaml:"name" json:"name"` + + // Type is the provider type (OpenAI, Anthropic, etc.) + Type v1alpha2.ModelProvider `yaml:"type" json:"type"` + + // Endpoint is the base URL for the provider API + Endpoint string `yaml:"endpoint" json:"endpoint"` + + // SecretRef references the Kubernetes Secret containing the API key + SecretRef SecretReference `yaml:"secretRef" json:"secretRef"` +} + +// SecretReference points to a specific key within a Kubernetes Secret +type SecretReference struct { + // Name is the name of the Secret + Name string `yaml:"name" json:"name"` + + // Key is the key within the Secret data + Key string `yaml:"key" json:"key"` +} + +// ProviderResponse is the API response format for listing providers +type ProviderResponse struct { + Name string `json:"name"` + Type string `json:"type"` + Endpoint string `json:"endpoint"` +} + +// ModelsResponse is the API response format for listing models +type ModelsResponse struct { + Provider string `json:"provider"` + Models []string `json:"models"` +} diff --git a/go/internal/controller/provider_controller.go b/go/internal/controller/provider_controller.go new file mode 100644 index 000000000..f2644ae84 --- /dev/null +++ b/go/internal/controller/provider_controller.go @@ -0,0 +1,117 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + + "github.com/kagent-dev/kagent/go/api/v1alpha2" + "github.com/kagent-dev/kagent/go/internal/controller/reconciler" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/builder" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/predicate" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +var ( + providerControllerLog = ctrl.Log.WithName("provider-controller") +) + +// ProviderController reconciles a Provider object +type ProviderController struct { + Scheme *runtime.Scheme + Reconciler reconciler.KagentReconciler +} + +// +kubebuilder:rbac:groups=kagent.dev,resources=providers,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=kagent.dev,resources=providers/status,verbs=get;update;patch +// +kubebuilder:rbac:groups=kagent.dev,resources=providers/finalizers,verbs=update +// +kubebuilder:rbac:groups=core,resources=secrets,verbs=get;list;watch + +func (r *ProviderController) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + _ = log.FromContext(ctx) + return r.Reconciler.ReconcileKagentProvider(ctx, req) +} + +// SetupWithManager sets up the controller with the Manager. +func (r *ProviderController) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + WithOptions(controller.Options{ + NeedLeaderElection: ptr.To(true), + }). + For(&v1alpha2.Provider{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})). + Watches( + &corev1.Secret{}, + handler.EnqueueRequestsFromMapFunc(func(ctx context.Context, obj client.Object) []reconcile.Request { + requests := []reconcile.Request{} + + for _, provider := range r.findProvidersUsingSecret(ctx, mgr.GetClient(), types.NamespacedName{ + Name: obj.GetName(), + Namespace: obj.GetNamespace(), + }) { + requests = append(requests, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Name: provider.ObjectMeta.Name, + Namespace: provider.ObjectMeta.Namespace, + }, + }) + } + + return requests + }), + builder.WithPredicates(predicate.ResourceVersionChangedPredicate{}), + ). + Named("provider"). + Complete(r) +} + +func (r *ProviderController) findProvidersUsingSecret(ctx context.Context, cl client.Client, obj types.NamespacedName) []*v1alpha2.Provider { + var providers []*v1alpha2.Provider + + var providersList v1alpha2.ProviderList + if err := cl.List( + ctx, + &providersList, + ); err != nil { + providerControllerLog.Error(err, "failed to list Providers in order to reconcile Secret update") + return providers + } + + for i := range providersList.Items { + provider := &providersList.Items[i] + + if providerReferencesSecret(provider, obj) { + providers = append(providers, provider) + } + } + + return providers +} + +func providerReferencesSecret(provider *v1alpha2.Provider, secretObj types.NamespacedName) bool { + // Secrets must be in the same namespace as the provider + return provider.Namespace == secretObj.Namespace && + provider.Spec.SecretRef.Name == secretObj.Name +} diff --git a/go/internal/controller/reconciler/reconciler.go b/go/internal/controller/reconciler/reconciler.go index a51dd9b6f..359ddaba7 100644 --- a/go/internal/controller/reconciler/reconciler.go +++ b/go/internal/controller/reconciler/reconciler.go @@ -22,6 +22,7 @@ import ( "k8s.io/client-go/util/retry" "github.com/kagent-dev/kagent/go/api/v1alpha2" + "github.com/kagent-dev/kagent/go/internal/controller/provider" "github.com/kagent-dev/kagent/go/internal/controller/translator" agent_translator "github.com/kagent-dev/kagent/go/internal/controller/translator/agent" "github.com/kagent-dev/kagent/go/internal/database" @@ -45,6 +46,7 @@ type KagentReconciler interface { ReconcileKagentRemoteMCPServer(ctx context.Context, req ctrl.Request) error ReconcileKagentMCPService(ctx context.Context, req ctrl.Request) error ReconcileKagentMCPServer(ctx context.Context, req ctrl.Request) error + ReconcileKagentProvider(ctx context.Context, req ctrl.Request) (ctrl.Result, error) GetOwnedResourceTypes() []client.Object } @@ -910,3 +912,194 @@ func convertTool(tool *database.Tool) (*v1alpha2.MCPTool, error) { Description: tool.Description, }, nil } + +// ReconcileKagentProvider reconciles a Provider object +func (a *kagentReconciler) ReconcileKagentProvider(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + p := &v1alpha2.Provider{} + if err := a.kube.Get(ctx, req.NamespacedName, p); err != nil { + if apierrors.IsNotFound(err) { + return ctrl.Result{}, nil // Deleted, cleanup done by OwnerReferences + } + return ctrl.Result{}, fmt.Errorf("failed to get provider %s: %w", req.NamespacedName, err) + } + + // Validate and resolve secret + secret, secretHash, secretErr := a.validateProviderSecret(ctx, p) + + // Discover models if needed + var models []string + var discoveryErr error + if a.shouldDiscoverModels(p) { + models, discoveryErr = a.discoverProviderModels(ctx, p, secret) + } else { + // Keep existing cached models + models = p.Status.DiscoveredModels + } + + // Update status with results + return a.updateProviderStatus(ctx, p, secretErr, discoveryErr, models, secretHash) +} + +// validateProviderSecret fetches the Secret and computes its hash +func (a *kagentReconciler) validateProviderSecret(ctx context.Context, p *v1alpha2.Provider) (*corev1.Secret, string, error) { + if p.Spec.SecretRef.Name == "" { + return nil, "", fmt.Errorf("provider %s has no secret reference", p.Name) + } + + secret := &corev1.Secret{} + namespacedName := types.NamespacedName{ + Namespace: p.Namespace, + Name: p.Spec.SecretRef.Name, + } + + if err := a.kube.Get(ctx, namespacedName, secret); err != nil { + return nil, "", fmt.Errorf("failed to get secret %s: %w", p.Spec.SecretRef.Name, err) + } + + // Check that the specified key exists + key := p.Spec.SecretRef.Key + if key == "" { + return nil, "", fmt.Errorf("provider %s has no secret key specified", p.Name) + } + + if _, ok := secret.Data[key]; !ok { + return nil, "", fmt.Errorf("secret %s missing key %s", p.Spec.SecretRef.Name, key) + } + + // Compute secret hash for change detection + secretHash := computeProviderSecretHash(secret, key) + + return secret, secretHash, nil +} + +// computeProviderSecretHash computes a hash of the secret data for change detection +func computeProviderSecretHash(secret *corev1.Secret, key string) string { + hash := sha256.New() + hash.Write([]byte(secret.Namespace)) + hash.Write([]byte(secret.Name)) + hash.Write([]byte(key)) + if data, ok := secret.Data[key]; ok { + hash.Write(data) + } + return hex.EncodeToString(hash.Sum(nil)) +} + +// shouldDiscoverModels checks if model discovery is needed +func (a *kagentReconciler) shouldDiscoverModels(p *v1alpha2.Provider) bool { + // 1. Force refresh via annotation (UI "Fetch Models" button) + if p.Annotations != nil && p.Annotations[v1alpha2.ProviderAnnotationForceDiscovery] == "true" { + return true + } + + // 2. Initial discovery when Provider is first created + if p.Status.LastDiscoveryTime == nil { + return true + } + + // No periodic discovery - only on-demand + return false +} + +// discoverProviderModels calls the model discoverer to fetch models +func (a *kagentReconciler) discoverProviderModels(ctx context.Context, p *v1alpha2.Provider, secret *corev1.Secret) ([]string, error) { + if secret == nil { + return nil, fmt.Errorf("cannot discover models: secret not available") + } + + // Get API key from secret + apiKey, ok := secret.Data[p.Spec.SecretRef.Key] + if !ok || len(apiKey) == 0 { + return nil, fmt.Errorf("secret %s has empty value for key %s", p.Spec.SecretRef.Name, p.Spec.SecretRef.Key) + } + + // Use the provider package's ModelDiscoverer + discoverer := provider.NewModelDiscoverer() + return discoverer.DiscoverModels(ctx, p.Spec.Type, p.Spec.Endpoint, string(apiKey)) +} + +// updateProviderStatus updates the Provider status based on reconciliation results +func (a *kagentReconciler) updateProviderStatus( + ctx context.Context, + p *v1alpha2.Provider, + secretErr, discoveryErr error, + models []string, + secretHash string, +) (ctrl.Result, error) { + // Update SecretResolved condition + meta.SetStatusCondition(&p.Status.Conditions, metav1.Condition{ + Type: v1alpha2.ProviderConditionTypeSecretResolved, + Status: conditionStatus(secretErr == nil), + Reason: conditionReason(secretErr, "SecretResolved", "SecretNotFound"), + Message: conditionMessage(secretErr, "Secret resolved successfully"), + ObservedGeneration: p.Generation, + }) + + // Update ModelsDiscovered condition + modelsDiscovered := discoveryErr == nil && len(models) > 0 + meta.SetStatusCondition(&p.Status.Conditions, metav1.Condition{ + Type: v1alpha2.ProviderConditionTypeModelsDiscovered, + Status: conditionStatus(modelsDiscovered), + Reason: conditionReason(discoveryErr, "ModelsDiscovered", "DiscoveryFailed"), + Message: fmt.Sprintf("Discovered %d models", len(models)), + ObservedGeneration: p.Generation, + }) + + // Update Ready condition (overall health) + ready := secretErr == nil && modelsDiscovered + meta.SetStatusCondition(&p.Status.Conditions, metav1.Condition{ + Type: v1alpha2.ProviderConditionTypeReady, + Status: conditionStatus(ready), + Reason: conditionReason(nil, "Ready", "NotReady"), + Message: conditionMessage(nil, "Provider is ready"), + ObservedGeneration: p.Generation, + }) + + // Update status fields + p.Status.ObservedGeneration = p.Generation + p.Status.DiscoveredModels = models + p.Status.ModelCount = len(models) + p.Status.SecretHash = secretHash + + if discoveryErr == nil && len(models) > 0 { + now := metav1.Now() + p.Status.LastDiscoveryTime = &now + } + + // Clear force-discovery annotation if set + if p.Annotations != nil && p.Annotations[v1alpha2.ProviderAnnotationForceDiscovery] == "true" { + delete(p.Annotations, v1alpha2.ProviderAnnotationForceDiscovery) + if err := a.kube.Update(ctx, p); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to clear force-discovery annotation: %w", err) + } + } + + // Update status subresource + if err := a.kube.Status().Update(ctx, p); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update provider status: %w", err) + } + + // No periodic requeue - discovery only on-demand + return ctrl.Result{}, nil +} + +// Helper functions for condition status +func conditionStatus(isTrue bool) metav1.ConditionStatus { + if isTrue { + return metav1.ConditionTrue + } + return metav1.ConditionFalse +} + +func conditionReason(err error, successReason, failureReason string) string { + if err == nil { + return successReason + } + return failureReason +} + +func conditionMessage(err error, successMessage string) string { + if err != nil { + return err.Error() + } + return successMessage +} diff --git a/go/internal/controller/service_controller_test.go b/go/internal/controller/service_controller_test.go index 041007dd0..4a25d4838 100644 --- a/go/internal/controller/service_controller_test.go +++ b/go/internal/controller/service_controller_test.go @@ -39,6 +39,10 @@ func (f *fakeServiceReconciler) ReconcileKagentRemoteMCPServer(ctx context.Conte return nil } +func (f *fakeServiceReconciler) ReconcileKagentProvider(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + return ctrl.Result{}, nil +} + func (f *fakeServiceReconciler) GetOwnedResourceTypes() []client.Object { return nil } diff --git a/go/internal/httpserver/handlers/handlers.go b/go/internal/httpserver/handlers/handlers.go index c2c4aa785..70e4b276b 100644 --- a/go/internal/httpserver/handlers/handlers.go +++ b/go/internal/httpserver/handlers/handlers.go @@ -4,6 +4,7 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/kagent-dev/kagent/go/internal/controller/provider" "github.com/kagent-dev/kagent/go/internal/database" "github.com/kagent-dev/kagent/go/pkg/auth" ) @@ -36,8 +37,9 @@ type Base struct { ProxyURL string } -// NewHandlers creates a new Handlers instance with all handler components -func NewHandlers(kubeClient client.Client, defaultModelConfig types.NamespacedName, dbService database.Client, watchedNamespaces []string, authorizer auth.Authorizer, proxyURL string) *Handlers { +// NewHandlers creates a new Handlers instance with all handler components. +// providerManager can be nil if provider discovery is not enabled. +func NewHandlers(kubeClient client.Client, defaultModelConfig types.NamespacedName, dbService database.Client, watchedNamespaces []string, authorizer auth.Authorizer, proxyURL string, providerManager *provider.Manager) *Handlers { base := &Base{ KubeClient: kubeClient, DefaultModelConfig: defaultModelConfig, @@ -50,7 +52,7 @@ func NewHandlers(kubeClient client.Client, defaultModelConfig types.NamespacedNa Health: NewHealthHandler(), ModelConfig: NewModelConfigHandler(base), Model: NewModelHandler(base), - Provider: NewProviderHandler(base), + Provider: NewProviderHandler(base, providerManager), Sessions: NewSessionsHandler(base), Agents: NewAgentsHandler(base), Tools: NewToolsHandler(base), diff --git a/go/internal/httpserver/handlers/providers.go b/go/internal/httpserver/handlers/providers.go index da4b2afa2..083dd028c 100644 --- a/go/internal/httpserver/handlers/providers.go +++ b/go/internal/httpserver/handlers/providers.go @@ -6,6 +6,7 @@ import ( "github.com/kagent-dev/kagent/go/api/v1alpha1" "github.com/kagent-dev/kagent/go/api/v1alpha2" + "github.com/kagent-dev/kagent/go/internal/controller/provider" "github.com/kagent-dev/kagent/go/pkg/client/api" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -13,11 +14,15 @@ import ( // ProviderHandler handles provider requests type ProviderHandler struct { *Base + providerManager *provider.Manager } // NewProviderHandler creates a new ProviderHandler -func NewProviderHandler(base *Base) *ProviderHandler { - return &ProviderHandler{Base: base} +func NewProviderHandler(base *Base, providerManager *provider.Manager) *ProviderHandler { + return &ProviderHandler{ + Base: base, + providerManager: providerManager, + } } // Helper function to get JSON keys specifically marked as required @@ -135,3 +140,75 @@ func (h *ProviderHandler) HandleListSupportedModelProviders(w ErrorResponseWrite data := api.NewResponse(providersResponse, "Successfully listed supported model providers", false) RespondWithJSON(w, http.StatusOK, data) } + +// HandleListConfiguredProviders returns the list of providers configured via Provider CRDs. +// GET /api/providers/configured +func (h *ProviderHandler) HandleListConfiguredProviders(w ErrorResponseWriter, r *http.Request) { + log := ctrllog.FromContext(r.Context()).WithName("provider-handler").WithValues("operation", "list-configured-providers") + + log.Info("Listing configured providers") + + if h.providerManager == nil { + log.Info("Provider manager not initialized") + data := api.NewResponse([]provider.ProviderResponse{}, "Provider discovery not enabled", false) + RespondWithJSON(w, http.StatusOK, data) + return + } + + providers := h.providerManager.GetProviders() + + // Transform to API response format (hide sensitive data like secretRef) + response := make([]provider.ProviderResponse, len(providers)) + for i, p := range providers { + response[i] = provider.ProviderResponse{ + Name: p.Name, + Type: string(p.Type), + Endpoint: p.Endpoint, + } + } + + log.Info("Successfully listed configured providers", "count", len(response)) + data := api.NewResponse(response, "Successfully listed configured providers", false) + RespondWithJSON(w, http.StatusOK, data) +} + +// HandleGetProviderModels discovers and returns available models for a specific provider. +// GET /api/providers/configured/{name}/models?refresh=true +func (h *ProviderHandler) HandleGetProviderModels(w ErrorResponseWriter, r *http.Request) { + log := ctrllog.FromContext(r.Context()).WithName("provider-handler").WithValues("operation", "get-provider-models") + + providerName, err := GetPathParam(r, "name") + if err != nil { + log.Info("Missing provider name parameter") + RespondWithError(w, http.StatusBadRequest, "Provider name is required") + return + } + + log = log.WithValues("provider", providerName) + log.Info("Getting models for provider") + + if h.providerManager == nil { + log.Info("Provider manager not initialized") + RespondWithError(w, http.StatusServiceUnavailable, "Provider discovery not enabled") + return + } + + // Check for refresh query parameter + forceRefresh := r.URL.Query().Get("refresh") == "true" + + models, err := h.providerManager.GetModels(r.Context(), providerName, forceRefresh) + if err != nil { + log.Error(err, "Failed to get models for provider") + RespondWithError(w, http.StatusInternalServerError, err.Error()) + return + } + + response := provider.ModelsResponse{ + Provider: providerName, + Models: models, + } + + log.Info("Successfully retrieved models for provider", "count", len(models)) + data := api.NewResponse(response, "Successfully retrieved models", false) + RespondWithJSON(w, http.StatusOK, data) +} diff --git a/go/internal/httpserver/server.go b/go/internal/httpserver/server.go index 5d992ec56..16fd366fc 100644 --- a/go/internal/httpserver/server.go +++ b/go/internal/httpserver/server.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/mux" "github.com/kagent-dev/kagent/go/internal/a2a" + "github.com/kagent-dev/kagent/go/internal/controller/provider" "github.com/kagent-dev/kagent/go/internal/database" "github.com/kagent-dev/kagent/go/internal/httpserver/handlers" "github.com/kagent-dev/kagent/go/internal/mcp" @@ -59,6 +60,7 @@ type ServerConfig struct { Authenticator auth.AuthProvider Authorizer auth.Authorizer ProxyURL string + ProviderManager *provider.Manager // Optional: enables provider discovery endpoints } // HTTPServer is the structure that manages the HTTP server @@ -78,7 +80,7 @@ func NewHTTPServer(config ServerConfig) (*HTTPServer, error) { return &HTTPServer{ config: config, router: config.Router, - handlers: handlers.NewHandlers(config.KubeClient, defaultModelConfig, config.DbClient, config.WatchedNamespaces, config.Authorizer, config.ProxyURL), + handlers: handlers.NewHandlers(config.KubeClient, defaultModelConfig, config.DbClient, config.WatchedNamespaces, config.Authorizer, config.ProxyURL, config.ProviderManager), authenticator: config.Authenticator, }, nil } @@ -194,6 +196,8 @@ func (s *HTTPServer) setupRoutes() { // Providers s.router.HandleFunc(APIPathProviders+"/models", adaptHandler(s.handlers.Provider.HandleListSupportedModelProviders)).Methods(http.MethodGet) s.router.HandleFunc(APIPathProviders+"/memories", adaptHandler(s.handlers.Provider.HandleListSupportedMemoryProviders)).Methods(http.MethodGet) + s.router.HandleFunc(APIPathProviders+"/configured", adaptHandler(s.handlers.Provider.HandleListConfiguredProviders)).Methods(http.MethodGet) + s.router.HandleFunc(APIPathProviders+"/configured/{name}/models", adaptHandler(s.handlers.Provider.HandleGetProviderModels)).Methods(http.MethodGet) // Models s.router.HandleFunc(APIPathModels, adaptHandler(s.handlers.Model.HandleListSupportedModels)).Methods(http.MethodGet) diff --git a/go/pkg/app/app.go b/go/pkg/app/app.go index cc38e3639..63d80be2b 100644 --- a/go/pkg/app/app.go +++ b/go/pkg/app/app.go @@ -41,6 +41,7 @@ import ( "github.com/kagent-dev/kagent/go/internal/mcp" versionmetrics "github.com/kagent-dev/kagent/go/internal/metrics" + "github.com/kagent-dev/kagent/go/internal/controller/provider" "github.com/kagent-dev/kagent/go/internal/controller/reconciler" reconcilerutils "github.com/kagent-dev/kagent/go/internal/controller/reconciler/utils" agent_translator "github.com/kagent-dev/kagent/go/internal/controller/translator/agent" @@ -422,6 +423,14 @@ func Start(getExtensionConfig GetExtensionConfig) { os.Exit(1) } + if err = (&controller.ProviderController{ + Scheme: mgr.GetScheme(), + Reconciler: rcnclr, + }).SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "Provider") + os.Exit(1) + } + if err = (&controller.RemoteMCPServerController{ Scheme: mgr.GetScheme(), Reconciler: rcnclr, @@ -494,6 +503,10 @@ func Start(getExtensionConfig GetExtensionConfig) { os.Exit(1) } + // Initialize provider manager for model/provider discovery + providerManager := provider.NewManager(mgr.GetClient(), kagentNamespace) + setupLog.Info("Initialized provider manager", "namespace", kagentNamespace) + httpServer, err := httpserver.NewHTTPServer(httpserver.ServerConfig{ Router: router, BindAddr: cfg.HttpServerAddr, @@ -505,6 +518,7 @@ func Start(getExtensionConfig GetExtensionConfig) { Authorizer: extensionCfg.Authorizer, Authenticator: extensionCfg.Authenticator, ProxyURL: cfg.Proxy.URL, + ProviderManager: providerManager, }) if err != nil { setupLog.Error(err, "unable to create HTTP server") diff --git a/helm/kagent-crds/templates/kagent.dev_providers.yaml b/helm/kagent-crds/templates/kagent.dev_providers.yaml new file mode 100644 index 000000000..ba902b2a4 --- /dev/null +++ b/helm/kagent-crds/templates/kagent.dev_providers.yaml @@ -0,0 +1,195 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.19.0 + name: providers.kagent.dev +spec: + group: kagent.dev + names: + categories: + - kagent + kind: Provider + listKind: ProviderList + plural: providers + shortNames: + - prov + singular: provider + scope: Namespaced + versions: + - additionalPrinterColumns: + - jsonPath: .spec.type + name: Type + type: string + - jsonPath: .spec.endpoint + name: Endpoint + type: string + - jsonPath: .status.discoveredModels[?(@)] + name: Models + type: integer + - jsonPath: .status.conditions[?(@.type=='Ready')].status + name: Ready + type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date + name: v1alpha2 + schema: + openAPIV3Schema: + description: |- + Provider is the Schema for the providers API. + It represents a model provider configuration with automatic model discovery. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: ProviderSpec defines the desired state of Provider. + properties: + endpoint: + description: Endpoint is the API endpoint URL for the provider + pattern: ^https?://.* + type: string + secretRef: + description: SecretRef references the Kubernetes Secret containing + the API key + properties: + key: + description: Key is the key within the secret that contains the + API key or credential + type: string + name: + description: Name is the name of the secret in the same namespace + as the Provider + type: string + required: + - key + - name + type: object + type: + description: Type is the model provider type (OpenAI, Anthropic, etc.) + enum: + - Anthropic + - OpenAI + - AzureOpenAI + - Ollama + - Gemini + - GeminiVertexAI + - AnthropicVertexAI + type: string + required: + - endpoint + - secretRef + - type + type: object + x-kubernetes-validations: + - message: endpoint must be a valid URL starting with http:// or https:// + rule: self.endpoint.startsWith('http://') || self.endpoint.startsWith('https://') + - message: secretRef.name and secretRef.key are required + rule: has(self.secretRef) && has(self.secretRef.name) && size(self.secretRef.name) + > 0 && has(self.secretRef.key) && size(self.secretRef.key) > 0 + status: + description: ProviderStatus defines the observed state of Provider. + properties: + conditions: + description: Conditions represent the latest available observations + of the Provider's state + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + type: array + x-kubernetes-list-map-keys: + - type + x-kubernetes-list-type: map + discoveredModels: + description: DiscoveredModels is the cached list of model IDs available + from this provider + items: + type: string + type: array + lastDiscoveryTime: + description: LastDiscoveryTime is the timestamp of the last successful + model discovery + format: date-time + type: string + observedGeneration: + description: ObservedGeneration reflects the generation of the most + recently observed Provider spec + format: int64 + type: integer + secretHash: + description: SecretHash is a hash of the referenced secret data, used + to detect secret changes + type: string + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/helm/kagent/templates/rbac/clusterrole.yaml b/helm/kagent/templates/rbac/clusterrole.yaml index 1841e7e5f..f58fe2a51 100644 --- a/helm/kagent/templates/rbac/clusterrole.yaml +++ b/helm/kagent/templates/rbac/clusterrole.yaml @@ -10,6 +10,7 @@ rules: resources: - agents - modelconfigs + - providers - toolservers - memories - remotemcpservers @@ -23,6 +24,7 @@ rules: resources: - agents/finalizers - modelconfigs/finalizers + - providers/finalizers - toolservers/finalizers - memories/finalizers - remotemcpservers/finalizers @@ -34,6 +36,7 @@ rules: resources: - agents/status - modelconfigs/status + - providers/status - toolservers/status - memories/status - remotemcpservers/status @@ -104,6 +107,7 @@ rules: resources: - agents - modelconfigs + - providers - toolservers - memories - remotemcpservers @@ -118,6 +122,7 @@ rules: resources: - agents/finalizers - modelconfigs/finalizers + - providers/finalizers - toolservers/finalizers - memories/finalizers - remotemcpservers/finalizers diff --git a/ui/jest.setup.ts b/ui/jest.setup.ts index 917b8d0ee..8d85e3459 100644 --- a/ui/jest.setup.ts +++ b/ui/jest.setup.ts @@ -1,10 +1,26 @@ import '@testing-library/jest-dom'; +import { TextEncoder, TextDecoder } from 'util'; // Mock uuid module (ESM-only, needs mocking for Jest) jest.mock('uuid', () => ({ v4: jest.fn(() => 'test-uuid-v4'), })); +// Polyfill TextEncoder/TextDecoder for Node.js test environment +global.TextEncoder = TextEncoder; +global.TextDecoder = TextDecoder as typeof global.TextDecoder; + +// Polyfill Request/Response for Next.js server actions +if (typeof Request === 'undefined') { + global.Request = class Request {} as any; +} +if (typeof Response === 'undefined') { + global.Response = class Response {} as any; +} +if (typeof Headers === 'undefined') { + global.Headers = class Headers {} as any; +} + // Mock next/router jest.mock('next/router', () => ({ useRouter() { diff --git a/ui/src/app/actions/providers.ts b/ui/src/app/actions/providers.ts index dc0fbbe26..55d1f2d4e 100644 --- a/ui/src/app/actions/providers.ts +++ b/ui/src/app/actions/providers.ts @@ -1,11 +1,11 @@ "use server"; import { createErrorResponse } from "./utils"; -import { Provider } from "@/types"; +import { Provider, ConfiguredProvider, ConfiguredProviderModelsResponse } from "@/types"; import { BaseResponse } from "@/types"; import { fetchApi } from "./utils"; /** - * Gets the list of supported providers + * Gets the list of supported (stock) providers * @returns A promise with the list of supported providers */ export async function getSupportedModelProviders(): Promise> { @@ -16,3 +16,37 @@ export async function getSupportedModelProviders(): Promise(error, "Error getting supported providers"); } } + +/** + * Gets the list of configured providers from Provider CRDs + * @returns A promise with the list of configured providers + */ +export async function getConfiguredProviders(): Promise> { + try { + const response = await fetchApi>("/providers/configured"); + return response; + } catch (error) { + return createErrorResponse(error, "Error getting configured providers"); + } +} + +/** + * Gets the models for a specific configured provider + * @param providerName - The name of the configured provider + * @param forceRefresh - Whether to force a refresh of the model list + * @returns A promise with the list of models for the provider + */ +export async function getConfiguredProviderModels( + providerName: string, + forceRefresh: boolean = false +): Promise> { + try { + const queryParam = forceRefresh ? "?refresh=true" : ""; + const response = await fetchApi>( + `/providers/configured/${providerName}/models${queryParam}` + ); + return response; + } catch (error) { + return createErrorResponse(error, `Error getting models for provider ${providerName}`); + } +} diff --git a/ui/src/app/models/new/__tests__/providerSelection.test.tsx b/ui/src/app/models/new/__tests__/providerSelection.test.tsx new file mode 100644 index 000000000..db192723c --- /dev/null +++ b/ui/src/app/models/new/__tests__/providerSelection.test.tsx @@ -0,0 +1,169 @@ +/** + * Test: Provider Selection Bug Fix + * + * This test verifies the useEffect logic for auto-selecting the stock OpenAI provider. + * + * Bug: The useEffect was running on every `providers` array change, overwriting user selections + * when they clicked a configured provider with type "OpenAI". + * + * Fix: Remove `providers` from dependency array so effect only runs on mount. + */ + +import { describe, it, expect, jest } from '@jest/globals'; +import { renderHook, act } from '@testing-library/react'; +import { useEffect, useState } from 'react'; + +type Provider = { + name: string; + type: string; + source?: 'stock' | 'configured'; +}; + +describe('Provider Selection useEffect Logic', () => { + // This recreates the problematic useEffect behavior before the fix + function useBuggyProviderSelection(providers: Provider[], isEditMode: boolean) { + const [selectedProvider, setSelectedProvider] = useState(null); + + useEffect(() => { + if (!isEditMode && providers.length > 0 && !selectedProvider) { + const openAIProvider = providers.find(p => p.type === 'OpenAI'); + if (openAIProvider) { + setSelectedProvider(openAIProvider); + } + } + }, [isEditMode, providers, selectedProvider]); // BUG: providers in dependency array + + return { selectedProvider, setSelectedProvider }; + } + + // This recreates the fixed useEffect behavior + function useFixedProviderSelection(providers: Provider[], isEditMode: boolean) { + const [selectedProvider, setSelectedProvider] = useState(null); + + useEffect(() => { + if (!isEditMode && providers.length > 0 && !selectedProvider) { + const openAIProvider = providers.find(p => p.type === 'OpenAI' && p.source === 'stock'); + if (openAIProvider) { + setSelectedProvider(openAIProvider); + } + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [isEditMode]); // FIX: Only run on mount (isEditMode change) + + return { selectedProvider, setSelectedProvider }; + } + + const stockOpenAI: Provider = { + name: 'OpenAI', + type: 'OpenAI', + source: 'stock', + }; + + const configuredOpenAI: Provider = { + name: 'ai-gateway-openai', + type: 'OpenAI', + source: 'configured', + }; + + describe('Buggy Implementation (Before Fix)', () => { + it('demonstrates the bug: user selection gets overwritten', () => { + const providers = [stockOpenAI, configuredOpenAI]; + + const { result, rerender } = renderHook( + ({ providers, isEditMode }) => useBuggyProviderSelection(providers, isEditMode), + { initialProps: { providers, isEditMode: false } } + ); + + // Initial state: stock OpenAI should be auto-selected + expect(result.current.selectedProvider).toEqual(stockOpenAI); + + // User clicks configured provider + act(() => { + result.current.setSelectedProvider(configuredOpenAI); + }); + + // Verify user selection + expect(result.current.selectedProvider).toEqual(configuredOpenAI); + + // The bug manifests when the effect runs due to the providers dependency. + // In the real app, this happens when: + // 1. The component re-renders with a new providers array reference + // 2. During the state update, React batches operations and selectedProvider might be momentarily null + // 3. The effect condition passes and overwrites the selection + // + // For this test, we verify the fixed implementation prevents this by not depending on providers + // The buggy version would fail in production, but the condition !selectedProvider prevents it here + // This test demonstrates why the fix is needed - to avoid the race condition entirely + expect(result.current.selectedProvider).toEqual(configuredOpenAI); + }); + }); + + describe('Fixed Implementation (After Fix)', () => { + it('maintains user selection after providers array changes', () => { + const providers = [stockOpenAI, configuredOpenAI]; + + const { result, rerender } = renderHook( + ({ providers, isEditMode }) => useFixedProviderSelection(providers, isEditMode), + { initialProps: { providers, isEditMode: false } } + ); + + // Initial state: stock OpenAI should be auto-selected + expect(result.current.selectedProvider).toEqual(stockOpenAI); + + // User clicks configured provider + act(() => { + result.current.setSelectedProvider(configuredOpenAI); + }); + + // Verify user selection + expect(result.current.selectedProvider).toEqual(configuredOpenAI); + + // Simulate a re-render that updates providers array + const newProviders = [...providers]; + rerender({ providers: newProviders, isEditMode: false }); + + // FIX: Selection persists because useEffect doesn't run on providers change + expect(result.current.selectedProvider).toEqual(configuredOpenAI); // PASSES! + }); + + it('only auto-selects on initial mount, not on subsequent renders', () => { + const providers = [stockOpenAI, configuredOpenAI]; + + const { result, rerender } = renderHook( + ({ providers, isEditMode }) => useFixedProviderSelection(providers, isEditMode), + { initialProps: { providers, isEditMode: false } } + ); + + // Initial auto-selection + expect(result.current.selectedProvider).toEqual(stockOpenAI); + + // Clear selection (simulating user clearing it) + act(() => { + result.current.setSelectedProvider(null); + }); + + expect(result.current.selectedProvider).toBeNull(); + + // Providers array changes (new reference) + const newProviders = [...providers]; + rerender({ providers: newProviders, isEditMode: false }); + + // useEffect should NOT run again (only runs on mount) + expect(result.current.selectedProvider).toBeNull(); // PASSES! + }); + + it('explicitly selects stock provider using source field', () => { + const providers = [stockOpenAI, configuredOpenAI]; + + const { result } = renderHook( + ({ providers, isEditMode }) => useFixedProviderSelection(providers, isEditMode), + { initialProps: { providers, isEditMode: false } } + ); + + // Should select stock, not configured, even though configured also has type OpenAI + expect(result.current.selectedProvider).toEqual(stockOpenAI); + expect(result.current.selectedProvider?.source).toBe('stock'); + }); + }); + +}); diff --git a/ui/src/app/models/new/page.tsx b/ui/src/app/models/new/page.tsx index c2df5853d..772944ff3 100644 --- a/ui/src/app/models/new/page.tsx +++ b/ui/src/app/models/new/page.tsx @@ -23,7 +23,7 @@ import type { import { toast } from "sonner"; import { isResourceNameValid, createRFC1123ValidName } from "@/lib/utils"; import { OLLAMA_DEFAULT_TAG } from "@/lib/constants" -import { getSupportedModelProviders } from "@/app/actions/providers"; +import { getSupportedModelProviders, getConfiguredProviders, getConfiguredProviderModels } from "@/app/actions/providers"; import { getModels } from "@/app/actions/models"; import { isValidProviderInfoKey, getProviderFormKey, ModelProviderKey, BackendModelProviderType } from "@/lib/providers"; import { BasicInfoSection } from '@/components/models/new/BasicInfoSection'; @@ -128,6 +128,7 @@ function ModelPageContent() { const [errors, setErrors] = useState({}); const [isApiKeyNeeded, setIsApiKeyNeeded] = useState(true); const [isParamsSectionExpanded, setIsParamsSectionExpanded] = useState(false); + const [isFetchingModels, setIsFetchingModels] = useState(false); const isOllamaSelected = selectedProvider?.type === "Ollama"; useEffect(() => { @@ -136,17 +137,35 @@ function ModelPageContent() { setLoadingError(null); setIsLoading(true); try { - const [providersResponse, modelsResponse] = await Promise.all([ + const [stockProvidersResponse, configuredProvidersResponse, modelsResponse] = await Promise.all([ getSupportedModelProviders(), + getConfiguredProviders(), getModels() ]); if (!isMounted) return; - if (!providersResponse.error && providersResponse.data) { - setProviders(providersResponse.data); - } else { - throw new Error(providersResponse.error || "Failed to fetch supported providers"); - } + + // Merge stock and configured providers + const stockProviders: Provider[] = (stockProvidersResponse.data || []).map(p => ({ + ...p, + source: 'stock' as const + })); + + const configuredProviders: Provider[] = (configuredProvidersResponse.data || []).map(cp => { + // Find the stock provider with the same type to get its params + const stockProvider = stockProviders.find(sp => sp.type === cp.type); + return { + name: cp.name, + type: cp.type, + requiredParams: stockProvider?.requiredParams || [], + optionalParams: stockProvider?.optionalParams || [], + source: 'configured' as const, + endpoint: cp.endpoint + }; + }); + + const allProviders = [...stockProviders, ...configuredProviders]; + setProviders(allProviders); if (!modelsResponse.error && modelsResponse.data) { setProviderModelsData(modelsResponse.data); @@ -256,10 +275,73 @@ function ModelPageContent() { return () => { isMounted = false; }; }, [isEditMode, modelConfigName, providers, providerModelsData, modelConfigNamespace]); + // Auto-fetch models when provider is selected and models are not available + useEffect(() => { + let isMounted = true; + const fetchProviderModels = async () => { + if (!selectedProvider || isEditMode) return; + + const providerKey = getProviderFormKey(selectedProvider.type as BackendModelProviderType); + if (!providerKey) return; + + // Check if models are already available for this provider + const hasModels = providerModelsData?.[providerKey] && providerModelsData[providerKey].length > 0; + if (hasModels) return; + + try { + if (selectedProvider.source === 'configured') { + // Fetch models for configured provider + const response = await getConfiguredProviderModels(selectedProvider.name, false); + + if (!isMounted) return; + + if (response.error || !response.data) { + console.error(`Failed to fetch models for ${selectedProvider.name}:`, response.error); + return; + } + + const models = response.data.models.map(modelName => ({ + name: modelName, + function_calling: true + })); + + setProviderModelsData(prev => ({ + ...prev, + [providerKey]: models + })); + } else { + // Fetch all stock models if stock provider is selected and models are missing + const response = await getModels(); + + if (!isMounted) return; + + if (response.error || !response.data) { + console.error('Failed to fetch stock models:', response.error); + return; + } + + setProviderModelsData(response.data); + } + } catch (error) { + console.error('Error fetching provider models:', error); + } + }; + + fetchProviderModels(); + return () => { isMounted = false; }; + }, [selectedProvider, isEditMode, providerModelsData]); + useEffect(() => { if (selectedProvider) { const requiredKeys = selectedProvider.requiredParams || []; - const optionalKeys = selectedProvider.optionalParams || []; + let optionalKeys = [...(selectedProvider.optionalParams || [])]; + + // Add baseUrl to optional params for providers that support it + const providersWithBaseUrl = ['OpenAI', 'Anthropic', 'Gemini']; + if (providersWithBaseUrl.includes(selectedProvider.type) && !optionalKeys.includes('baseUrl')) { + optionalKeys = ['baseUrl', ...optionalKeys]; + } + const currentModelRequiresReset = !isEditMode; if (currentModelRequiresReset) { @@ -315,6 +397,18 @@ function ModelPageContent() { } }, [isApiKeyNeeded, errors.apiKey]); + // Auto-select provider on page load (create mode only) + // Default: select stock OpenAI provider + useEffect(() => { + if (!isEditMode && providers.length > 0 && !selectedProvider) { + const openAIProvider = providers.find(p => p.type === 'OpenAI' && p.source === 'stock'); + if (openAIProvider) { + setSelectedProvider(openAIProvider); + } + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [isEditMode]); // Only run when isEditMode changes (initial mount) + const validateForm = () => { const newErrors: ValidationErrors = { requiredParams: {} }; @@ -384,6 +478,51 @@ function ModelPageContent() { } }; + const handleFetchModels = async () => { + setIsFetchingModels(true); + try { + // If a configured provider is selected, fetch its models specifically + if (selectedProvider?.source === 'configured') { + const response = await getConfiguredProviderModels(selectedProvider.name, true); + + if (response.error || !response.data) { + throw new Error(response.error || `Failed to fetch models for ${selectedProvider.name}`); + } + + // Convert configured provider models to the expected format + const providerKey = getProviderFormKey(selectedProvider.type as BackendModelProviderType); + if (providerKey) { + const models = response.data.models.map(modelName => ({ + name: modelName, + function_calling: true // Assume function calling for configured providers + })); + + setProviderModelsData(prev => ({ + ...prev, + [providerKey]: models + })); + } + + toast.success(`Models fetched for ${selectedProvider.name}`); + } else { + // Fetch stock models + const response = await getModels(); + + if (response.error || !response.data) { + throw new Error(response.error || "Failed to fetch models"); + } + + setProviderModelsData(response.data); + toast.success("Models refreshed successfully"); + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : "Failed to fetch models"; + toast.error(errorMessage); + } finally { + setIsFetchingModels(false); + } + }; + const handleSubmit = async () => { if (!selectedCombinedModel) { setErrors(prev => ({...prev, selectedCombinedModel: "Provider and Model selection is required"})); @@ -532,19 +671,35 @@ function ModelPageContent() { selectedCombinedModel={selectedCombinedModel} onModelChange={(comboboxValue, providerKey, modelName, functionCalling) => { setSelectedCombinedModel(comboboxValue); - const prov = providers.find(p => getProviderFormKey(p.type as BackendModelProviderType) === providerKey); - setSelectedProvider(prov || null); setSelectedModelSupportsFunctionCalling(functionCalling); if (errors.selectedCombinedModel) { setErrors(prev => ({ ...prev, selectedCombinedModel: undefined })); } }} + onProviderChange={(provider) => { + setSelectedProvider(provider); + + // Clear models for this provider type when switching providers + // This prevents showing wrong models when switching between providers with the same type + // (e.g., stock OpenAI vs configured ai-gateway-openai) + const providerKey = getProviderFormKey(provider.type as BackendModelProviderType); + if (providerKey && providerModelsData?.[providerKey]) { + setProviderModelsData(prev => { + if (!prev) return prev; + const newData = { ...prev }; + delete newData[providerKey]; + return newData; + }); + } + }} selectedProvider={selectedProvider} selectedModelSupportsFunctionCalling={selectedModelSupportsFunctionCalling} loadingError={loadingError} isEditMode={isEditMode} modelTag={modelTag} onModelTagChange={setModelTag} + onFetchModels={handleFetchModels} + isFetchingModels={isFetchingModels} /> ); } - - - - - diff --git a/ui/src/components/ModelCombobox.tsx b/ui/src/components/ModelCombobox.tsx new file mode 100644 index 000000000..b34e7d891 --- /dev/null +++ b/ui/src/components/ModelCombobox.tsx @@ -0,0 +1,91 @@ +import React, { useState, useMemo } from 'react'; +import { Button } from '@/components/ui/button'; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { Command, CommandEmpty, CommandInput, CommandItem, CommandList } from "@/components/ui/command"; +import { Check, ChevronsUpDown } from 'lucide-react'; +import { cn } from '@/lib/utils'; +import { ProviderModel } from '@/types'; + +interface ModelComboboxProps { + models: ProviderModel[]; + value: string | undefined; + onChange: (modelName: string, functionCalling: boolean) => void; + disabled?: boolean; + placeholder?: string; + emptyMessage?: string; +} + +export function ModelCombobox({ + models, + value, + onChange, + disabled = false, + placeholder = "Select model...", + emptyMessage = "No models available" +}: ModelComboboxProps) { + const [open, setOpen] = useState(false); + + const sortedModels = useMemo(() => { + return [...models].sort((a, b) => a.name.localeCompare(b.name)); + }, [models]); + + const selectedModel = useMemo(() => { + return sortedModels.find(m => m.name === value); + }, [sortedModels, value]); + + const triggerContent = useMemo(() => { + if (selectedModel) { + return selectedModel.name; + } + if (sortedModels.length === 0 && !disabled) return emptyMessage; + return placeholder; + }, [selectedModel, sortedModels.length, disabled, emptyMessage, placeholder]); + + return ( + + + + + + + + + No model found. + {sortedModels.map((model) => ( + { + onChange(model.name, model.function_calling ?? false); + setOpen(false); + }} + > + + {model.name} + + ))} + + + + + ); +} diff --git a/ui/src/components/ProviderCombobox.tsx b/ui/src/components/ProviderCombobox.tsx new file mode 100644 index 000000000..65d8bce84 --- /dev/null +++ b/ui/src/components/ProviderCombobox.tsx @@ -0,0 +1,156 @@ +import React, { useState, useMemo } from 'react'; +import { Button } from '@/components/ui/button'; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { Command, CommandEmpty, CommandGroup, CommandInput, CommandItem, CommandList, CommandSeparator } from "@/components/ui/command"; +import { Check, ChevronsUpDown } from 'lucide-react'; +import { cn } from '@/lib/utils'; +import { Provider } from '@/types'; +import { ModelProviderKey } from '@/lib/providers'; +import { OpenAI } from './icons/OpenAI'; +import { Anthropic } from './icons/Anthropic'; +import { Ollama } from './icons/Ollama'; +import { Azure } from './icons/Azure'; +import { Gemini } from './icons/Gemini'; + +const PROVIDER_ICONS: Record> = { + 'OpenAI': OpenAI, + 'Anthropic': Anthropic, + 'Ollama': Ollama, + 'AzureOpenAI': Azure, + 'Gemini': Gemini, + 'GeminiVertexAI': Gemini, + 'AnthropicVertexAI': Anthropic, +}; + +function getProviderIcon(providerType: string | undefined): React.ReactNode | null { + if (!providerType || !(providerType in PROVIDER_ICONS)) { + return null; + } + const IconComponent = PROVIDER_ICONS[providerType as ModelProviderKey]; + return ; +} + +interface ProviderComboboxProps { + providers: Provider[]; + value: Provider | null; + onChange: (provider: Provider) => void; + disabled?: boolean; + loading?: boolean; +} + +export function ProviderCombobox({ + providers, + value, + onChange, + disabled = false, + loading = false, +}: ProviderComboboxProps) { + const [open, setOpen] = useState(false); + + const groupedProviders = useMemo(() => { + const stock = providers.filter(p => p.source === 'stock' || !p.source).sort((a, b) => a.name.localeCompare(b.name)); + const configured = providers.filter(p => p.source === 'configured').sort((a, b) => a.name.localeCompare(b.name)); + return { stock, configured }; + }, [providers]); + + const hasProviders = groupedProviders.stock.length > 0 || groupedProviders.configured.length > 0; + + const triggerContent = useMemo(() => { + if (loading) return "Loading providers..."; + if (value) { + return ( + <> + {getProviderIcon(value.type)} + {value.name} + + ); + } + if (!hasProviders) return "No providers available"; + return "Select provider..."; + }, [loading, value, hasProviders]); + + return ( + + + + + + + + + No provider found. + + {/* Configured Providers (shown first) */} + {groupedProviders.configured.length > 0 && ( + + {groupedProviders.configured.map((provider) => ( + { + onChange(provider); + setOpen(false); + }} + > + + {getProviderIcon(provider.type)} + {provider.name} + + ))} + + )} + + {/* Separator if both groups exist */} + {groupedProviders.configured.length > 0 && groupedProviders.stock.length > 0 && ( + + )} + + {/* Stock Providers */} + {groupedProviders.stock.length > 0 && ( + + {groupedProviders.stock.map((provider) => ( + { + onChange(provider); + setOpen(false); + }} + > + + {getProviderIcon(provider.type)} + {provider.name} + + ))} + + )} + + + + + ); +} diff --git a/ui/src/components/models/new/BasicInfoSection.tsx b/ui/src/components/models/new/BasicInfoSection.tsx index 68b333e86..fbd6327b5 100644 --- a/ui/src/components/models/new/BasicInfoSection.tsx +++ b/ui/src/components/models/new/BasicInfoSection.tsx @@ -2,10 +2,11 @@ import React from 'react'; import { Input } from "@/components/ui/input"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; -import { Pencil, ExternalLinkIcon, AlertCircle } from "lucide-react"; +import { Pencil, ExternalLinkIcon, AlertCircle, Loader2 } from "lucide-react"; import Link from "next/link"; import type { Provider, ProviderModelsResponse } from "@/types"; -import { ModelProviderCombobox } from "@/components/ModelProviderCombobox"; +import { ProviderCombobox } from "@/components/ProviderCombobox"; +import { ModelCombobox } from "@/components/ModelCombobox"; import { PROVIDERS_INFO, getProviderFormKey, ModelProviderKey, BackendModelProviderType } from "@/lib/providers"; import { OLLAMA_DEFAULT_TAG } from '@/lib/constants'; import { NamespaceCombobox } from "@/components/NamespaceCombobox"; @@ -32,23 +33,38 @@ interface BasicInfoSectionProps { providers: Provider[]; providerModelsData: ProviderModelsResponse | null; selectedCombinedModel: string | undefined; - onModelChange: (comboboxValue: string, providerKey: ModelProviderKey, modelName: string, functionCalling: boolean) => void; + onModelChange: (comboboxValue: string | undefined, providerKey: ModelProviderKey, modelName: string, functionCalling: boolean) => void; + onProviderChange: (provider: Provider) => void; selectedProvider: Provider | null; selectedModelSupportsFunctionCalling: boolean | null; loadingError: string | null; isEditMode: boolean; modelTag: string; onModelTagChange: (value: string) => void; + onFetchModels: () => void; + isFetchingModels: boolean; } export const BasicInfoSection: React.FC = ({ name, isEditingName, namespace, errors, isSubmitting, isLoading, onNameChange, onToggleEditName, onNamespaceChange, providers, providerModelsData, selectedCombinedModel, - onModelChange, selectedProvider, selectedModelSupportsFunctionCalling, - loadingError, isEditMode, modelTag, onModelTagChange + onModelChange, onProviderChange, selectedProvider, selectedModelSupportsFunctionCalling, + loadingError, isEditMode, modelTag, onModelTagChange, onFetchModels, isFetchingModels }) => { const isOllamaSelected = selectedProvider?.type === "Ollama"; + // Get the current provider key and models for the selected provider + const selectedProviderKey = selectedProvider + ? getProviderFormKey(selectedProvider.type as BackendModelProviderType) + : undefined; + + const modelsForSelectedProvider = selectedProviderKey && providerModelsData + ? providerModelsData[selectedProviderKey] || [] + : []; + + // Extract the current model name from selectedCombinedModel (format: "providerKey::modelName") + const currentModelName = selectedCombinedModel?.split('::')[1]; + return ( @@ -96,36 +112,94 @@ export const BasicInfoSection: React.FC = ({ -
- + {/* Provider Selection */} +
+
+ + +
- { + // Directly set the selected provider + onProviderChange(provider); + + // Clear model selection when provider changes + const providerKey = getProviderFormKey(provider.type as BackendModelProviderType); + if (providerKey) { + onModelChange(undefined, providerKey, '', false); + } + }} disabled={isSubmitting || isLoading || isEditMode} loading={isLoading} - error={loadingError} - filterFunctionCalling={false} - placeholder="Select Provider & Model..." />
{selectedProvider && ( (() => { - const providerKey = getProviderFormKey(selectedProvider.type as BackendModelProviderType); - const providerInfo = providerKey ? PROVIDERS_INFO[providerKey] : undefined; - return providerInfo?.modelDocsLink ? ( - - ) : null; - })() + const providerKey = getProviderFormKey(selectedProvider.type as BackendModelProviderType); + const providerInfo = providerKey ? PROVIDERS_INFO[providerKey] : undefined; + return providerInfo?.modelDocsLink ? ( + + ) : null; + })() )}
+ {loadingError &&

{loadingError}

} +
+ + {/* Model Selection */} +
+ + {selectedProvider && providerModelsData ? ( + { + const providerKey = getProviderFormKey(selectedProvider.type as BackendModelProviderType); + if (providerKey) { + onModelChange(`${providerKey}::${modelName}`, providerKey, modelName, functionCalling); + } + }} + disabled={isSubmitting || isLoading || isEditMode} + placeholder="Select a model..." + emptyMessage="No models available for this provider" + /> + ) : ( + {}} + disabled={true} + placeholder="Select a provider first..." + /> + )} {errors.selectedCombinedModel &&

{errors.selectedCombinedModel}

} {selectedCombinedModel && selectedModelSupportsFunctionCalling === false && (

diff --git a/ui/src/lib/__tests__/utils.test.ts b/ui/src/lib/__tests__/utils.test.ts index eb21d88f9..39e0fd8ea 100644 --- a/ui/src/lib/__tests__/utils.test.ts +++ b/ui/src/lib/__tests__/utils.test.ts @@ -107,4 +107,5 @@ describe('RFC 1123 Valid Name', () => { expect(createRFC1123ValidName(['***', '___', ''])).toBe(''); }); }); -}); \ No newline at end of file +}); + diff --git a/ui/src/lib/utils.ts b/ui/src/lib/utils.ts index da4d6edda..67ad340d4 100644 --- a/ui/src/lib/utils.ts +++ b/ui/src/lib/utils.ts @@ -157,7 +157,3 @@ export function isAgentToolName(name: string | undefined): boolean { - - - - diff --git a/ui/src/types/index.ts b/ui/src/types/index.ts index 982bad21f..fc54551ed 100644 --- a/ui/src/types/index.ts +++ b/ui/src/types/index.ts @@ -34,6 +34,8 @@ export interface Provider { type: string; requiredParams: string[]; optionalParams: string[]; + source?: 'stock' | 'configured'; // Distinguishes between stock and configured providers + endpoint?: string; // Only present for configured providers } export type ProviderModel = { @@ -44,6 +46,19 @@ export type ProviderModel = { // Define the type for the expected API response structure export type ProviderModelsResponse = Record; +// ConfiguredProvider is the response from /api/providers/configured +export interface ConfiguredProvider { + name: string; + type: string; + endpoint: string; +} + +// ConfiguredProviderModelsResponse is the response from /api/providers/configured/{name}/models +export interface ConfiguredProviderModelsResponse { + provider: string; + models: string[]; +} + // Export OpenAIConfigPayload export interface OpenAIConfigPayload { baseUrl?: string; @@ -207,7 +222,6 @@ export interface TypedLocalReference { kind?: string; apiGroup?: string; name: string; - namespace?: string; } export interface McpServerTool extends TypedLocalReference { From 28330406a607776302e50a955ef1dd8a69d825d9 Mon Sep 17 00:00:00 2001 From: Nagarjun Krishnan Date: Wed, 11 Feb 2026 09:28:40 -0700 Subject: [PATCH 2/3] default envpoints and fix manager Signed-off-by: Nagarjun Krishnan --- go/api/v1alpha2/provider_types.go | 54 +++-- go/api/v1alpha2/zz_generated.deepcopy.go | 8 +- go/internal/controller/provider/discoverer.go | 6 +- .../controller/provider/discoverer_test.go | 3 +- go/internal/controller/provider/manager.go | 116 ++++++----- .../controller/provider/manager_test.go | 195 +++++++++++------- .../controller/reconciler/reconciler.go | 108 +++++----- 7 files changed, 289 insertions(+), 201 deletions(-) diff --git a/go/api/v1alpha2/provider_types.go b/go/api/v1alpha2/provider_types.go index 22a21c4b3..98641999d 100644 --- a/go/api/v1alpha2/provider_types.go +++ b/go/api/v1alpha2/provider_types.go @@ -29,11 +29,26 @@ const ( // ProviderConditionTypeModelsDiscovered indicates whether model discovery has succeeded ProviderConditionTypeModelsDiscovered = "ModelsDiscovered" - - // ProviderAnnotationForceDiscovery is set by clients to trigger immediate model discovery - ProviderAnnotationForceDiscovery = "kagent.dev/force-discovery" ) +// DefaultProviderEndpoint returns the default API endpoint for a given provider type. +// Returns empty string if no default is defined. +func DefaultProviderEndpoint(providerType ModelProvider) string { + switch providerType { + case ModelProviderOpenAI: + return "https://api.openai.com/v1" + case ModelProviderAnthropic: + return "https://api.anthropic.com" + case ModelProviderGemini: + return "https://generativelanguage.googleapis.com" + case ModelProviderOllama: + return "http://localhost:11434" + default: + // Azure, Bedrock, Vertex AI require user-specific endpoints + return "" + } +} + // SecretReference contains information to locate a secret. type SecretReference struct { // Name is the name of the secret in the same namespace as the Provider @@ -47,24 +62,37 @@ type SecretReference struct { // ProviderSpec defines the desired state of Provider. // -// +kubebuilder:validation:XValidation:message="endpoint must be a valid URL starting with http:// or https://",rule="self.endpoint.startsWith('http://') || self.endpoint.startsWith('https://')" -// +kubebuilder:validation:XValidation:message="secretRef.name and secretRef.key are required",rule="has(self.secretRef) && has(self.secretRef.name) && size(self.secretRef.name) > 0 && has(self.secretRef.key) && size(self.secretRef.key) > 0" +// +kubebuilder:validation:XValidation:message="endpoint must be a valid URL starting with http:// or https://",rule="!has(self.endpoint) || self.endpoint == '' || self.endpoint.startsWith('http://') || self.endpoint.startsWith('https://')" +// +kubebuilder:validation:XValidation:message="secretRef is required for providers that need authentication (not Ollama)",rule="self.type == 'Ollama' || (has(self.secretRef) && has(self.secretRef.name) && size(self.secretRef.name) > 0 && has(self.secretRef.key) && size(self.secretRef.key) > 0)" type ProviderSpec struct { // Type is the model provider type (OpenAI, Anthropic, etc.) // +required // +kubebuilder:validation:Required Type ModelProvider `json:"type"` - // Endpoint is the API endpoint URL for the provider - // +required - // +kubebuilder:validation:Required + // Endpoint is the API endpoint URL for the provider. + // If not specified, the default endpoint for the provider type will be used. + // +optional // +kubebuilder:validation:Pattern=`^https?://.*` - Endpoint string `json:"endpoint"` + Endpoint string `json:"endpoint,omitempty"` - // SecretRef references the Kubernetes Secret containing the API key - // +required - // +kubebuilder:validation:Required - SecretRef SecretReference `json:"secretRef"` + // SecretRef references the Kubernetes Secret containing the API key. + // Optional for providers that don't require authentication (e.g., local Ollama). + // +optional + SecretRef *SecretReference `json:"secretRef,omitempty"` +} + +// GetEndpoint returns the endpoint, or the default endpoint if not specified. +func (p *ProviderSpec) GetEndpoint() string { + if p.Endpoint != "" { + return p.Endpoint + } + return DefaultProviderEndpoint(p.Type) +} + +// RequiresSecret returns true if this provider type requires a secret for authentication. +func (p *ProviderSpec) RequiresSecret() bool { + return p.Type != ModelProviderOllama } // ProviderStatus defines the observed state of Provider. diff --git a/go/api/v1alpha2/zz_generated.deepcopy.go b/go/api/v1alpha2/zz_generated.deepcopy.go index 0a41e821f..acadc3f6f 100644 --- a/go/api/v1alpha2/zz_generated.deepcopy.go +++ b/go/api/v1alpha2/zz_generated.deepcopy.go @@ -700,7 +700,7 @@ func (in *Provider) DeepCopyInto(out *Provider) { *out = *in out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) - out.Spec = in.Spec + in.Spec.DeepCopyInto(&out.Spec) in.Status.DeepCopyInto(&out.Status) } @@ -757,7 +757,11 @@ func (in *ProviderList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ProviderSpec) DeepCopyInto(out *ProviderSpec) { *out = *in - out.SecretRef = in.SecretRef + if in.SecretRef != nil { + in, out := &in.SecretRef, &out.SecretRef + *out = new(SecretReference) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProviderSpec. diff --git a/go/internal/controller/provider/discoverer.go b/go/internal/controller/provider/discoverer.go index 9cba0868f..60deda0a3 100644 --- a/go/internal/controller/provider/discoverer.go +++ b/go/internal/controller/provider/discoverer.go @@ -65,7 +65,7 @@ func (d *ModelDiscoverer) DiscoverModels(ctx context.Context, providerType v1alp // Ollama has a completely different API - delegate to specialized function if providerType == v1alpha2.ModelProviderOllama { logger.V(1).Info("Discovering models from Ollama", "endpoint", endpoint) - return d.DiscoverOllamaModels(ctx, endpoint) + return d.discoverOllamaModels(ctx, endpoint) } modelsURL := buildModelsURL(endpoint, providerType) @@ -181,8 +181,8 @@ type ollamaTagsResponse struct { } `json:"models"` } -// DiscoverOllamaModels handles Ollama's different response format. -func (d *ModelDiscoverer) DiscoverOllamaModels(ctx context.Context, endpoint string) ([]string, error) { +// discoverOllamaModels handles Ollama's different response format. +func (d *ModelDiscoverer) discoverOllamaModels(ctx context.Context, endpoint string) ([]string, error) { endpoint = strings.TrimSuffix(endpoint, "/") modelsURL := endpoint + "/api/tags" diff --git a/go/internal/controller/provider/discoverer_test.go b/go/internal/controller/provider/discoverer_test.go index 1336729f1..de34ec514 100644 --- a/go/internal/controller/provider/discoverer_test.go +++ b/go/internal/controller/provider/discoverer_test.go @@ -357,7 +357,8 @@ func TestDiscoverOllamaModels(t *testing.T) { defer server.Close() d := NewModelDiscoverer() - models, err := d.DiscoverOllamaModels(context.Background(), server.URL) + // Test through the public DiscoverModels API for Ollama provider + models, err := d.DiscoverModels(context.Background(), v1alpha2.ModelProviderOllama, server.URL, "") if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/go/internal/controller/provider/manager.go b/go/internal/controller/provider/manager.go index 39b439168..45d34ddbb 100644 --- a/go/internal/controller/provider/manager.go +++ b/go/internal/controller/provider/manager.go @@ -19,36 +19,35 @@ package provider import ( "context" "fmt" - "time" "github.com/kagent-dev/kagent/go/api/v1alpha2" + "github.com/kagent-dev/kagent/go/internal/utils" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" ) -const ( - // DefaultNamespace is the default namespace for kagent resources - DefaultNamespace = "kagent" -) - // Manager handles provider configuration and model discovery. // It reads provider configs from Provider CRDs and caches discovered models in CRD status. type Manager struct { - client client.Client - namespace string + client client.Client + namespace string + discoverer *ModelDiscoverer } // NewManager creates a new provider Manager instance. +// If namespace is empty, it uses utils.GetResourceNamespace() which reads from +// KAGENT_NAMESPACE environment variable or defaults to "kagent". func NewManager(client client.Client, namespace string) *Manager { if namespace == "" { - namespace = DefaultNamespace + namespace = utils.GetResourceNamespace() } - return &Manager{ - client: client, - namespace: namespace, + client: client, + namespace: namespace, + discoverer: NewModelDiscoverer(), } } @@ -66,24 +65,28 @@ func (m *Manager) GetProviders() []ProviderConfig { // Only include Ready providers if meta.IsStatusConditionTrue(p.Status.Conditions, v1alpha2.ProviderConditionTypeReady) { - providers = append(providers, ProviderConfig{ + config := ProviderConfig{ Name: p.Name, Type: p.Spec.Type, - Endpoint: p.Spec.Endpoint, - SecretRef: SecretReference{ + Endpoint: p.Spec.GetEndpoint(), + } + // Only set SecretRef if it's specified + if p.Spec.SecretRef != nil { + config.SecretRef = SecretReference{ Name: p.Spec.SecretRef.Name, Key: p.Spec.SecretRef.Key, - }, - }) + } + } + providers = append(providers, config) } } return providers } -// GetModels returns models for a provider from the cached status or triggers discovery. +// GetModels returns models for a provider from the cached status or performs direct discovery. // Models are cached in Provider.Status.DiscoveredModels by the Provider controller. -// If forceRefresh is true, sets the force-discovery annotation to trigger controller reconciliation. +// If forceRefresh is true, performs direct discovery and updates the provider status. func (m *Manager) GetModels(ctx context.Context, providerName string, forceRefresh bool) ([]string, error) { logger := log.FromContext(ctx).WithName("provider-manager") @@ -94,33 +97,25 @@ func (m *Manager) GetModels(ctx context.Context, providerName string, forceRefre return nil, fmt.Errorf("provider %s not found: %w", providerName, err) } - // Force refresh by setting annotation + // If force refresh, perform direct discovery if forceRefresh { - if provider.Annotations == nil { - provider.Annotations = make(map[string]string) - } - provider.Annotations[v1alpha2.ProviderAnnotationForceDiscovery] = "true" - if err := m.client.Update(ctx, provider); err != nil { - return nil, fmt.Errorf("failed to trigger discovery: %w", err) - } + logger.Info("Performing direct model discovery", "provider", providerName) - // Wait for controller to refresh (with timeout) - deadline := time.Now().Add(10 * time.Second) - for time.Now().Before(deadline) { - time.Sleep(500 * time.Millisecond) - if err := m.client.Get(ctx, client.ObjectKeyFromObject(provider), provider); err == nil { - // Check if annotation cleared (refresh done) - if provider.Annotations[v1alpha2.ProviderAnnotationForceDiscovery] != "true" { - logger.Info("Model discovery completed", "provider", providerName) - break - } - } + // Get API key from secret if required + apiKey, err := m.getAPIKey(ctx, provider) + if err != nil { + return nil, fmt.Errorf("failed to get API key: %w", err) } - // Re-fetch provider to get updated status - if err := m.client.Get(ctx, client.ObjectKeyFromObject(provider), provider); err != nil { - return nil, fmt.Errorf("failed to get updated provider: %w", err) + // Discover models directly + endpoint := provider.Spec.GetEndpoint() + models, err := m.discoverer.DiscoverModels(ctx, provider.Spec.Type, endpoint, apiKey) + if err != nil { + return nil, fmt.Errorf("model discovery failed: %w", err) } + + logger.Info("Model discovery completed", "provider", providerName, "count", len(models)) + return models, nil } // Return cached models from status @@ -129,20 +124,35 @@ func (m *Manager) GetModels(ctx context.Context, providerName string, forceRefre } // No models discovered - provide helpful message - if forceRefresh { - return nil, fmt.Errorf("no models discovered for provider %s", providerName) - } return nil, fmt.Errorf("no models discovered for provider %s, try refreshing", providerName) } -// ClearCache is a no-op for CRD-based implementation. -// Models are cached in Provider.Status.DiscoveredModels and cleared by the controller. -func (m *Manager) ClearCache(providerName string) { - // No-op - cache is managed by Provider controller in CRD status -} +// getAPIKey retrieves the API key from the secret referenced by the provider. +// Returns empty string for providers that don't require authentication (e.g., Ollama). +func (m *Manager) getAPIKey(ctx context.Context, provider *v1alpha2.Provider) (string, error) { + // Providers like Ollama don't require authentication + if !provider.Spec.RequiresSecret() { + return "", nil + } + + if provider.Spec.SecretRef == nil { + return "", fmt.Errorf("provider %s requires a secret but none is configured", provider.Name) + } + + secret := &corev1.Secret{} + secretName := types.NamespacedName{ + Namespace: provider.Namespace, + Name: provider.Spec.SecretRef.Name, + } + + if err := m.client.Get(ctx, secretName, secret); err != nil { + return "", fmt.Errorf("failed to get secret %s: %w", provider.Spec.SecretRef.Name, err) + } + + apiKey, ok := secret.Data[provider.Spec.SecretRef.Key] + if !ok || len(apiKey) == 0 { + return "", fmt.Errorf("secret %s missing key %s", provider.Spec.SecretRef.Name, provider.Spec.SecretRef.Key) + } -// HasProviders returns true if any Ready providers are configured. -func (m *Manager) HasProviders() bool { - providers := m.GetProviders() - return len(providers) > 0 + return string(apiKey), nil } diff --git a/go/internal/controller/provider/manager_test.go b/go/internal/controller/provider/manager_test.go index 0da10ffd9..71b4eff13 100644 --- a/go/internal/controller/provider/manager_test.go +++ b/go/internal/controller/provider/manager_test.go @@ -27,22 +27,24 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" ) +const testNamespace = "kagent" + func TestNewManager(t *testing.T) { tests := []struct { name string namespace string wantNamespace string }{ - { - name: "with default namespace", - namespace: "", - wantNamespace: DefaultNamespace, - }, { name: "with custom namespace", namespace: "custom-ns", wantNamespace: "custom-ns", }, + { + name: "with empty namespace uses default from env or kagent", + namespace: "", + wantNamespace: "kagent", // Default when KAGENT_NAMESPACE env var is not set + }, } for _, tt := range tests { @@ -60,6 +62,9 @@ func TestNewManager(t *testing.T) { if m.client == nil { t.Error("client should be initialized") } + if m.discoverer == nil { + t.Error("discoverer should be initialized") + } }) } } @@ -78,17 +83,17 @@ func TestGetProviders(t *testing.T) { wantNames: []string{}, }, { - name: "single ready provider", + name: "single ready provider with secret", providers: []*v1alpha2.Provider{ { ObjectMeta: metav1.ObjectMeta{ Name: "openai-prod", - Namespace: DefaultNamespace, + Namespace: testNamespace, }, Spec: v1alpha2.ProviderSpec{ Type: v1alpha2.ModelProviderOpenAI, Endpoint: "https://api.openai.com/v1", - SecretRef: v1alpha2.SecretReference{ + SecretRef: &v1alpha2.SecretReference{ Name: "openai-secret", Key: "apiKey", }, @@ -106,18 +111,44 @@ func TestGetProviders(t *testing.T) { wantCount: 1, wantNames: []string{"openai-prod"}, }, + { + name: "ready provider without secret (ollama)", + providers: []*v1alpha2.Provider{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "ollama-local", + Namespace: testNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOllama, + // No endpoint - uses default + // No SecretRef - not required for Ollama + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + { + Type: v1alpha2.ProviderConditionTypeReady, + Status: metav1.ConditionTrue, + }, + }, + }, + }, + }, + wantCount: 1, + wantNames: []string{"ollama-local"}, + }, { name: "mixed ready and not ready providers", providers: []*v1alpha2.Provider{ { ObjectMeta: metav1.ObjectMeta{ Name: "openai-prod", - Namespace: DefaultNamespace, + Namespace: testNamespace, }, Spec: v1alpha2.ProviderSpec{ Type: v1alpha2.ModelProviderOpenAI, Endpoint: "https://api.openai.com/v1", - SecretRef: v1alpha2.SecretReference{ + SecretRef: &v1alpha2.SecretReference{ Name: "openai-secret", Key: "apiKey", }, @@ -134,12 +165,12 @@ func TestGetProviders(t *testing.T) { { ObjectMeta: metav1.ObjectMeta{ Name: "anthropic-prod", - Namespace: DefaultNamespace, + Namespace: testNamespace, }, Spec: v1alpha2.ProviderSpec{ Type: v1alpha2.ModelProviderAnthropic, Endpoint: "https://api.anthropic.com", - SecretRef: v1alpha2.SecretReference{ + SecretRef: &v1alpha2.SecretReference{ Name: "anthropic-secret", Key: "apiKey", }, @@ -175,7 +206,7 @@ func TestGetProviders(t *testing.T) { WithRuntimeObjects(objs...). Build() - m := NewManager(client, DefaultNamespace) + m := NewManager(client, testNamespace) providers := m.GetProviders() if len(providers) != tt.wantCount { @@ -211,12 +242,12 @@ func TestGetModels(t *testing.T) { provider: &v1alpha2.Provider{ ObjectMeta: metav1.ObjectMeta{ Name: "openai-prod", - Namespace: DefaultNamespace, + Namespace: testNamespace, }, Spec: v1alpha2.ProviderSpec{ Type: v1alpha2.ModelProviderOpenAI, Endpoint: "https://api.openai.com/v1", - SecretRef: v1alpha2.SecretReference{ + SecretRef: &v1alpha2.SecretReference{ Name: "openai-secret", Key: "apiKey", }, @@ -240,7 +271,7 @@ func TestGetModels(t *testing.T) { provider: &v1alpha2.Provider{ ObjectMeta: metav1.ObjectMeta{ Name: "different-provider", - Namespace: DefaultNamespace, + Namespace: testNamespace, }, }, forceRefresh: false, @@ -252,12 +283,12 @@ func TestGetModels(t *testing.T) { provider: &v1alpha2.Provider{ ObjectMeta: metav1.ObjectMeta{ Name: "empty-provider", - Namespace: DefaultNamespace, + Namespace: testNamespace, }, Spec: v1alpha2.ProviderSpec{ Type: v1alpha2.ModelProviderOpenAI, Endpoint: "https://api.openai.com/v1", - SecretRef: v1alpha2.SecretReference{ + SecretRef: &v1alpha2.SecretReference{ Name: "openai-secret", Key: "apiKey", }, @@ -283,7 +314,7 @@ func TestGetModels(t *testing.T) { WithRuntimeObjects(tt.provider). Build() - m := NewManager(client, DefaultNamespace) + m := NewManager(client, testNamespace) // Use the provider name from the test case providerName := "openai-prod" @@ -308,72 +339,76 @@ func TestGetModels(t *testing.T) { } } -func TestHasProviders(t *testing.T) { +func TestGetProviderDefaultEndpoint(t *testing.T) { tests := []struct { - name string - providers []*v1alpha2.Provider - want bool + name string + provider *v1alpha2.Provider + wantEndpoint string }{ { - name: "no providers", - providers: []*v1alpha2.Provider{}, - want: false, - }, - { - name: "has ready provider", - providers: []*v1alpha2.Provider{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "openai-prod", - Namespace: DefaultNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOpenAI, - Endpoint: "https://api.openai.com/v1", - SecretRef: v1alpha2.SecretReference{ - Name: "openai-secret", - Key: "apiKey", - }, + name: "OpenAI with default endpoint", + provider: &v1alpha2.Provider{ + ObjectMeta: metav1.ObjectMeta{ + Name: "openai", + Namespace: testNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOpenAI, + // No endpoint specified + SecretRef: &v1alpha2.SecretReference{ + Name: "secret", + Key: "key", }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - { - Type: v1alpha2.ProviderConditionTypeReady, - Status: metav1.ConditionTrue, - }, - }, + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + {Type: v1alpha2.ProviderConditionTypeReady, Status: metav1.ConditionTrue}, }, }, }, - want: true, + wantEndpoint: "https://api.openai.com/v1", }, { - name: "has only not ready provider", - providers: []*v1alpha2.Provider{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "anthropic-prod", - Namespace: DefaultNamespace, + name: "Ollama with default endpoint", + provider: &v1alpha2.Provider{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ollama", + Namespace: testNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOllama, + // No endpoint, no secret + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + {Type: v1alpha2.ProviderConditionTypeReady, Status: metav1.ConditionTrue}, }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderAnthropic, - Endpoint: "https://api.anthropic.com", - SecretRef: v1alpha2.SecretReference{ - Name: "anthropic-secret", - Key: "apiKey", - }, + }, + }, + wantEndpoint: "http://localhost:11434", + }, + { + name: "OpenAI with custom endpoint", + provider: &v1alpha2.Provider{ + ObjectMeta: metav1.ObjectMeta{ + Name: "openai-custom", + Namespace: testNamespace, + }, + Spec: v1alpha2.ProviderSpec{ + Type: v1alpha2.ModelProviderOpenAI, + Endpoint: "https://custom.openai.com/v1", + SecretRef: &v1alpha2.SecretReference{ + Name: "secret", + Key: "key", }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - { - Type: v1alpha2.ProviderConditionTypeReady, - Status: metav1.ConditionFalse, - }, - }, + }, + Status: v1alpha2.ProviderStatus{ + Conditions: []metav1.Condition{ + {Type: v1alpha2.ProviderConditionTypeReady, Status: metav1.ConditionTrue}, }, }, }, - want: false, + wantEndpoint: "https://custom.openai.com/v1", }, } @@ -383,20 +418,20 @@ func TestHasProviders(t *testing.T) { _ = v1alpha2.AddToScheme(scheme) _ = corev1.AddToScheme(scheme) - objs := make([]runtime.Object, len(tt.providers)) - for i, p := range tt.providers { - objs[i] = p - } - client := fake.NewClientBuilder(). WithScheme(scheme). - WithRuntimeObjects(objs...). + WithRuntimeObjects(tt.provider). Build() - m := NewManager(client, DefaultNamespace) + m := NewManager(client, testNamespace) + providers := m.GetProviders() + + if len(providers) != 1 { + t.Fatalf("expected 1 provider, got %d", len(providers)) + } - if got := m.HasProviders(); got != tt.want { - t.Errorf("HasProviders() = %v, want %v", got, tt.want) + if providers[0].Endpoint != tt.wantEndpoint { + t.Errorf("endpoint = %v, want %v", providers[0].Endpoint, tt.wantEndpoint) } }) } diff --git a/go/internal/controller/reconciler/reconciler.go b/go/internal/controller/reconciler/reconciler.go index 359ddaba7..37ea97b33 100644 --- a/go/internal/controller/reconciler/reconciler.go +++ b/go/internal/controller/reconciler/reconciler.go @@ -923,27 +923,33 @@ func (a *kagentReconciler) ReconcileKagentProvider(ctx context.Context, req ctrl return ctrl.Result{}, fmt.Errorf("failed to get provider %s: %w", req.NamespacedName, err) } - // Validate and resolve secret - secret, secretHash, secretErr := a.validateProviderSecret(ctx, p) + // Validate and resolve secret, get API key in one pass + apiKey, secretHash, secretErr := a.resolveProviderSecret(ctx, p) // Discover models if needed var models []string var discoveryErr error if a.shouldDiscoverModels(p) { - models, discoveryErr = a.discoverProviderModels(ctx, p, secret) + models, discoveryErr = a.discoverProviderModels(ctx, p, apiKey) } else { // Keep existing cached models models = p.Status.DiscoveredModels } - // Update status with results + // Update status with results (status subresource only, no object modification) return a.updateProviderStatus(ctx, p, secretErr, discoveryErr, models, secretHash) } -// validateProviderSecret fetches the Secret and computes its hash -func (a *kagentReconciler) validateProviderSecret(ctx context.Context, p *v1alpha2.Provider) (*corev1.Secret, string, error) { - if p.Spec.SecretRef.Name == "" { - return nil, "", fmt.Errorf("provider %s has no secret reference", p.Name) +// resolveProviderSecret fetches the Secret, validates it, and returns the API key and hash. +// For providers that don't require authentication (e.g., Ollama), returns empty apiKey with no error. +func (a *kagentReconciler) resolveProviderSecret(ctx context.Context, p *v1alpha2.Provider) (string, string, error) { + // Providers like Ollama don't require authentication + if !p.Spec.RequiresSecret() { + return "", "", nil + } + + if p.Spec.SecretRef == nil { + return "", "", fmt.Errorf("provider %s requires a secret but none is configured", p.Name) } secret := &corev1.Secret{} @@ -953,23 +959,24 @@ func (a *kagentReconciler) validateProviderSecret(ctx context.Context, p *v1alph } if err := a.kube.Get(ctx, namespacedName, secret); err != nil { - return nil, "", fmt.Errorf("failed to get secret %s: %w", p.Spec.SecretRef.Name, err) + return "", "", fmt.Errorf("failed to get secret %s: %w", p.Spec.SecretRef.Name, err) } - // Check that the specified key exists + // Check that the specified key exists and has a value key := p.Spec.SecretRef.Key if key == "" { - return nil, "", fmt.Errorf("provider %s has no secret key specified", p.Name) + return "", "", fmt.Errorf("provider %s has no secret key specified", p.Name) } - if _, ok := secret.Data[key]; !ok { - return nil, "", fmt.Errorf("secret %s missing key %s", p.Spec.SecretRef.Name, key) + apiKey, ok := secret.Data[key] + if !ok || len(apiKey) == 0 { + return "", "", fmt.Errorf("secret %s missing or empty key %s", p.Spec.SecretRef.Name, key) } // Compute secret hash for change detection secretHash := computeProviderSecretHash(secret, key) - return secret, secretHash, nil + return string(apiKey), secretHash, nil } // computeProviderSecretHash computes a hash of the secret data for change detection @@ -986,38 +993,34 @@ func computeProviderSecretHash(secret *corev1.Secret, key string) string { // shouldDiscoverModels checks if model discovery is needed func (a *kagentReconciler) shouldDiscoverModels(p *v1alpha2.Provider) bool { - // 1. Force refresh via annotation (UI "Fetch Models" button) - if p.Annotations != nil && p.Annotations[v1alpha2.ProviderAnnotationForceDiscovery] == "true" { + // Initial discovery when Provider is first created or spec changed + if p.Status.LastDiscoveryTime == nil { return true } - // 2. Initial discovery when Provider is first created - if p.Status.LastDiscoveryTime == nil { + // Re-discover if the generation changed (spec was updated) + if p.Status.ObservedGeneration != p.Generation { return true } - // No periodic discovery - only on-demand + // No periodic discovery - only on-demand via HTTP API return false } // discoverProviderModels calls the model discoverer to fetch models -func (a *kagentReconciler) discoverProviderModels(ctx context.Context, p *v1alpha2.Provider, secret *corev1.Secret) ([]string, error) { - if secret == nil { - return nil, fmt.Errorf("cannot discover models: secret not available") +func (a *kagentReconciler) discoverProviderModels(ctx context.Context, p *v1alpha2.Provider, apiKey string) ([]string, error) { + // For providers that require auth, ensure we have an API key + if p.Spec.RequiresSecret() && apiKey == "" { + return nil, fmt.Errorf("cannot discover models: API key not available") } - // Get API key from secret - apiKey, ok := secret.Data[p.Spec.SecretRef.Key] - if !ok || len(apiKey) == 0 { - return nil, fmt.Errorf("secret %s has empty value for key %s", p.Spec.SecretRef.Name, p.Spec.SecretRef.Key) - } - - // Use the provider package's ModelDiscoverer + // Use the provider package's ModelDiscoverer with the resolved endpoint discoverer := provider.NewModelDiscoverer() - return discoverer.DiscoverModels(ctx, p.Spec.Type, p.Spec.Endpoint, string(apiKey)) + return discoverer.DiscoverModels(ctx, p.Spec.Type, p.Spec.GetEndpoint(), apiKey) } -// updateProviderStatus updates the Provider status based on reconciliation results +// updateProviderStatus updates the Provider status based on reconciliation results. +// Only modifies the status subresource - never modifies the Provider object itself. func (a *kagentReconciler) updateProviderStatus( ctx context.Context, p *v1alpha2.Provider, @@ -1025,14 +1028,29 @@ func (a *kagentReconciler) updateProviderStatus( models []string, secretHash string, ) (ctrl.Result, error) { + // For providers that don't require secrets, mark SecretResolved as true + secretRequired := p.Spec.RequiresSecret() + secretResolved := !secretRequired || secretErr == nil + // Update SecretResolved condition - meta.SetStatusCondition(&p.Status.Conditions, metav1.Condition{ - Type: v1alpha2.ProviderConditionTypeSecretResolved, - Status: conditionStatus(secretErr == nil), - Reason: conditionReason(secretErr, "SecretResolved", "SecretNotFound"), - Message: conditionMessage(secretErr, "Secret resolved successfully"), - ObservedGeneration: p.Generation, - }) + if secretRequired { + meta.SetStatusCondition(&p.Status.Conditions, metav1.Condition{ + Type: v1alpha2.ProviderConditionTypeSecretResolved, + Status: conditionStatus(secretErr == nil), + Reason: conditionReason(secretErr, "SecretResolved", "SecretNotFound"), + Message: conditionMessage(secretErr, "Secret resolved successfully"), + ObservedGeneration: p.Generation, + }) + } else { + // Provider doesn't require a secret (e.g., Ollama) + meta.SetStatusCondition(&p.Status.Conditions, metav1.Condition{ + Type: v1alpha2.ProviderConditionTypeSecretResolved, + Status: metav1.ConditionTrue, + Reason: "SecretNotRequired", + Message: "Provider does not require authentication", + ObservedGeneration: p.Generation, + }) + } // Update ModelsDiscovered condition modelsDiscovered := discoveryErr == nil && len(models) > 0 @@ -1045,7 +1063,7 @@ func (a *kagentReconciler) updateProviderStatus( }) // Update Ready condition (overall health) - ready := secretErr == nil && modelsDiscovered + ready := secretResolved && modelsDiscovered meta.SetStatusCondition(&p.Status.Conditions, metav1.Condition{ Type: v1alpha2.ProviderConditionTypeReady, Status: conditionStatus(ready), @@ -1065,20 +1083,12 @@ func (a *kagentReconciler) updateProviderStatus( p.Status.LastDiscoveryTime = &now } - // Clear force-discovery annotation if set - if p.Annotations != nil && p.Annotations[v1alpha2.ProviderAnnotationForceDiscovery] == "true" { - delete(p.Annotations, v1alpha2.ProviderAnnotationForceDiscovery) - if err := a.kube.Update(ctx, p); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to clear force-discovery annotation: %w", err) - } - } - - // Update status subresource + // Update status subresource only - never modify the Provider object itself if err := a.kube.Status().Update(ctx, p); err != nil { return ctrl.Result{}, fmt.Errorf("failed to update provider status: %w", err) } - // No periodic requeue - discovery only on-demand + // No periodic requeue - discovery only on-demand via HTTP API return ctrl.Result{}, nil } From 68df742271302055f4cafb1f8d26a11d18f5c05b Mon Sep 17 00:00:00 2001 From: Nagarjun Krishnan Date: Wed, 11 Feb 2026 14:16:34 -0700 Subject: [PATCH 3/3] more fixes Signed-off-by: Nagarjun Krishnan --- .../mcp_server_tool_controller_test.go | 4 + go/internal/controller/provider/manager.go | 158 ------- .../controller/provider/manager_test.go | 438 ------------------ go/internal/controller/provider/types.go | 59 --- .../controller/reconciler/reconciler.go | 31 ++ .../controller/service_controller_test.go | 4 + go/internal/httpserver/handlers/handlers.go | 7 +- go/internal/httpserver/handlers/providers.go | 98 ++-- go/internal/httpserver/server.go | 6 +- go/pkg/app/app.go | 7 +- 10 files changed, 114 insertions(+), 698 deletions(-) delete mode 100644 go/internal/controller/provider/manager.go delete mode 100644 go/internal/controller/provider/manager_test.go delete mode 100644 go/internal/controller/provider/types.go diff --git a/go/internal/controller/mcp_server_tool_controller_test.go b/go/internal/controller/mcp_server_tool_controller_test.go index 35f518738..9d50185d0 100644 --- a/go/internal/controller/mcp_server_tool_controller_test.go +++ b/go/internal/controller/mcp_server_tool_controller_test.go @@ -43,6 +43,10 @@ func (f *fakeReconciler) ReconcileKagentProvider(ctx context.Context, req ctrl.R return ctrl.Result{}, nil } +func (f *fakeReconciler) RefreshProviderModels(ctx context.Context, namespace, name string) ([]string, error) { + return nil, nil +} + func (f *fakeReconciler) GetOwnedResourceTypes() []client.Object { return nil } diff --git a/go/internal/controller/provider/manager.go b/go/internal/controller/provider/manager.go deleted file mode 100644 index 45d34ddbb..000000000 --- a/go/internal/controller/provider/manager.go +++ /dev/null @@ -1,158 +0,0 @@ -/* -Copyright 2025. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package provider - -import ( - "context" - "fmt" - - "github.com/kagent-dev/kagent/go/api/v1alpha2" - "github.com/kagent-dev/kagent/go/internal/utils" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/meta" - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/log" -) - -// Manager handles provider configuration and model discovery. -// It reads provider configs from Provider CRDs and caches discovered models in CRD status. -type Manager struct { - client client.Client - namespace string - discoverer *ModelDiscoverer -} - -// NewManager creates a new provider Manager instance. -// If namespace is empty, it uses utils.GetResourceNamespace() which reads from -// KAGENT_NAMESPACE environment variable or defaults to "kagent". -func NewManager(client client.Client, namespace string) *Manager { - if namespace == "" { - namespace = utils.GetResourceNamespace() - } - return &Manager{ - client: client, - namespace: namespace, - discoverer: NewModelDiscoverer(), - } -} - -// GetProviders returns all configured providers that are Ready. -func (m *Manager) GetProviders() []ProviderConfig { - ctx := context.Background() - var providerList v1alpha2.ProviderList - if err := m.client.List(ctx, &providerList, - client.InNamespace(m.namespace)); err != nil { - return nil - } - - var providers []ProviderConfig - for _, p := range providerList.Items { - // Only include Ready providers - if meta.IsStatusConditionTrue(p.Status.Conditions, - v1alpha2.ProviderConditionTypeReady) { - config := ProviderConfig{ - Name: p.Name, - Type: p.Spec.Type, - Endpoint: p.Spec.GetEndpoint(), - } - // Only set SecretRef if it's specified - if p.Spec.SecretRef != nil { - config.SecretRef = SecretReference{ - Name: p.Spec.SecretRef.Name, - Key: p.Spec.SecretRef.Key, - } - } - providers = append(providers, config) - } - } - - return providers -} - -// GetModels returns models for a provider from the cached status or performs direct discovery. -// Models are cached in Provider.Status.DiscoveredModels by the Provider controller. -// If forceRefresh is true, performs direct discovery and updates the provider status. -func (m *Manager) GetModels(ctx context.Context, providerName string, forceRefresh bool) ([]string, error) { - logger := log.FromContext(ctx).WithName("provider-manager") - - provider := &v1alpha2.Provider{} - if err := m.client.Get(ctx, types.NamespacedName{ - Name: providerName, Namespace: m.namespace, - }, provider); err != nil { - return nil, fmt.Errorf("provider %s not found: %w", providerName, err) - } - - // If force refresh, perform direct discovery - if forceRefresh { - logger.Info("Performing direct model discovery", "provider", providerName) - - // Get API key from secret if required - apiKey, err := m.getAPIKey(ctx, provider) - if err != nil { - return nil, fmt.Errorf("failed to get API key: %w", err) - } - - // Discover models directly - endpoint := provider.Spec.GetEndpoint() - models, err := m.discoverer.DiscoverModels(ctx, provider.Spec.Type, endpoint, apiKey) - if err != nil { - return nil, fmt.Errorf("model discovery failed: %w", err) - } - - logger.Info("Model discovery completed", "provider", providerName, "count", len(models)) - return models, nil - } - - // Return cached models from status - if len(provider.Status.DiscoveredModels) > 0 { - return provider.Status.DiscoveredModels, nil - } - - // No models discovered - provide helpful message - return nil, fmt.Errorf("no models discovered for provider %s, try refreshing", providerName) -} - -// getAPIKey retrieves the API key from the secret referenced by the provider. -// Returns empty string for providers that don't require authentication (e.g., Ollama). -func (m *Manager) getAPIKey(ctx context.Context, provider *v1alpha2.Provider) (string, error) { - // Providers like Ollama don't require authentication - if !provider.Spec.RequiresSecret() { - return "", nil - } - - if provider.Spec.SecretRef == nil { - return "", fmt.Errorf("provider %s requires a secret but none is configured", provider.Name) - } - - secret := &corev1.Secret{} - secretName := types.NamespacedName{ - Namespace: provider.Namespace, - Name: provider.Spec.SecretRef.Name, - } - - if err := m.client.Get(ctx, secretName, secret); err != nil { - return "", fmt.Errorf("failed to get secret %s: %w", provider.Spec.SecretRef.Name, err) - } - - apiKey, ok := secret.Data[provider.Spec.SecretRef.Key] - if !ok || len(apiKey) == 0 { - return "", fmt.Errorf("secret %s missing key %s", provider.Spec.SecretRef.Name, provider.Spec.SecretRef.Key) - } - - return string(apiKey), nil -} diff --git a/go/internal/controller/provider/manager_test.go b/go/internal/controller/provider/manager_test.go deleted file mode 100644 index 71b4eff13..000000000 --- a/go/internal/controller/provider/manager_test.go +++ /dev/null @@ -1,438 +0,0 @@ -/* -Copyright 2025. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package provider - -import ( - "context" - "testing" - - v1alpha2 "github.com/kagent-dev/kagent/go/api/v1alpha2" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "sigs.k8s.io/controller-runtime/pkg/client/fake" -) - -const testNamespace = "kagent" - -func TestNewManager(t *testing.T) { - tests := []struct { - name string - namespace string - wantNamespace string - }{ - { - name: "with custom namespace", - namespace: "custom-ns", - wantNamespace: "custom-ns", - }, - { - name: "with empty namespace uses default from env or kagent", - namespace: "", - wantNamespace: "kagent", // Default when KAGENT_NAMESPACE env var is not set - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - scheme := runtime.NewScheme() - _ = v1alpha2.AddToScheme(scheme) - _ = corev1.AddToScheme(scheme) - client := fake.NewClientBuilder().WithScheme(scheme).Build() - - m := NewManager(client, tt.namespace) - - if m.namespace != tt.wantNamespace { - t.Errorf("namespace = %v, want %v", m.namespace, tt.wantNamespace) - } - if m.client == nil { - t.Error("client should be initialized") - } - if m.discoverer == nil { - t.Error("discoverer should be initialized") - } - }) - } -} - -func TestGetProviders(t *testing.T) { - tests := []struct { - name string - providers []*v1alpha2.Provider - wantCount int - wantNames []string - }{ - { - name: "no providers", - providers: []*v1alpha2.Provider{}, - wantCount: 0, - wantNames: []string{}, - }, - { - name: "single ready provider with secret", - providers: []*v1alpha2.Provider{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "openai-prod", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOpenAI, - Endpoint: "https://api.openai.com/v1", - SecretRef: &v1alpha2.SecretReference{ - Name: "openai-secret", - Key: "apiKey", - }, - }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - { - Type: v1alpha2.ProviderConditionTypeReady, - Status: metav1.ConditionTrue, - }, - }, - }, - }, - }, - wantCount: 1, - wantNames: []string{"openai-prod"}, - }, - { - name: "ready provider without secret (ollama)", - providers: []*v1alpha2.Provider{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "ollama-local", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOllama, - // No endpoint - uses default - // No SecretRef - not required for Ollama - }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - { - Type: v1alpha2.ProviderConditionTypeReady, - Status: metav1.ConditionTrue, - }, - }, - }, - }, - }, - wantCount: 1, - wantNames: []string{"ollama-local"}, - }, - { - name: "mixed ready and not ready providers", - providers: []*v1alpha2.Provider{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "openai-prod", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOpenAI, - Endpoint: "https://api.openai.com/v1", - SecretRef: &v1alpha2.SecretReference{ - Name: "openai-secret", - Key: "apiKey", - }, - }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - { - Type: v1alpha2.ProviderConditionTypeReady, - Status: metav1.ConditionTrue, - }, - }, - }, - }, - { - ObjectMeta: metav1.ObjectMeta{ - Name: "anthropic-prod", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderAnthropic, - Endpoint: "https://api.anthropic.com", - SecretRef: &v1alpha2.SecretReference{ - Name: "anthropic-secret", - Key: "apiKey", - }, - }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - { - Type: v1alpha2.ProviderConditionTypeReady, - Status: metav1.ConditionFalse, - }, - }, - }, - }, - }, - wantCount: 1, // Only ready provider should be returned - wantNames: []string{"openai-prod"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - scheme := runtime.NewScheme() - _ = v1alpha2.AddToScheme(scheme) - _ = corev1.AddToScheme(scheme) - - objs := make([]runtime.Object, len(tt.providers)) - for i, p := range tt.providers { - objs[i] = p - } - - client := fake.NewClientBuilder(). - WithScheme(scheme). - WithRuntimeObjects(objs...). - Build() - - m := NewManager(client, testNamespace) - providers := m.GetProviders() - - if len(providers) != tt.wantCount { - t.Errorf("GetProviders() returned %d providers, want %d", len(providers), tt.wantCount) - } - - for _, wantName := range tt.wantNames { - found := false - for _, p := range providers { - if p.Name == wantName { - found = true - break - } - } - if !found { - t.Errorf("Expected provider %s not found in results", wantName) - } - } - }) - } -} - -func TestGetModels(t *testing.T) { - tests := []struct { - name string - provider *v1alpha2.Provider - forceRefresh bool - wantModels []string - wantErr bool - }{ - { - name: "provider with cached models", - provider: &v1alpha2.Provider{ - ObjectMeta: metav1.ObjectMeta{ - Name: "openai-prod", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOpenAI, - Endpoint: "https://api.openai.com/v1", - SecretRef: &v1alpha2.SecretReference{ - Name: "openai-secret", - Key: "apiKey", - }, - }, - Status: v1alpha2.ProviderStatus{ - DiscoveredModels: []string{"gpt-4", "gpt-3.5-turbo"}, - Conditions: []metav1.Condition{ - { - Type: v1alpha2.ProviderConditionTypeReady, - Status: metav1.ConditionTrue, - }, - }, - }, - }, - forceRefresh: false, - wantModels: []string{"gpt-4", "gpt-3.5-turbo"}, - wantErr: false, - }, - { - name: "provider not found", - provider: &v1alpha2.Provider{ - ObjectMeta: metav1.ObjectMeta{ - Name: "different-provider", - Namespace: testNamespace, - }, - }, - forceRefresh: false, - wantModels: nil, - wantErr: true, - }, - { - name: "provider without models", - provider: &v1alpha2.Provider{ - ObjectMeta: metav1.ObjectMeta{ - Name: "empty-provider", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOpenAI, - Endpoint: "https://api.openai.com/v1", - SecretRef: &v1alpha2.SecretReference{ - Name: "openai-secret", - Key: "apiKey", - }, - }, - Status: v1alpha2.ProviderStatus{ - DiscoveredModels: []string{}, - }, - }, - forceRefresh: false, - wantModels: nil, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - scheme := runtime.NewScheme() - _ = v1alpha2.AddToScheme(scheme) - _ = corev1.AddToScheme(scheme) - - client := fake.NewClientBuilder(). - WithScheme(scheme). - WithRuntimeObjects(tt.provider). - Build() - - m := NewManager(client, testNamespace) - - // Use the provider name from the test case - providerName := "openai-prod" - switch tt.name { - case "provider not found": - providerName = "nonexistent-provider" - case "provider without models": - providerName = "empty-provider" - } - - models, err := m.GetModels(context.Background(), providerName, tt.forceRefresh) - - if (err != nil) != tt.wantErr { - t.Errorf("GetModels() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr && len(models) != len(tt.wantModels) { - t.Errorf("GetModels() returned %d models, want %d", len(models), len(tt.wantModels)) - } - }) - } -} - -func TestGetProviderDefaultEndpoint(t *testing.T) { - tests := []struct { - name string - provider *v1alpha2.Provider - wantEndpoint string - }{ - { - name: "OpenAI with default endpoint", - provider: &v1alpha2.Provider{ - ObjectMeta: metav1.ObjectMeta{ - Name: "openai", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOpenAI, - // No endpoint specified - SecretRef: &v1alpha2.SecretReference{ - Name: "secret", - Key: "key", - }, - }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - {Type: v1alpha2.ProviderConditionTypeReady, Status: metav1.ConditionTrue}, - }, - }, - }, - wantEndpoint: "https://api.openai.com/v1", - }, - { - name: "Ollama with default endpoint", - provider: &v1alpha2.Provider{ - ObjectMeta: metav1.ObjectMeta{ - Name: "ollama", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOllama, - // No endpoint, no secret - }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - {Type: v1alpha2.ProviderConditionTypeReady, Status: metav1.ConditionTrue}, - }, - }, - }, - wantEndpoint: "http://localhost:11434", - }, - { - name: "OpenAI with custom endpoint", - provider: &v1alpha2.Provider{ - ObjectMeta: metav1.ObjectMeta{ - Name: "openai-custom", - Namespace: testNamespace, - }, - Spec: v1alpha2.ProviderSpec{ - Type: v1alpha2.ModelProviderOpenAI, - Endpoint: "https://custom.openai.com/v1", - SecretRef: &v1alpha2.SecretReference{ - Name: "secret", - Key: "key", - }, - }, - Status: v1alpha2.ProviderStatus{ - Conditions: []metav1.Condition{ - {Type: v1alpha2.ProviderConditionTypeReady, Status: metav1.ConditionTrue}, - }, - }, - }, - wantEndpoint: "https://custom.openai.com/v1", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - scheme := runtime.NewScheme() - _ = v1alpha2.AddToScheme(scheme) - _ = corev1.AddToScheme(scheme) - - client := fake.NewClientBuilder(). - WithScheme(scheme). - WithRuntimeObjects(tt.provider). - Build() - - m := NewManager(client, testNamespace) - providers := m.GetProviders() - - if len(providers) != 1 { - t.Fatalf("expected 1 provider, got %d", len(providers)) - } - - if providers[0].Endpoint != tt.wantEndpoint { - t.Errorf("endpoint = %v, want %v", providers[0].Endpoint, tt.wantEndpoint) - } - }) - } -} diff --git a/go/internal/controller/provider/types.go b/go/internal/controller/provider/types.go deleted file mode 100644 index 728a4c8e6..000000000 --- a/go/internal/controller/provider/types.go +++ /dev/null @@ -1,59 +0,0 @@ -/* -Copyright 2025. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package provider - -import ( - v1alpha2 "github.com/kagent-dev/kagent/go/api/v1alpha2" -) - -// ProviderConfig represents a configured LLM provider instance. -// Multiple ProviderConfigs can exist for the same provider type (e.g., two OpenAI instances). -type ProviderConfig struct { - // Name is the unique identifier for this provider instance - Name string `yaml:"name" json:"name"` - - // Type is the provider type (OpenAI, Anthropic, etc.) - Type v1alpha2.ModelProvider `yaml:"type" json:"type"` - - // Endpoint is the base URL for the provider API - Endpoint string `yaml:"endpoint" json:"endpoint"` - - // SecretRef references the Kubernetes Secret containing the API key - SecretRef SecretReference `yaml:"secretRef" json:"secretRef"` -} - -// SecretReference points to a specific key within a Kubernetes Secret -type SecretReference struct { - // Name is the name of the Secret - Name string `yaml:"name" json:"name"` - - // Key is the key within the Secret data - Key string `yaml:"key" json:"key"` -} - -// ProviderResponse is the API response format for listing providers -type ProviderResponse struct { - Name string `json:"name"` - Type string `json:"type"` - Endpoint string `json:"endpoint"` -} - -// ModelsResponse is the API response format for listing models -type ModelsResponse struct { - Provider string `json:"provider"` - Models []string `json:"models"` -} diff --git a/go/internal/controller/reconciler/reconciler.go b/go/internal/controller/reconciler/reconciler.go index 37ea97b33..4435d01c3 100644 --- a/go/internal/controller/reconciler/reconciler.go +++ b/go/internal/controller/reconciler/reconciler.go @@ -47,6 +47,7 @@ type KagentReconciler interface { ReconcileKagentMCPService(ctx context.Context, req ctrl.Request) error ReconcileKagentMCPServer(ctx context.Context, req ctrl.Request) error ReconcileKagentProvider(ctx context.Context, req ctrl.Request) (ctrl.Result, error) + RefreshProviderModels(ctx context.Context, namespace, name string) ([]string, error) GetOwnedResourceTypes() []client.Object } @@ -1113,3 +1114,33 @@ func conditionMessage(err error, successMessage string) string { } return successMessage } + +// RefreshProviderModels forces a fresh model discovery for a provider and updates its status. +// This is called by the HTTP API when refresh=true is requested. +// It reuses all existing internal reconciler methods - no code duplication. +func (a *kagentReconciler) RefreshProviderModels(ctx context.Context, namespace, name string) ([]string, error) { + p := &v1alpha2.Provider{} + if err := a.kube.Get(ctx, types.NamespacedName{Namespace: namespace, Name: name}, p); err != nil { + return nil, fmt.Errorf("failed to get provider %s/%s: %w", namespace, name, err) + } + + // Reuse existing secret resolution logic + apiKey, secretHash, secretErr := a.resolveProviderSecret(ctx, p) + if secretErr != nil { + return nil, fmt.Errorf("failed to resolve provider secret: %w", secretErr) + } + + // Force discovery by calling the existing method + models, discoveryErr := a.discoverProviderModels(ctx, p, apiKey) + if discoveryErr != nil { + return nil, fmt.Errorf("model discovery failed: %w", discoveryErr) + } + + // Update status using existing method (persists to CR) + _, err := a.updateProviderStatus(ctx, p, secretErr, discoveryErr, models, secretHash) + if err != nil { + return nil, fmt.Errorf("failed to update provider status: %w", err) + } + + return models, nil +} diff --git a/go/internal/controller/service_controller_test.go b/go/internal/controller/service_controller_test.go index 4a25d4838..17a6290fc 100644 --- a/go/internal/controller/service_controller_test.go +++ b/go/internal/controller/service_controller_test.go @@ -43,6 +43,10 @@ func (f *fakeServiceReconciler) ReconcileKagentProvider(ctx context.Context, req return ctrl.Result{}, nil } +func (f *fakeServiceReconciler) RefreshProviderModels(ctx context.Context, namespace, name string) ([]string, error) { + return nil, nil +} + func (f *fakeServiceReconciler) GetOwnedResourceTypes() []client.Object { return nil } diff --git a/go/internal/httpserver/handlers/handlers.go b/go/internal/httpserver/handlers/handlers.go index 70e4b276b..79cc63fb5 100644 --- a/go/internal/httpserver/handlers/handlers.go +++ b/go/internal/httpserver/handlers/handlers.go @@ -4,7 +4,7 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" - "github.com/kagent-dev/kagent/go/internal/controller/provider" + "github.com/kagent-dev/kagent/go/internal/controller/reconciler" "github.com/kagent-dev/kagent/go/internal/database" "github.com/kagent-dev/kagent/go/pkg/auth" ) @@ -38,8 +38,7 @@ type Base struct { } // NewHandlers creates a new Handlers instance with all handler components. -// providerManager can be nil if provider discovery is not enabled. -func NewHandlers(kubeClient client.Client, defaultModelConfig types.NamespacedName, dbService database.Client, watchedNamespaces []string, authorizer auth.Authorizer, proxyURL string, providerManager *provider.Manager) *Handlers { +func NewHandlers(kubeClient client.Client, defaultModelConfig types.NamespacedName, dbService database.Client, watchedNamespaces []string, authorizer auth.Authorizer, proxyURL string, rcnclr reconciler.KagentReconciler) *Handlers { base := &Base{ KubeClient: kubeClient, DefaultModelConfig: defaultModelConfig, @@ -52,7 +51,7 @@ func NewHandlers(kubeClient client.Client, defaultModelConfig types.NamespacedNa Health: NewHealthHandler(), ModelConfig: NewModelConfigHandler(base), Model: NewModelHandler(base), - Provider: NewProviderHandler(base, providerManager), + Provider: NewProviderHandler(base, rcnclr), Sessions: NewSessionsHandler(base), Agents: NewAgentsHandler(base), Tools: NewToolsHandler(base), diff --git a/go/internal/httpserver/handlers/providers.go b/go/internal/httpserver/handlers/providers.go index 083dd028c..5ef767b4b 100644 --- a/go/internal/httpserver/handlers/providers.go +++ b/go/internal/httpserver/handlers/providers.go @@ -6,22 +6,38 @@ import ( "github.com/kagent-dev/kagent/go/api/v1alpha1" "github.com/kagent-dev/kagent/go/api/v1alpha2" - "github.com/kagent-dev/kagent/go/internal/controller/provider" + "github.com/kagent-dev/kagent/go/internal/controller/reconciler" + "github.com/kagent-dev/kagent/go/internal/utils" "github.com/kagent-dev/kagent/go/pkg/client/api" + "k8s.io/apimachinery/pkg/api/meta" + "sigs.k8s.io/controller-runtime/pkg/client" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" ) +// ProviderResponse is the API response format for listing providers +type ProviderResponse struct { + Name string `json:"name"` + Type string `json:"type"` + Endpoint string `json:"endpoint"` +} + +// ModelsResponse is the API response format for listing models +type ModelsResponse struct { + Provider string `json:"provider"` + Models []string `json:"models"` +} + // ProviderHandler handles provider requests type ProviderHandler struct { *Base - providerManager *provider.Manager + reconciler reconciler.KagentReconciler } // NewProviderHandler creates a new ProviderHandler -func NewProviderHandler(base *Base, providerManager *provider.Manager) *ProviderHandler { +func NewProviderHandler(base *Base, rcnclr reconciler.KagentReconciler) *ProviderHandler { return &ProviderHandler{ - Base: base, - providerManager: providerManager, + Base: base, + reconciler: rcnclr, } } @@ -148,22 +164,25 @@ func (h *ProviderHandler) HandleListConfiguredProviders(w ErrorResponseWriter, r log.Info("Listing configured providers") - if h.providerManager == nil { - log.Info("Provider manager not initialized") - data := api.NewResponse([]provider.ProviderResponse{}, "Provider discovery not enabled", false) - RespondWithJSON(w, http.StatusOK, data) + // List Provider CRs directly from Kubernetes + namespace := utils.GetResourceNamespace() + var providerList v1alpha2.ProviderList + if err := h.KubeClient.List(r.Context(), &providerList, client.InNamespace(namespace)); err != nil { + log.Error(err, "Failed to list providers") + RespondWithError(w, http.StatusInternalServerError, err.Error()) return } - providers := h.providerManager.GetProviders() - - // Transform to API response format (hide sensitive data like secretRef) - response := make([]provider.ProviderResponse, len(providers)) - for i, p := range providers { - response[i] = provider.ProviderResponse{ - Name: p.Name, - Type: string(p.Type), - Endpoint: p.Endpoint, + // Filter for Ready providers and transform to API response format + var response []ProviderResponse + for _, p := range providerList.Items { + // Only include Ready providers + if meta.IsStatusConditionTrue(p.Status.Conditions, v1alpha2.ProviderConditionTypeReady) { + response = append(response, ProviderResponse{ + Name: p.Name, + Type: string(p.Spec.Type), + Endpoint: p.Spec.GetEndpoint(), + }) } } @@ -187,23 +206,42 @@ func (h *ProviderHandler) HandleGetProviderModels(w ErrorResponseWriter, r *http log = log.WithValues("provider", providerName) log.Info("Getting models for provider") - if h.providerManager == nil { - log.Info("Provider manager not initialized") - RespondWithError(w, http.StatusServiceUnavailable, "Provider discovery not enabled") - return - } - // Check for refresh query parameter forceRefresh := r.URL.Query().Get("refresh") == "true" - models, err := h.providerManager.GetModels(r.Context(), providerName, forceRefresh) - if err != nil { - log.Error(err, "Failed to get models for provider") - RespondWithError(w, http.StatusInternalServerError, err.Error()) - return + namespace := utils.GetResourceNamespace() + var models []string + if forceRefresh { + // Call reconciler to trigger fresh discovery + log.Info("Forcing fresh model discovery") + models, err = h.reconciler.RefreshProviderModels(r.Context(), namespace, providerName) + if err != nil { + log.Error(err, "Failed to refresh models for provider") + RespondWithError(w, http.StatusInternalServerError, err.Error()) + return + } + } else { + // Read cached models from Provider.Status + p := &v1alpha2.Provider{} + if err := h.KubeClient.Get(r.Context(), client.ObjectKey{ + Namespace: namespace, + Name: providerName, + }, p); err != nil { + log.Error(err, "Failed to get provider") + RespondWithError(w, http.StatusNotFound, err.Error()) + return + } + + if len(p.Status.DiscoveredModels) == 0 { + log.Info("No models discovered for provider, try refreshing") + RespondWithError(w, http.StatusNotFound, "No models discovered for provider, try refreshing") + return + } + + models = p.Status.DiscoveredModels } - response := provider.ModelsResponse{ + response := ModelsResponse{ Provider: providerName, Models: models, } diff --git a/go/internal/httpserver/server.go b/go/internal/httpserver/server.go index 16fd366fc..bb76d7a74 100644 --- a/go/internal/httpserver/server.go +++ b/go/internal/httpserver/server.go @@ -7,7 +7,7 @@ import ( "github.com/gorilla/mux" "github.com/kagent-dev/kagent/go/internal/a2a" - "github.com/kagent-dev/kagent/go/internal/controller/provider" + "github.com/kagent-dev/kagent/go/internal/controller/reconciler" "github.com/kagent-dev/kagent/go/internal/database" "github.com/kagent-dev/kagent/go/internal/httpserver/handlers" "github.com/kagent-dev/kagent/go/internal/mcp" @@ -60,7 +60,7 @@ type ServerConfig struct { Authenticator auth.AuthProvider Authorizer auth.Authorizer ProxyURL string - ProviderManager *provider.Manager // Optional: enables provider discovery endpoints + Reconciler reconciler.KagentReconciler } // HTTPServer is the structure that manages the HTTP server @@ -80,7 +80,7 @@ func NewHTTPServer(config ServerConfig) (*HTTPServer, error) { return &HTTPServer{ config: config, router: config.Router, - handlers: handlers.NewHandlers(config.KubeClient, defaultModelConfig, config.DbClient, config.WatchedNamespaces, config.Authorizer, config.ProxyURL, config.ProviderManager), + handlers: handlers.NewHandlers(config.KubeClient, defaultModelConfig, config.DbClient, config.WatchedNamespaces, config.Authorizer, config.ProxyURL, config.Reconciler), authenticator: config.Authenticator, }, nil } diff --git a/go/pkg/app/app.go b/go/pkg/app/app.go index 63d80be2b..2aadc37ee 100644 --- a/go/pkg/app/app.go +++ b/go/pkg/app/app.go @@ -41,7 +41,6 @@ import ( "github.com/kagent-dev/kagent/go/internal/mcp" versionmetrics "github.com/kagent-dev/kagent/go/internal/metrics" - "github.com/kagent-dev/kagent/go/internal/controller/provider" "github.com/kagent-dev/kagent/go/internal/controller/reconciler" reconcilerutils "github.com/kagent-dev/kagent/go/internal/controller/reconciler/utils" agent_translator "github.com/kagent-dev/kagent/go/internal/controller/translator/agent" @@ -503,10 +502,6 @@ func Start(getExtensionConfig GetExtensionConfig) { os.Exit(1) } - // Initialize provider manager for model/provider discovery - providerManager := provider.NewManager(mgr.GetClient(), kagentNamespace) - setupLog.Info("Initialized provider manager", "namespace", kagentNamespace) - httpServer, err := httpserver.NewHTTPServer(httpserver.ServerConfig{ Router: router, BindAddr: cfg.HttpServerAddr, @@ -518,7 +513,7 @@ func Start(getExtensionConfig GetExtensionConfig) { Authorizer: extensionCfg.Authorizer, Authenticator: extensionCfg.Authenticator, ProxyURL: cfg.Proxy.URL, - ProviderManager: providerManager, + Reconciler: rcnclr, }) if err != nil { setupLog.Error(err, "unable to create HTTP server")