From c9ea6ba0158a2d77846073fc79066db642f1cf5e Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Thu, 29 Jan 2026 16:06:08 -0500 Subject: [PATCH] feat: migrate Cerebras provider to @ai-sdk/cerebras - Rewrote cerebras.ts from 362 lines to ~170 lines using @ai-sdk/cerebras - Uses createCerebras from the dedicated Cerebras AI SDK package - Uses streamText/generateText from 'ai' package - Preserved X-Cerebras-3rd-Party-Integration: roocode attribution header - Updated tests to mock @ai-sdk/cerebras - Removed manual fetch/SSE parsing, TagMatcher, tool schema stripping --- pnpm-lock.yaml | 51 +- src/api/providers/__tests__/cerebras.spec.ts | 575 +++++++++++++------ src/api/providers/cerebras.ts | 428 ++++---------- src/package.json | 1 + 4 files changed, 583 insertions(+), 472 deletions(-) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 7930fe5352c..a7eebc66f10 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -743,9 +743,12 @@ importers: src: dependencies: + '@ai-sdk/cerebras': + specifier: ^1.0.0 + version: 1.0.35(zod@3.25.76) '@ai-sdk/deepseek': specifier: ^2.0.14 - version: 2.0.14(zod@3.25.76) + version: 2.0.15(zod@3.25.76) '@anthropic-ai/bedrock-sdk': specifier: ^0.10.2 version: 0.10.4 @@ -1390,8 +1393,14 @@ packages: '@adobe/css-tools@4.4.2': resolution: {integrity: sha512-baYZExFpsdkBNuvGKTKWCwKH57HRZLVtycZS05WTQNVOiXVSeAki3nU35zlRbToeMW8aHlJfyS+1C4BOv27q0A==} - '@ai-sdk/deepseek@2.0.14': - resolution: {integrity: sha512-1vXh8sVwRJYd1JO57qdy1rACucaNLDoBRCwOER3EbPgSF2vNVPcdJywGutA01Bhn7Cta+UJQ+k5y/yzMAIpP2w==} + '@ai-sdk/cerebras@1.0.35': + resolution: {integrity: sha512-JrNdMYptrOUjNthibgBeAcBjZ/H+fXb49sSrWhOx5Aq8eUcrYvwQ2DtSAi8VraHssZu78NAnBMrgFWSUOTXFxw==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + + '@ai-sdk/deepseek@2.0.15': + resolution: {integrity: sha512-3wJUjNjGrTZS3K8OEfHD1PZYhzkcXuoL8KIVtzi6WrC5xrDQPjCBPATmdKPV7DgDCF+wujQOaMz5cv40Yg+hog==} engines: {node: '>=18'} peerDependencies: zod: 3.25.76 @@ -1420,6 +1429,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.11': + resolution: {integrity: sha512-y/WOPpcZaBjvNaogy83mBsCRPvbtaK0y1sY9ckRrrbTGMvG2HC/9Y/huqNXKnLAxUIME2PGa2uvF2CDwIsxoXQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider@2.0.1': resolution: {integrity: sha512-KCUwswvsC5VsW2PWFqF8eJgSCu5Ysj7m1TxiHTVA6g7k360bk0RNQENT8KTMAYEs+8fWPD3Uu4dEmzGHc+jGng==} engines: {node: '>=18'} @@ -1428,6 +1443,10 @@ packages: resolution: {integrity: sha512-2Xmoq6DBJqmSl80U6V9z5jJSJP7ehaJJQMy2iFUqTay06wdCqTnPVBBQbtEL8RCChenL+q5DC5H5WzU3vV3v8w==} engines: {node: '>=18'} + '@ai-sdk/provider@3.0.6': + resolution: {integrity: sha512-hSfoJtLtpMd7YxKM+iTqlJ0ZB+kJ83WESMiWuWrNVey3X8gg97x0OdAAaeAeclZByCX3UdPOTqhvJdK8qYA3ww==} + engines: {node: '>=18'} + '@alcalzone/ansi-tokenize@0.2.3': resolution: {integrity: sha512-jsElTJ0sQ4wHRz+C45tfect76BwbTbgkgKByOzpCN9xG61N5V6u/glvg1CsNJhq2xJIFpKHSwG3D2wPPuEYOrQ==} engines: {node: '>=18'} @@ -10819,10 +10838,17 @@ snapshots: '@adobe/css-tools@4.4.2': {} - '@ai-sdk/deepseek@2.0.14(zod@3.25.76)': + '@ai-sdk/cerebras@1.0.35(zod@3.25.76)': dependencies: - '@ai-sdk/provider': 3.0.5 - '@ai-sdk/provider-utils': 4.0.10(zod@3.25.76) + '@ai-sdk/openai-compatible': 1.0.31(zod@3.25.76) + '@ai-sdk/provider': 2.0.1 + '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) + zod: 3.25.76 + + '@ai-sdk/deepseek@2.0.15(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.6 + '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) zod: 3.25.76 '@ai-sdk/gateway@3.0.25(zod@3.25.76)': @@ -10852,6 +10878,13 @@ snapshots: eventsource-parser: 3.0.6 zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.11(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.6 + '@standard-schema/spec': 1.1.0 + eventsource-parser: 3.0.6 + zod: 3.25.76 + '@ai-sdk/provider@2.0.1': dependencies: json-schema: 0.4.0 @@ -10860,6 +10893,10 @@ snapshots: dependencies: json-schema: 0.4.0 + '@ai-sdk/provider@3.0.6': + dependencies: + json-schema: 0.4.0 + '@alcalzone/ansi-tokenize@0.2.3': dependencies: ansi-styles: 6.2.3 @@ -14686,7 +14723,7 @@ snapshots: sirv: 3.0.1 tinyglobby: 0.2.14 tinyrainbow: 2.0.0 - vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.50)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.2.1)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) '@vitest/utils@3.2.4': dependencies: diff --git a/src/api/providers/__tests__/cerebras.spec.ts b/src/api/providers/__tests__/cerebras.spec.ts index 0915f449d0d..aefb8a599ce 100644 --- a/src/api/providers/__tests__/cerebras.spec.ts +++ b/src/api/providers/__tests__/cerebras.spec.ts @@ -1,249 +1,502 @@ -// Mock i18n -vi.mock("../../i18n", () => ({ - t: vi.fn((key: string, params?: Record) => { - // Return a simplified mock translation for testing - if (key.startsWith("common:errors.cerebras.")) { - return `Mocked: ${key.replace("common:errors.cerebras.", "")}` - } - return key - }), +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), })) -// Mock DEFAULT_HEADERS -vi.mock("../constants", () => ({ - DEFAULT_HEADERS: { - "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", - "X-Title": "Roo Code", - "User-Agent": "RooCode/1.0.0", - }, +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) + +vi.mock("@ai-sdk/cerebras", () => ({ + createCerebras: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "llama-3.3-70b", + provider: "cerebras", + })) + }), })) -import { CerebrasHandler } from "../cerebras" -import { cerebrasModels, type CerebrasModelId } from "@roo-code/types" +import type { Anthropic } from "@anthropic-ai/sdk" + +import { cerebrasDefaultModelId, cerebrasModels, type CerebrasModelId } from "@roo-code/types" -// Mock fetch globally -global.fetch = vi.fn() +import type { ApiHandlerOptions } from "../../../shared/api" + +import { CerebrasHandler } from "../cerebras" describe("CerebrasHandler", () => { let handler: CerebrasHandler - const mockOptions = { - cerebrasApiKey: "test-api-key", - apiModelId: "llama-3.3-70b" as CerebrasModelId, - } + let mockOptions: ApiHandlerOptions beforeEach(() => { - vi.clearAllMocks() + mockOptions = { + cerebrasApiKey: "test-api-key", + apiModelId: "llama-3.3-70b" as CerebrasModelId, + } handler = new CerebrasHandler(mockOptions) + vi.clearAllMocks() }) describe("constructor", () => { - it("should throw error when API key is missing", () => { - expect(() => new CerebrasHandler({ cerebrasApiKey: "" })).toThrow("Cerebras API key is required") + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(CerebrasHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) }) - it("should initialize with valid API key", () => { - expect(() => new CerebrasHandler(mockOptions)).not.toThrow() + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new CerebrasHandler({ + ...mockOptions, + apiModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(cerebrasDefaultModelId) }) }) describe("getModel", () => { - it("should return correct model info", () => { - const { id, info } = handler.getModel() - expect(id).toBe("llama-3.3-70b") - expect(info).toEqual(cerebrasModels["llama-3.3-70b"]) + it("should return model info for valid model ID", () => { + const model = handler.getModel() + expect(model.id).toBe(mockOptions.apiModelId) + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(16384) + expect(model.info.contextWindow).toBe(64000) + expect(model.info.supportsImages).toBe(false) + expect(model.info.supportsPromptCache).toBe(false) }) - it("should fallback to default model when apiModelId is not provided", () => { - const handlerWithoutModel = new CerebrasHandler({ cerebrasApiKey: "test" }) - const { id } = handlerWithoutModel.getModel() - expect(id).toBe("gpt-oss-120b") // cerebrasDefaultModelId - }) - }) - - describe("message conversion", () => { - it("should strip thinking tokens from assistant messages", () => { - // This would test the stripThinkingTokens function - // Implementation details would test the regex functionality + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new CerebrasHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") // Returns provided ID + expect(model.info).toBeDefined() + // Should have the same base properties as default model + expect(model.info.contextWindow).toBe(cerebrasModels[cerebrasDefaultModelId].contextWindow) }) - it("should flatten complex message content to strings", () => { - // This would test the flattenMessageContent function - // Test various content types: strings, arrays, image objects + it("should return default model if no model ID is provided", () => { + const handlerWithoutModel = new CerebrasHandler({ + ...mockOptions, + apiModelId: undefined, + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(cerebrasDefaultModelId) + expect(model.info).toBeDefined() }) - it("should convert OpenAI messages to Cerebras format", () => { - // This would test the convertToCerebrasMessages function - // Ensure all messages have string content and proper role/content structure + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") }) }) describe("createMessage", () => { - it("should make correct API request", async () => { - // Mock successful API response - const mockResponse = { - ok: true, - body: { - getReader: () => ({ - read: vi.fn().mockResolvedValueOnce({ done: true, value: new Uint8Array() }), - releaseLock: vi.fn(), - }), + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] + + it("should handle streaming responses", async () => { + // Mock the fullStream async generator + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + // Mock usage promise + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + }) + + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(5) + }) + + it("should handle reasoning content in streaming responses", async () => { + // Mock the fullStream async generator with reasoning content + async function* mockFullStream() { + yield { type: "reasoning", text: "Let me think about this..." } + yield { type: "reasoning", text: " I'll analyze step by step." } + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + details: { + reasoningTokens: 15, }, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } - vi.mocked(fetch).mockResolvedValueOnce(mockResponse as any) - const generator = handler.createMessage("System prompt", []) - await generator.next() // Actually start the generator to trigger the fetch call + // Should have reasoning chunks + const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") + expect(reasoningChunks.length).toBe(2) + expect(reasoningChunks[0].text).toBe("Let me think about this...") + expect(reasoningChunks[1].text).toBe(" I'll analyze step by step.") + + // Should also have text chunks + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks.length).toBe(1) + expect(textChunks[0].text).toBe("Test response") + }) + }) + + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) + + const result = await handler.completePrompt("Test prompt") - // Test that fetch was called with correct parameters - expect(fetch).toHaveBeenCalledWith( - "https://api.cerebras.ai/v1/chat/completions", + expect(result).toBe("Test completion") + expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - method: "POST", - headers: expect.objectContaining({ - "Content-Type": "application/json", - Authorization: "Bearer test-api-key", - "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", - "X-Title": "Roo Code", - "User-Agent": "RooCode/1.0.0", - }), + prompt: "Test prompt", }), ) }) + }) - it("should handle API errors properly", async () => { - const mockErrorResponse = { - ok: false, - status: 400, - text: () => Promise.resolve('{"error": {"message": "Bad Request"}}'), + describe("processUsageMetrics", () => { + it("should correctly process usage metrics", () => { + // We need to access the protected method, so we'll create a test subclass + class TestCerebrasHandler extends CerebrasHandler { + public testProcessUsageMetrics(usage: any) { + return this.processUsageMetrics(usage) + } } - vi.mocked(fetch).mockResolvedValueOnce(mockErrorResponse as any) - const generator = handler.createMessage("System prompt", []) - // Since the mock isn't working, let's just check that an error is thrown - await expect(generator.next()).rejects.toThrow() - }) + const testHandler = new TestCerebrasHandler(mockOptions) - it("should parse streaming responses correctly", async () => { - // Test streaming response parsing - // Mock ReadableStream with various data chunks - // Verify thinking token extraction and usage tracking + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 20, + reasoningTokens: 30, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheReadTokens).toBe(20) + expect(result.reasoningTokens).toBe(30) }) - it("should handle temperature clamping", async () => { - const handlerWithTemp = new CerebrasHandler({ - ...mockOptions, - modelTemperature: 2.0, // Above Cerebras max of 1.5 - }) + it("should handle missing cache metrics gracefully", () => { + class TestCerebrasHandler extends CerebrasHandler { + public testProcessUsageMetrics(usage: any) { + return this.processUsageMetrics(usage) + } + } + + const testHandler = new TestCerebrasHandler(mockOptions) - vi.mocked(fetch).mockResolvedValueOnce({ - ok: true, - body: { getReader: () => ({ read: () => Promise.resolve({ done: true }), releaseLock: vi.fn() }) }, - } as any) + const usage = { + inputTokens: 100, + outputTokens: 50, + } - await handlerWithTemp.createMessage("test", []).next() + const result = testHandler.testProcessUsageMetrics(usage) - const requestBody = JSON.parse(vi.mocked(fetch).mock.calls[0][1]?.body as string) - expect(requestBody.temperature).toBe(1.5) // Should be clamped + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheReadTokens).toBeUndefined() + expect(result.reasoningTokens).toBeUndefined() }) }) - describe("completePrompt", () => { - it("should handle non-streaming completion", async () => { - const mockResponse = { - ok: true, - json: () => - Promise.resolve({ - choices: [{ message: { content: "Test response" } }], - }), + describe("getMaxOutputTokens", () => { + it("should return maxTokens from model info", () => { + class TestCerebrasHandler extends CerebrasHandler { + public testGetMaxOutputTokens() { + return this.getMaxOutputTokens() + } } - vi.mocked(fetch).mockResolvedValueOnce(mockResponse as any) - const result = await handler.completePrompt("Test prompt") - expect(result).toBe("Test response") + const testHandler = new TestCerebrasHandler(mockOptions) + const result = testHandler.testGetMaxOutputTokens() + + // llama-3.3-70b maxTokens is 16384 + expect(result).toBe(16384) }) - }) - describe("token usage and cost calculation", () => { - it("should track token usage properly", () => { - // Test that lastUsage is updated correctly - // Test getApiCost returns calculated cost based on actual usage + it("should use modelMaxTokens when provided", () => { + class TestCerebrasHandler extends CerebrasHandler { + public testGetMaxOutputTokens() { + return this.getMaxOutputTokens() + } + } + + const customMaxTokens = 5000 + const testHandler = new TestCerebrasHandler({ + ...mockOptions, + modelMaxTokens: customMaxTokens, + }) + + const result = testHandler.testGetMaxOutputTokens() + expect(result).toBe(customMaxTokens) }) - it("should provide usage estimates when API doesn't return usage", () => { - // Test fallback token estimation logic + it("should fall back to modelInfo.maxTokens when modelMaxTokens is not provided", () => { + class TestCerebrasHandler extends CerebrasHandler { + public testGetMaxOutputTokens() { + return this.getMaxOutputTokens() + } + } + + const testHandler = new TestCerebrasHandler(mockOptions) + const result = testHandler.testGetMaxOutputTokens() + + // llama-3.3-70b has maxTokens of 16384 + expect(result).toBe(16384) }) }) - describe("convertToolsForOpenAI", () => { - it("should set all tools to strict: false for Cerebras API consistency", () => { - // Access the protected method through a test subclass - const regularTool = { - type: "function", - function: { - name: "read_file", - parameters: { - type: "object", - properties: { - path: { type: "string" }, + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', + } + yield { + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, }, - required: ["path"], }, - }, + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } - // MCP tool with the 'mcp--' prefix - const mcpTool = { - type: "function", - function: { - name: "mcp--server--tool", - parameters: { - type: "object", - properties: { - arg: { type: "string" }, + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") + }) + + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + // tool-call events are intentionally ignored because tool-input-start/delta/end + // already provide complete tool call information. Emitting tool-call would cause + // duplicate tools in the UI for AI SDK providers (e.g., DeepSeek, Moonshot, Cerebras). + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, }, }, - }, + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } - // Create a test wrapper to access protected method + // tool-call events are ignored, so no tool_call chunks should be emitted + const toolCallChunks = chunks.filter((c) => c.type === "tool_call") + expect(toolCallChunks.length).toBe(0) + }) + }) + + describe("mapToolChoice", () => { + it("should handle string tool choices", () => { class TestCerebrasHandler extends CerebrasHandler { - public testConvertToolsForOpenAI(tools: any[]) { - return this.convertToolsForOpenAI(tools) + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) } } - const testHandler = new TestCerebrasHandler({ cerebrasApiKey: "test" }) - const converted = testHandler.testConvertToolsForOpenAI([regularTool, mcpTool]) + const testHandler = new TestCerebrasHandler(mockOptions) - // Both tools should have strict: false - expect(converted).toHaveLength(2) - expect(converted![0].function.strict).toBe(false) - expect(converted![1].function.strict).toBe(false) + expect(testHandler.testMapToolChoice("auto")).toBe("auto") + expect(testHandler.testMapToolChoice("none")).toBe("none") + expect(testHandler.testMapToolChoice("required")).toBe("required") + expect(testHandler.testMapToolChoice("unknown")).toBe("auto") }) - it("should return undefined when tools is undefined", () => { + it("should handle object tool choice with function name", () => { class TestCerebrasHandler extends CerebrasHandler { - public testConvertToolsForOpenAI(tools: any[] | undefined) { - return this.convertToolsForOpenAI(tools) + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) } } - const testHandler = new TestCerebrasHandler({ cerebrasApiKey: "test" }) - expect(testHandler.testConvertToolsForOpenAI(undefined)).toBeUndefined() + const testHandler = new TestCerebrasHandler(mockOptions) + + const result = testHandler.testMapToolChoice({ + type: "function", + function: { name: "my_tool" }, + }) + + expect(result).toEqual({ type: "tool", toolName: "my_tool" }) }) - it("should pass through non-function tools unchanged", () => { + it("should return undefined for null or undefined", () => { class TestCerebrasHandler extends CerebrasHandler { - public testConvertToolsForOpenAI(tools: any[]) { - return this.convertToolsForOpenAI(tools) + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) } } - const nonFunctionTool = { type: "other", data: "test" } - const testHandler = new TestCerebrasHandler({ cerebrasApiKey: "test" }) - const converted = testHandler.testConvertToolsForOpenAI([nonFunctionTool]) + const testHandler = new TestCerebrasHandler(mockOptions) - expect(converted![0]).toEqual(nonFunctionTool) + expect(testHandler.testMapToolChoice(null)).toBeUndefined() + expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() }) }) }) diff --git a/src/api/providers/cerebras.ts b/src/api/providers/cerebras.ts index 8ca30af36f1..0fbc375bdcd 100644 --- a/src/api/providers/cerebras.ts +++ b/src/api/providers/cerebras.ts @@ -1,362 +1,182 @@ import { Anthropic } from "@anthropic-ai/sdk" +import { createCerebras } from "@ai-sdk/cerebras" +import { streamText, generateText, ToolSet } from "ai" -import { type CerebrasModelId, cerebrasDefaultModelId, cerebrasModels } from "@roo-code/types" +import { cerebrasModels, cerebrasDefaultModelId, type CerebrasModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { calculateApiCostOpenAI } from "../../shared/cost" -import { ApiStream } from "../transform/stream" -import { convertToOpenAiMessages } from "../transform/openai-format" -import { TagMatcher } from "../../utils/tag-matcher" -import type { ApiHandlerCreateMessageMetadata, SingleCompletionHandler } from "../index" -import { BaseProvider } from "./base-provider" -import { DEFAULT_HEADERS } from "./constants" -import { t } from "../../i18n" +import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" -const CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1" -const CEREBRAS_DEFAULT_TEMPERATURE = 0 +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" const CEREBRAS_INTEGRATION_HEADER = "X-Cerebras-3rd-Party-Integration" const CEREBRAS_INTEGRATION_NAME = "roocode" +const CEREBRAS_DEFAULT_TEMPERATURE = 0 +/** + * Cerebras provider using the dedicated @ai-sdk/cerebras package. + * Provides high-speed inference powered by Wafer-Scale Engines. + */ export class CerebrasHandler extends BaseProvider implements SingleCompletionHandler { - private apiKey: string - private providerModels: typeof cerebrasModels - private defaultProviderModelId: CerebrasModelId - private options: ApiHandlerOptions - private lastUsage: { inputTokens: number; outputTokens: number } = { inputTokens: 0, outputTokens: 0 } + protected options: ApiHandlerOptions + protected provider: ReturnType constructor(options: ApiHandlerOptions) { super() this.options = options - this.apiKey = options.cerebrasApiKey || "" - this.providerModels = cerebrasModels - this.defaultProviderModelId = cerebrasDefaultModelId - if (!this.apiKey) { - throw new Error("Cerebras API key is required") - } + // Create the Cerebras provider using AI SDK + this.provider = createCerebras({ + apiKey: options.cerebrasApiKey ?? "not-provided", + headers: { + ...DEFAULT_HEADERS, + [CEREBRAS_INTEGRATION_HEADER]: CEREBRAS_INTEGRATION_NAME, + }, + }) } - getModel(): { id: CerebrasModelId; info: (typeof cerebrasModels)[CerebrasModelId] } { - const modelId = this.options.apiModelId as CerebrasModelId - const validModelId = modelId && this.providerModels[modelId] ? modelId : this.defaultProviderModelId - - return { - id: validModelId, - info: this.providerModels[validModelId], - } + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = (this.options.apiModelId ?? cerebrasDefaultModelId) as CerebrasModelId + const info = cerebrasModels[id as keyof typeof cerebrasModels] || cerebrasModels[cerebrasDefaultModelId] + const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options }) + return { id, info, ...params } } /** - * Override convertToolSchemaForOpenAI to remove unsupported schema fields for Cerebras. - * Cerebras doesn't support minItems/maxItems in array schemas with strict mode. + * Get the language model for the configured model ID. */ - protected override convertToolSchemaForOpenAI(schema: any): any { - const converted = super.convertToolSchemaForOpenAI(schema) - return this.stripUnsupportedSchemaFields(converted) + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) } /** - * Recursively strips unsupported schema fields for Cerebras. - * Cerebras strict mode doesn't support minItems, maxItems on arrays. + * Process usage metrics from the AI SDK response. */ - private stripUnsupportedSchemaFields(schema: any): any { - if (!schema || typeof schema !== "object") { - return schema + protected processUsageMetrics(usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number } + }): ApiStreamUsageChunk { + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens: usage.details?.cachedInputTokens, + reasoningTokens: usage.details?.reasoningTokens, + } + } - const result = { ...schema } - - // Remove unsupported array constraints - if (result.type === "array" || (Array.isArray(result.type) && result.type.includes("array"))) { - delete result.minItems - delete result.maxItems + /** + * Map OpenAI tool_choice to AI SDK toolChoice format. + */ + protected mapToolChoice( + toolChoice: any, + ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { + if (!toolChoice) { + return undefined } - // Recursively process properties - if (result.properties) { - const newProps = { ...result.properties } - for (const key of Object.keys(newProps)) { - newProps[key] = this.stripUnsupportedSchemaFields(newProps[key]) + // Handle string values + if (typeof toolChoice === "string") { + switch (toolChoice) { + case "auto": + return "auto" + case "none": + return "none" + case "required": + return "required" + default: + return "auto" } - result.properties = newProps } - // Recursively process array items - if (result.items) { - result.items = this.stripUnsupportedSchemaFields(result.items) + // Handle object values (OpenAI ChatCompletionNamedToolChoice format) + if (typeof toolChoice === "object" && "type" in toolChoice) { + if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { + return { type: "tool", toolName: toolChoice.function.name } + } } - return result + return undefined } /** - * Override convertToolsForOpenAI to ensure all tools have consistent strict values. - * Cerebras API requires all tools to have the same strict mode setting. - * We use strict: false for all tools since MCP tools cannot use strict mode - * (they have optional parameters from the MCP server schema). + * Get the max tokens parameter to include in the request. */ - protected override convertToolsForOpenAI(tools: any[] | undefined): any[] | undefined { - if (!tools) { - return undefined - } - - return tools.map((tool) => { - if (tool.type !== "function") { - return tool - } - - return { - ...tool, - function: { - ...tool.function, - strict: false, - parameters: this.convertToolSchemaForOpenAI(tool.function.parameters), - }, - } - }) + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined } - async *createMessage( + /** + * Create a message stream using the AI SDK. + */ + override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: model, info: modelInfo } = this.getModel() - const max_tokens = modelInfo.maxTokens - const temperature = this.options.modelTemperature ?? CEREBRAS_DEFAULT_TEMPERATURE - - // Convert Anthropic messages to OpenAI format (Cerebras is OpenAI-compatible) - const openaiMessages = convertToOpenAiMessages(messages) - - // Prepare request body following Cerebras API specification exactly - const requestBody: Record = { - model, - messages: [{ role: "system", content: systemPrompt }, ...openaiMessages], - stream: true, - // Use max_completion_tokens (Cerebras-specific parameter) - ...(max_tokens && max_tokens > 0 && max_tokens <= 32768 ? { max_completion_tokens: max_tokens } : {}), - // Clamp temperature to Cerebras range (0 to 1.5) - ...(temperature !== undefined && temperature !== CEREBRAS_DEFAULT_TEMPERATURE - ? { - temperature: Math.max(0, Math.min(1.5, temperature)), - } - : {}), - // Native tool calling support - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? CEREBRAS_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: this.mapToolChoice(metadata?.tool_choice), } - try { - const response = await fetch(`${CEREBRAS_BASE_URL}/chat/completions`, { - method: "POST", - headers: { - ...DEFAULT_HEADERS, - "Content-Type": "application/json", - Authorization: `Bearer ${this.apiKey}`, - [CEREBRAS_INTEGRATION_HEADER]: CEREBRAS_INTEGRATION_NAME, - }, - body: JSON.stringify(requestBody), - }) - - if (!response.ok) { - const errorText = await response.text() - - let errorMessage = "Unknown error" - try { - const errorJson = JSON.parse(errorText) - errorMessage = errorJson.error?.message || errorJson.message || JSON.stringify(errorJson, null, 2) - } catch { - errorMessage = errorText || `HTTP ${response.status}` - } - - // Provide more actionable error messages - if (response.status === 401) { - throw new Error(t("common:errors.cerebras.authenticationFailed")) - } else if (response.status === 403) { - throw new Error(t("common:errors.cerebras.accessForbidden")) - } else if (response.status === 429) { - throw new Error(t("common:errors.cerebras.rateLimitExceeded")) - } else if (response.status >= 500) { - throw new Error(t("common:errors.cerebras.serverError", { status: response.status })) - } else { - throw new Error( - t("common:errors.cerebras.genericError", { status: response.status, message: errorMessage }), - ) - } - } - - if (!response.body) { - throw new Error(t("common:errors.cerebras.noResponseBody")) - } - - // Initialize TagMatcher to parse ... tags - const matcher = new TagMatcher( - "think", - (chunk) => - ({ - type: chunk.matched ? "reasoning" : "text", - text: chunk.data, - }) as const, - ) - - const reader = response.body.getReader() - const decoder = new TextDecoder() - let buffer = "" - let inputTokens = 0 - let outputTokens = 0 - - try { - while (true) { - const { done, value } = await reader.read() - if (done) break - - buffer += decoder.decode(value, { stream: true }) - const lines = buffer.split("\n") - buffer = lines.pop() || "" // Keep the last incomplete line in the buffer - - for (const line of lines) { - if (line.trim() === "") continue - - try { - if (line.startsWith("data: ")) { - const jsonStr = line.slice(6).trim() - if (jsonStr === "[DONE]") { - continue - } - - const parsed = JSON.parse(jsonStr) + // Use streamText for streaming responses + const result = streamText(requestOptions) - const delta = parsed.choices?.[0]?.delta - - // Handle text content - parse for thinking tokens - if (delta?.content) { - const content = delta.content - - // Use TagMatcher to parse ... tags - for (const chunk of matcher.update(content)) { - yield chunk - } - } - - // Handle tool calls in stream - emit partial chunks for NativeToolCallParser - if (delta?.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } - - // Handle usage information if available - if (parsed.usage) { - inputTokens = parsed.usage.prompt_tokens || 0 - outputTokens = parsed.usage.completion_tokens || 0 - } - } - } catch (error) { - // Silently ignore malformed streaming data lines - } - } - } - } finally { - reader.releaseLock() - } - - // Process any remaining content in the matcher - for (const chunk of matcher.final()) { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { yield chunk } + } - // Provide token usage estimate if not available from API - if (inputTokens === 0 || outputTokens === 0) { - const inputText = - systemPrompt + - openaiMessages - .map((m: any) => (typeof m.content === "string" ? m.content : JSON.stringify(m.content))) - .join("") - inputTokens = inputTokens || Math.ceil(inputText.length / 4) // Rough estimate: 4 chars per token - outputTokens = outputTokens || Math.ceil((max_tokens || 1000) / 10) // Rough estimate - } - - // Store usage for cost calculation - this.lastUsage = { inputTokens, outputTokens } - - yield { - type: "usage", - inputTokens, - outputTokens, - } - } catch (error) { - if (error instanceof Error) { - throw new Error(t("common:errors.cerebras.completionError", { error: error.message })) - } - throw error + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) } } + /** + * Complete a prompt using the AI SDK generateText. + */ async completePrompt(prompt: string): Promise { - const { id: model } = this.getModel() - - // Prepare request body for non-streaming completion - const requestBody = { - model, - messages: [{ role: "user", content: prompt }], - stream: false, - } - - try { - const response = await fetch(`${CEREBRAS_BASE_URL}/chat/completions`, { - method: "POST", - headers: { - ...DEFAULT_HEADERS, - "Content-Type": "application/json", - Authorization: `Bearer ${this.apiKey}`, - [CEREBRAS_INTEGRATION_HEADER]: CEREBRAS_INTEGRATION_NAME, - }, - body: JSON.stringify(requestBody), - }) - - if (!response.ok) { - const errorText = await response.text() - - // Provide consistent error handling with createMessage - if (response.status === 401) { - throw new Error(t("common:errors.cerebras.authenticationFailed")) - } else if (response.status === 403) { - throw new Error(t("common:errors.cerebras.accessForbidden")) - } else if (response.status === 429) { - throw new Error(t("common:errors.cerebras.rateLimitExceeded")) - } else if (response.status >= 500) { - throw new Error(t("common:errors.cerebras.serverError", { status: response.status })) - } else { - throw new Error( - t("common:errors.cerebras.genericError", { status: response.status, message: errorText }), - ) - } - } - - const result = await response.json() - return result.choices?.[0]?.message?.content || "" - } catch (error) { - if (error instanceof Error) { - throw new Error(t("common:errors.cerebras.completionError", { error: error.message })) - } - throw error - } - } + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? CEREBRAS_DEFAULT_TEMPERATURE, + }) - getApiCost(metadata: ApiHandlerCreateMessageMetadata): number { - const { info } = this.getModel() - // Use actual token usage from the last request - const { inputTokens, outputTokens } = this.lastUsage - const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens) - return totalCost + return text } } diff --git a/src/package.json b/src/package.json index 0640ed00d1d..edacfd403d2 100644 --- a/src/package.json +++ b/src/package.json @@ -450,6 +450,7 @@ "clean": "rimraf README.md CHANGELOG.md LICENSE dist logs mock .turbo" }, "dependencies": { + "@ai-sdk/cerebras": "^1.0.0", "@ai-sdk/deepseek": "^2.0.14", "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.37.0",