diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/tool-input/tool-input.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/tool-input/tool-input.tsx index a38bff4e4f3..ff931551c85 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/tool-input/tool-input.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/tool-input/tool-input.tsx @@ -14,8 +14,7 @@ import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip import type { OAuthProvider } from '@/lib/oauth/oauth' import { cn } from '@/lib/utils' import { getAllBlocks } from '@/blocks' -import { supportsToolUsageControl } from '@/providers/model-capabilities' -import { getProviderFromModel } from '@/providers/utils' +import { getProviderFromModel, supportsToolUsageControl } from '@/providers/utils' import { useCustomToolsStore } from '@/stores/custom-tools/store' import { useGeneralStore } from '@/stores/settings/general/store' import { useSubBlockStore } from '@/stores/workflows/subblock/store' diff --git a/apps/sim/app/w/[id]/hooks/use-workflow-execution.ts b/apps/sim/app/w/[id]/hooks/use-workflow-execution.ts index 1f7c78f47e2..31e3994107b 100644 --- a/apps/sim/app/w/[id]/hooks/use-workflow-execution.ts +++ b/apps/sim/app/w/[id]/hooks/use-workflow-execution.ts @@ -92,9 +92,13 @@ export function useWorkflowExecution() { const streamingBlockId = (result.metadata as any)?.streamingBlockId || null for (const log of enrichedResult.logs) { - // Only update the specific agent block that was streamed + // Only update the specific LLM block (agent/router) that was streamed const isStreamingBlock = streamingBlockId && log.blockId === streamingBlockId - if (isStreamingBlock && log.blockType === 'agent' && log.output?.response) { + if ( + isStreamingBlock && + (log.blockType === 'agent' || log.blockType === 'router') && + log.output?.response + ) { log.output.response.content = streamContent } } diff --git a/apps/sim/app/w/logs/components/sidebar/sidebar.tsx b/apps/sim/app/w/logs/components/sidebar/sidebar.tsx index e690fcac628..a47131e42d5 100644 --- a/apps/sim/app/w/logs/components/sidebar/sidebar.tsx +++ b/apps/sim/app/w/logs/components/sidebar/sidebar.tsx @@ -556,7 +556,8 @@ export function Sidebar({ {isWorkflowWithCost && (

- This is the total cost for all agent blocks in this workflow execution. + This is the total cost for all LLM-based blocks in this workflow + execution.

)} diff --git a/apps/sim/blocks/blocks/agent.ts b/apps/sim/blocks/blocks/agent.ts index 08fb61fb24c..44f950300d9 100644 --- a/apps/sim/blocks/blocks/agent.ts +++ b/apps/sim/blocks/blocks/agent.ts @@ -1,8 +1,14 @@ import { AgentIcon } from '@/components/icons' import { isHosted } from '@/lib/environment' import { createLogger } from '@/lib/logs/console-logger' -import { MODELS_TEMP_RANGE_0_1, MODELS_TEMP_RANGE_0_2 } from '@/providers/model-capabilities' -import { getAllModelProviders, getBaseModelProviders } from '@/providers/utils' +import { + getAllModelProviders, + getBaseModelProviders, + getHostedModels, + MODELS_TEMP_RANGE_0_1, + MODELS_TEMP_RANGE_0_2, + providers, +} from '@/providers/utils' import { useOllamaStore } from '@/stores/ollama/store' import type { ToolResponse } from '@/tools/types' import type { BlockConfig } from '../types' @@ -121,29 +127,11 @@ export const AgentBlock: BlockConfig = { placeholder: 'Enter your API key', password: true, connectionDroppable: false, - // Hide API key for all OpenAI and Claude models when running on hosted version + // Hide API key for all hosted models when running on hosted version condition: isHosted ? { field: 'model', - // Include all OpenAI models and Claude models for which we don't show the API key field - value: [ - // OpenAI models - 'gpt-4o', - 'o1', - 'o1-mini', - 'o1-preview', - 'o3', - 'o3-preview', - 'o4-mini', - 'gpt-4.1', - 'gpt-4.1-nano', - 'gpt-4.1-mini', - // Claude models - 'claude-sonnet-4-0', - 'claude-opus-4-0', - 'claude-3-7-sonnet-latest', - 'claude-3-5-sonnet-latest', - ], + value: getHostedModels(), not: true, // Show for all models EXCEPT those listed } : undefined, // Show for all models in non-hosted environments @@ -158,7 +146,7 @@ export const AgentBlock: BlockConfig = { connectionDroppable: false, condition: { field: 'model', - value: ['azure/gpt-4o', 'azure/o3', 'azure/o4-mini', 'azure/gpt-4.1', 'azure/model-router'], + value: providers['azure-openai'].models, }, }, { @@ -170,7 +158,7 @@ export const AgentBlock: BlockConfig = { connectionDroppable: false, condition: { field: 'model', - value: ['azure/gpt-4o', 'azure/o3', 'azure/o4-mini', 'azure/gpt-4.1', 'azure/model-router'], + value: providers['azure-openai'].models, }, }, { diff --git a/apps/sim/blocks/blocks/evaluator.ts b/apps/sim/blocks/blocks/evaluator.ts index d7a41d52965..a10dfd4094f 100644 --- a/apps/sim/blocks/blocks/evaluator.ts +++ b/apps/sim/blocks/blocks/evaluator.ts @@ -1,7 +1,8 @@ import { ChartBarIcon } from '@/components/icons' +import { isHosted } from '@/lib/environment' import { createLogger } from '@/lib/logs/console-logger' import type { ProviderId } from '@/providers/types' -import { getAllModelProviders, getBaseModelProviders } from '@/providers/utils' +import { getAllModelProviders, getBaseModelProviders, getHostedModels } from '@/providers/utils' import { useOllamaStore } from '@/stores/ollama/store' import type { ToolResponse } from '@/tools/types' import type { BlockConfig, ParamType } from '../types' @@ -26,6 +27,11 @@ interface EvaluatorResponse extends ToolResponse { completion?: number total?: number } + cost?: { + input: number + output: number + total: number + } [metricName: string]: any // Allow dynamic metric fields } } @@ -181,6 +187,13 @@ export const EvaluatorBlock: BlockConfig = { placeholder: 'Enter your API key', password: true, connectionDroppable: false, + condition: isHosted + ? { + field: 'model', + value: getHostedModels(), + not: true, + } + : undefined, }, { id: 'systemPrompt', @@ -299,6 +312,7 @@ export const EvaluatorBlock: BlockConfig = { content: 'string', model: 'string', tokens: 'any', + cost: 'any', }, dependsOn: { subBlockId: 'metrics', @@ -307,6 +321,7 @@ export const EvaluatorBlock: BlockConfig = { content: 'string', model: 'string', tokens: 'any', + cost: 'any', }, whenFilled: 'json', }, diff --git a/apps/sim/blocks/blocks/router.ts b/apps/sim/blocks/blocks/router.ts index 61201d59370..56f19515376 100644 --- a/apps/sim/blocks/blocks/router.ts +++ b/apps/sim/blocks/blocks/router.ts @@ -1,6 +1,7 @@ import { ConnectIcon } from '@/components/icons' +import { isHosted } from '@/lib/environment' import type { ProviderId } from '@/providers/types' -import { getAllModelProviders, getBaseModelProviders } from '@/providers/utils' +import { getAllModelProviders, getBaseModelProviders, getHostedModels } from '@/providers/utils' import { useOllamaStore } from '@/stores/ollama/store' import type { ToolResponse } from '@/tools/types' import type { BlockConfig } from '../types' @@ -14,6 +15,11 @@ interface RouterResponse extends ToolResponse { completion?: number total?: number } + cost?: { + input: number + output: number + total: number + } selectedPath: { blockId: string blockType: string @@ -125,6 +131,14 @@ export const RouterBlock: BlockConfig = { placeholder: 'Enter your API key', password: true, connectionDroppable: false, + // Hide API key for all hosted models when running on hosted version + condition: isHosted + ? { + field: 'model', + value: getHostedModels(), + not: true, // Show for all models EXCEPT those listed + } + : undefined, // Show for all models in non-hosted environments }, { id: 'systemPrompt', @@ -171,6 +185,7 @@ export const RouterBlock: BlockConfig = { content: 'string', model: 'string', tokens: 'any', + cost: 'any', selectedPath: 'json', }, }, diff --git a/apps/sim/executor/handlers/evaluator/evaluator-handler.test.ts b/apps/sim/executor/handlers/evaluator/evaluator-handler.test.ts index f58ab5859b7..cd436c457d5 100644 --- a/apps/sim/executor/handlers/evaluator/evaluator-handler.test.ts +++ b/apps/sim/executor/handlers/evaluator/evaluator-handler.test.ts @@ -91,6 +91,11 @@ describe('EvaluatorBlockHandler', () => { content: 'This is the content to evaluate.', model: 'mock-model', tokens: { prompt: 50, completion: 10, total: 60 }, + cost: { + input: 0, + output: 0, + total: 0, + }, score1: 5, score2: 8, }, diff --git a/apps/sim/executor/handlers/evaluator/evaluator-handler.ts b/apps/sim/executor/handlers/evaluator/evaluator-handler.ts index 82f26f54092..642a7b13461 100644 --- a/apps/sim/executor/handlers/evaluator/evaluator-handler.ts +++ b/apps/sim/executor/handlers/evaluator/evaluator-handler.ts @@ -1,7 +1,7 @@ import { env } from '@/lib/env' import { createLogger } from '@/lib/logs/console-logger' import type { BlockOutput } from '@/blocks/types' -import { getProviderFromModel } from '@/providers/utils' +import { calculateCost, getProviderFromModel } from '@/providers/utils' import type { SerializedBlock } from '@/serializer/types' import type { BlockHandler, ExecutionContext } from '../../types' @@ -241,6 +241,14 @@ export class EvaluatorBlockHandler implements BlockHandler { logger.error('Error extracting metric scores:', e) } + // Calculate cost based on token usage, similar to how providers do it + const costCalculation = calculateCost( + result.model, + result.tokens?.prompt || 0, + result.tokens?.completion || 0, + false // Evaluator blocks don't typically use cached input + ) + // Create result with metrics as direct fields for easy access const outputResult = { response: { @@ -251,6 +259,11 @@ export class EvaluatorBlockHandler implements BlockHandler { completion: result.tokens?.completion || 0, total: result.tokens?.total || 0, }, + cost: { + input: costCalculation.input, + output: costCalculation.output, + total: costCalculation.total, + }, ...metricScores, }, } diff --git a/apps/sim/executor/handlers/router/router-handler.test.ts b/apps/sim/executor/handlers/router/router-handler.test.ts index 8c9acc1d401..2ad38cf7178 100644 --- a/apps/sim/executor/handlers/router/router-handler.test.ts +++ b/apps/sim/executor/handlers/router/router-handler.test.ts @@ -152,6 +152,11 @@ describe('RouterBlockHandler', () => { content: 'Choose the best option.', model: 'mock-model', tokens: { prompt: 100, completion: 5, total: 105 }, + cost: { + input: 0, + output: 0, + total: 0, + }, selectedPath: { blockId: 'target-block-1', blockType: 'target', diff --git a/apps/sim/executor/handlers/router/router-handler.ts b/apps/sim/executor/handlers/router/router-handler.ts index 965e6e903c6..40a88781e25 100644 --- a/apps/sim/executor/handlers/router/router-handler.ts +++ b/apps/sim/executor/handlers/router/router-handler.ts @@ -2,7 +2,7 @@ import { env } from '@/lib/env' import { createLogger } from '@/lib/logs/console-logger' import { generateRouterPrompt } from '@/blocks/blocks/router' import type { BlockOutput } from '@/blocks/types' -import { getProviderFromModel } from '@/providers/utils' +import { calculateCost, getProviderFromModel } from '@/providers/utils' import type { SerializedBlock } from '@/serializer/types' import type { PathTracker } from '../../path' import type { BlockHandler, ExecutionContext } from '../../types' @@ -92,6 +92,14 @@ export class RouterBlockHandler implements BlockHandler { const tokens = result.tokens || { prompt: 0, completion: 0, total: 0 } + // Calculate cost based on token usage, similar to how providers do it + const cost = calculateCost( + result.model, + tokens.prompt || 0, + tokens.completion || 0, + false // Router blocks don't typically use cached input + ) + return { response: { content: inputs.prompt, @@ -101,6 +109,11 @@ export class RouterBlockHandler implements BlockHandler { completion: tokens.completion || 0, total: tokens.total || 0, }, + cost: { + input: cost.input, + output: cost.output, + total: cost.total, + }, selectedPath: { blockId: chosenBlock.id, blockType: chosenBlock.type || 'unknown', diff --git a/apps/sim/lib/email/mailer.test.ts b/apps/sim/lib/email/mailer.test.ts index 05901b1c2ce..f21b4e47f02 100644 --- a/apps/sim/lib/email/mailer.test.ts +++ b/apps/sim/lib/email/mailer.test.ts @@ -119,7 +119,7 @@ describe('mailer', () => { expect(mockSend).not.toHaveBeenCalled() }) - it('should handle Resend API errors', async () => { + it.concurrent('should handle Resend API errors', async () => { mockSend.mockResolvedValue({ data: null, error: { message: 'API rate limit exceeded' }, @@ -131,7 +131,7 @@ describe('mailer', () => { expect(result.message).toBe('API rate limit exceeded') }) - it('should handle unexpected errors', async () => { + it.concurrent('should handle unexpected errors', async () => { mockSend.mockRejectedValue(new Error('Network error')) const result = await sendEmail(testEmailOptions) @@ -140,7 +140,7 @@ describe('mailer', () => { expect(result.message).toBe('Failed to send email') }) - it('should use custom from address when provided', async () => { + it.concurrent('should use custom from address when provided', async () => { await sendEmail({ ...testEmailOptions, from: 'custom@example.com', @@ -168,7 +168,7 @@ describe('mailer', () => { ) }) - it('should replace unsubscribe token placeholders in HTML', async () => { + it.concurrent('should replace unsubscribe token placeholders in HTML', async () => { const htmlWithPlaceholder = '

Content

Unsubscribe' await sendEmail({ diff --git a/apps/sim/lib/email/unsubscribe.test.ts b/apps/sim/lib/email/unsubscribe.test.ts index 89abcb9fa5c..55ca63fd260 100644 --- a/apps/sim/lib/email/unsubscribe.test.ts +++ b/apps/sim/lib/email/unsubscribe.test.ts @@ -13,7 +13,7 @@ describe('unsubscribe utilities', () => { const testEmailType = 'marketing' describe('generateUnsubscribeToken', () => { - it('should generate a token with salt:hash:emailType format', () => { + it.concurrent('should generate a token with salt:hash:emailType format', () => { const token = generateUnsubscribeToken(testEmail, testEmailType) const parts = token.split(':') @@ -23,21 +23,24 @@ describe('unsubscribe utilities', () => { expect(parts[2]).toBe(testEmailType) }) - it('should generate different tokens for the same email (due to random salt)', () => { - const token1 = generateUnsubscribeToken(testEmail, testEmailType) - const token2 = generateUnsubscribeToken(testEmail, testEmailType) + it.concurrent( + 'should generate different tokens for the same email (due to random salt)', + () => { + const token1 = generateUnsubscribeToken(testEmail, testEmailType) + const token2 = generateUnsubscribeToken(testEmail, testEmailType) - expect(token1).not.toBe(token2) - }) + expect(token1).not.toBe(token2) + } + ) - it('should default to marketing email type', () => { + it.concurrent('should default to marketing email type', () => { const token = generateUnsubscribeToken(testEmail) const parts = token.split(':') expect(parts[2]).toBe('marketing') }) - it('should generate different tokens for different email types', () => { + it.concurrent('should generate different tokens for different email types', () => { const marketingToken = generateUnsubscribeToken(testEmail, 'marketing') const updatesToken = generateUnsubscribeToken(testEmail, 'updates') @@ -46,7 +49,7 @@ describe('unsubscribe utilities', () => { }) describe('verifyUnsubscribeToken', () => { - it('should verify a valid token', () => { + it.concurrent('should verify a valid token', () => { const token = generateUnsubscribeToken(testEmail, testEmailType) const result = verifyUnsubscribeToken(testEmail, token) @@ -54,7 +57,7 @@ describe('unsubscribe utilities', () => { expect(result.emailType).toBe(testEmailType) }) - it('should reject an invalid token', () => { + it.concurrent('should reject an invalid token', () => { const invalidToken = 'invalid:token:format' const result = verifyUnsubscribeToken(testEmail, invalidToken) @@ -62,14 +65,14 @@ describe('unsubscribe utilities', () => { expect(result.emailType).toBe('format') }) - it('should reject a token for wrong email', () => { + it.concurrent('should reject a token for wrong email', () => { const token = generateUnsubscribeToken(testEmail, testEmailType) const result = verifyUnsubscribeToken('wrong@example.com', token) expect(result.valid).toBe(false) }) - it('should handle legacy tokens (2 parts) and default to marketing', () => { + it.concurrent('should handle legacy tokens (2 parts) and default to marketing', () => { // Generate a real legacy token using the actual hashing logic to ensure backward compatibility const salt = 'abc123' const { createHash } = require('crypto') @@ -84,7 +87,7 @@ describe('unsubscribe utilities', () => { expect(result.emailType).toBe('marketing') // Should default to marketing for legacy tokens }) - it('should reject malformed tokens', () => { + it.concurrent('should reject malformed tokens', () => { const malformedTokens = ['', 'single-part', 'too:many:parts:here:invalid', ':empty:parts:'] malformedTokens.forEach((token) => { @@ -95,11 +98,11 @@ describe('unsubscribe utilities', () => { }) describe('isTransactionalEmail', () => { - it('should identify transactional emails correctly', () => { + it.concurrent('should identify transactional emails correctly', () => { expect(isTransactionalEmail('transactional')).toBe(true) }) - it('should identify non-transactional emails correctly', () => { + it.concurrent('should identify non-transactional emails correctly', () => { const nonTransactionalTypes: EmailType[] = ['marketing', 'updates', 'notifications'] nonTransactionalTypes.forEach((type) => { diff --git a/apps/sim/lib/logs/execution-logger.ts b/apps/sim/lib/logs/execution-logger.ts index e4833db7033..a726d7d210d 100644 --- a/apps/sim/lib/logs/execution-logger.ts +++ b/apps/sim/lib/logs/execution-logger.ts @@ -87,7 +87,7 @@ export async function persistExecutionLogs( const userId = workflowRecord.userId - // Track accumulated cost data across all agent blocks + // Track accumulated cost data across all LLM blocks (agent, router, and evaluator) let totalCost = 0 let totalInputCost = 0 let totalOutputCost = 0 @@ -102,11 +102,17 @@ export async function persistExecutionLogs( // Check for agent block and tool calls let metadata: ToolCallMetadata | undefined - // If this is an agent block - if (log.blockType === 'agent' && log.output) { - logger.debug('Processing agent block output for tool calls', { + // If this is an agent, router, or evaluator block (all use LLM providers and generate costs) + if ( + (log.blockType === 'agent' || + log.blockType === 'router' || + log.blockType === 'evaluator') && + log.output + ) { + logger.debug('Processing LLM-based block output for tool calls and cost tracking', { blockId: log.blockId, blockName: log.blockName, + blockType: log.blockType, outputKeys: Object.keys(log.output), hasToolCalls: !!log.output.toolCalls, hasResponse: !!log.output.response, @@ -165,7 +171,7 @@ export async function persistExecutionLogs( } } - // Special case for streaming responses from agent blocks + // Special case for streaming responses from LLM blocks (agent, router, and evaluator) // This format has both stream and executionData properties if (log.output.stream && log.output.executionData) { logger.debug('Found streaming response with executionData', { diff --git a/apps/sim/lib/schedules/utils.test.ts b/apps/sim/lib/schedules/utils.test.ts index 7988408b5f8..c9b4b63c0cd 100644 --- a/apps/sim/lib/schedules/utils.test.ts +++ b/apps/sim/lib/schedules/utils.test.ts @@ -15,27 +15,27 @@ import { describe('Schedule Utilities', () => { describe('parseTimeString', () => { - it('should parse valid time strings', () => { + it.concurrent('should parse valid time strings', () => { expect(parseTimeString('09:30')).toEqual([9, 30]) expect(parseTimeString('23:45')).toEqual([23, 45]) expect(parseTimeString('00:00')).toEqual([0, 0]) }) - it('should return default values for invalid inputs', () => { + it.concurrent('should return default values for invalid inputs', () => { expect(parseTimeString('')).toEqual([9, 0]) expect(parseTimeString(null)).toEqual([9, 0]) expect(parseTimeString(undefined)).toEqual([9, 0]) expect(parseTimeString('invalid')).toEqual([9, 0]) }) - it('should handle malformed time strings', () => { + it.concurrent('should handle malformed time strings', () => { expect(parseTimeString('9:30')).toEqual([9, 30]) expect(parseTimeString('9:3')).toEqual([9, 3]) expect(parseTimeString('9:')).toEqual([9, 0]) expect(parseTimeString(':30')).toEqual([0, 30]) // Only has minutes }) - it('should handle out-of-range time values', () => { + it.concurrent('should handle out-of-range time values', () => { expect(parseTimeString('25:30')).toEqual([25, 30]) // Hours > 24 expect(parseTimeString('10:75')).toEqual([10, 75]) // Minutes > 59 expect(parseTimeString('99:99')).toEqual([99, 99]) // Both out of range @@ -43,7 +43,7 @@ describe('Schedule Utilities', () => { }) describe('getSubBlockValue', () => { - it('should get values from block subBlocks', () => { + it.concurrent('should get values from block subBlocks', () => { const block: BlockState = { type: 'starter', subBlocks: { @@ -61,7 +61,7 @@ describe('Schedule Utilities', () => { expect(getSubBlockValue(block, 'nonExistent')).toBe('') }) - it('should handle missing subBlocks', () => { + it.concurrent('should handle missing subBlocks', () => { const block = { type: 'starter', subBlocks: {}, // Empty subBlocks @@ -72,7 +72,7 @@ describe('Schedule Utilities', () => { }) describe('getScheduleTimeValues', () => { - it('should extract all time values from a block', () => { + it.concurrent('should extract all time values from a block', () => { const block: BlockState = { type: 'starter', subBlocks: { @@ -105,7 +105,7 @@ describe('Schedule Utilities', () => { }) }) - it('should use default values for missing fields', () => { + it.concurrent('should use default values for missing fields', () => { const block: BlockState = { type: 'starter', subBlocks: { @@ -132,7 +132,7 @@ describe('Schedule Utilities', () => { }) describe('generateCronExpression', () => { - it('should generate correct cron expressions for different schedule types', () => { + it.concurrent('should generate correct cron expressions for different schedule types', () => { const scheduleValues = { scheduleTime: '09:30', minutesInterval: 15, @@ -161,7 +161,7 @@ describe('Schedule Utilities', () => { expect(generateCronExpression('monthly', scheduleValues)).toBe('30 14 15 * *') }) - it('should handle custom cron expressions', () => { + it.concurrent('should handle custom cron expressions', () => { // For this simplified test, let's skip the complex mocking // and just verify the 'custom' case is in the switch statement @@ -201,7 +201,7 @@ describe('Schedule Utilities', () => { expect(generateCronExpression('minutes', standardScheduleValues)).toBe('*/15 * * * *') }) - it('should throw for invalid schedule types', () => { + it.concurrent('should throw for invalid schedule types', () => { const scheduleValues = {} as any expect(() => generateCronExpression('invalid-type', scheduleValues)).toThrow() }) @@ -218,7 +218,7 @@ describe('Schedule Utilities', () => { vi.useRealTimers() }) - it('should calculate next run for minutes schedule', () => { + it.concurrent('should calculate next run for minutes schedule', () => { const scheduleValues = { scheduleTime: '', scheduleStartAt: '', @@ -242,7 +242,7 @@ describe('Schedule Utilities', () => { expect(nextRun.getMinutes() % 15).toBe(0) }) - it('should respect scheduleTime for minutes schedule', () => { + it.concurrent('should respect scheduleTime for minutes schedule', () => { const scheduleValues = { scheduleTime: '14:30', // Specific start time scheduleStartAt: '', @@ -263,7 +263,7 @@ describe('Schedule Utilities', () => { expect(nextRun.getMinutes()).toBe(30) }) - it('should calculate next run for hourly schedule', () => { + it.concurrent('should calculate next run for hourly schedule', () => { const scheduleValues = { scheduleTime: '', scheduleStartAt: '', @@ -285,7 +285,7 @@ describe('Schedule Utilities', () => { expect(nextRun.getMinutes()).toBe(30) }) - it('should calculate next run for daily schedule', () => { + it.concurrent('should calculate next run for daily schedule', () => { const scheduleValues = { scheduleTime: '', scheduleStartAt: '', @@ -308,7 +308,7 @@ describe('Schedule Utilities', () => { expect(nextRun.getMinutes()).toBe(0) }) - it('should calculate next run for weekly schedule', () => { + it.concurrent('should calculate next run for weekly schedule', () => { const scheduleValues = { scheduleTime: '', scheduleStartAt: '', @@ -330,7 +330,7 @@ describe('Schedule Utilities', () => { expect(nextRun.getMinutes()).toBe(0) }) - it('should calculate next run for monthly schedule', () => { + it.concurrent('should calculate next run for monthly schedule', () => { const scheduleValues = { scheduleTime: '', scheduleStartAt: '', @@ -354,7 +354,7 @@ describe('Schedule Utilities', () => { expect(nextRun.getMinutes()).toBe(30) }) - it('should consider lastRanAt for better interval calculation', () => { + it.concurrent('should consider lastRanAt for better interval calculation', () => { const scheduleValues = { scheduleTime: '', scheduleStartAt: '', @@ -381,7 +381,7 @@ describe('Schedule Utilities', () => { expect(nextRun.getMinutes()).toBe(expectedNextRun.getMinutes()) }) - it('should respect future scheduleStartAt date', () => { + it.concurrent('should respect future scheduleStartAt date', () => { const scheduleValues = { scheduleStartAt: '2025-04-22T20:50:00.000Z', // April 22, 2025 at 8:50 PM scheduleTime: '', @@ -401,7 +401,7 @@ describe('Schedule Utilities', () => { expect(nextRun.toISOString()).toBe('2025-04-22T20:50:00.000Z') }) - it('should ignore past scheduleStartAt date', () => { + it.concurrent('should ignore past scheduleStartAt date', () => { const scheduleValues = { scheduleStartAt: '2025-04-10T20:50:00.000Z', // April 10, 2025 at 8:50 PM (in the past) scheduleTime: '', @@ -424,7 +424,7 @@ describe('Schedule Utilities', () => { }) describe('parseCronToHumanReadable', () => { - it('should parse common cron patterns', () => { + it.concurrent('should parse common cron patterns', () => { expect(parseCronToHumanReadable('* * * * *')).toBe('Every minute') expect(parseCronToHumanReadable('*/15 * * * *')).toBe('Every 15 minutes') expect(parseCronToHumanReadable('30 * * * *')).toBe('Hourly at 30 minutes past the hour') @@ -434,14 +434,14 @@ describe('Schedule Utilities', () => { expect(parseCronToHumanReadable('30 14 15 * *')).toMatch(/Monthly on the 15th at 2:30 PM/) }) - it('should handle complex patterns', () => { + it.concurrent('should handle complex patterns', () => { // Test with various combinations expect(parseCronToHumanReadable('* */2 * * *')).toMatch(/Runs/) expect(parseCronToHumanReadable('0 9 * * 1-5')).toMatch(/Runs/) expect(parseCronToHumanReadable('0 9 1,15 * *')).toMatch(/Runs/) }) - it('should return a fallback for unrecognized patterns', () => { + it.concurrent('should return a fallback for unrecognized patterns', () => { const result = parseCronToHumanReadable('*/10 */6 31 2 *') // Invalid (Feb 31) // Just check that we get something back that's not empty expect(result.length).toBeGreaterThan(5) @@ -449,7 +449,7 @@ describe('Schedule Utilities', () => { }) describe('createDateWithTimezone', () => { - it('should correctly handle UTC timezone', () => { + it.concurrent('should correctly handle UTC timezone', () => { const date = createDateWithTimezone( '2025-04-21T00:00:00.000Z', '14:00', // 2:00 PM @@ -458,7 +458,7 @@ describe('Schedule Utilities', () => { expect(date.toISOString()).toBe('2025-04-21T14:00:00.000Z') }) - it('should correctly handle America/Los_Angeles (UTC-7 during DST)', () => { + it.concurrent('should correctly handle America/Los_Angeles (UTC-7 during DST)', () => { // April 21, 2025 is during DST for Los Angeles (PDT = UTC-7) const date = createDateWithTimezone( '2025-04-21', // Using date string without time/zone @@ -469,7 +469,7 @@ describe('Schedule Utilities', () => { expect(date.toISOString()).toBe('2025-04-21T21:00:00.000Z') }) - it('should correctly handle America/Los_Angeles (UTC-8 outside DST)', () => { + it.concurrent('should correctly handle America/Los_Angeles (UTC-8 outside DST)', () => { // January 10, 2025 is outside DST for Los Angeles (PST = UTC-8) const date = createDateWithTimezone( '2025-01-10', @@ -480,7 +480,7 @@ describe('Schedule Utilities', () => { expect(date.toISOString()).toBe('2025-01-10T22:00:00.000Z') }) - it('should correctly handle America/New_York (UTC-4 during DST)', () => { + it.concurrent('should correctly handle America/New_York (UTC-4 during DST)', () => { // June 15, 2025 is during DST for New York (EDT = UTC-4) const date = createDateWithTimezone( '2025-06-15', @@ -491,7 +491,7 @@ describe('Schedule Utilities', () => { expect(date.toISOString()).toBe('2025-06-15T14:30:00.000Z') }) - it('should correctly handle America/New_York (UTC-5 outside DST)', () => { + it.concurrent('should correctly handle America/New_York (UTC-5 outside DST)', () => { // December 20, 2025 is outside DST for New York (EST = UTC-5) const date = createDateWithTimezone( '2025-12-20', @@ -502,7 +502,7 @@ describe('Schedule Utilities', () => { expect(date.toISOString()).toBe('2025-12-20T15:30:00.000Z') }) - it('should correctly handle Europe/London (UTC+1 during DST)', () => { + it.concurrent('should correctly handle Europe/London (UTC+1 during DST)', () => { // August 5, 2025 is during DST for London (BST = UTC+1) const date = createDateWithTimezone( '2025-08-05', @@ -513,7 +513,7 @@ describe('Schedule Utilities', () => { expect(date.toISOString()).toBe('2025-08-05T08:15:00.000Z') }) - it('should correctly handle Europe/London (UTC+0 outside DST)', () => { + it.concurrent('should correctly handle Europe/London (UTC+0 outside DST)', () => { // February 10, 2025 is outside DST for London (GMT = UTC+0) const date = createDateWithTimezone( '2025-02-10', @@ -524,7 +524,7 @@ describe('Schedule Utilities', () => { expect(date.toISOString()).toBe('2025-02-10T09:15:00.000Z') }) - it('should correctly handle Asia/Tokyo (UTC+9)', () => { + it.concurrent('should correctly handle Asia/Tokyo (UTC+9)', () => { // Tokyo does not observe DST (JST = UTC+9) const date = createDateWithTimezone( '2025-07-01', @@ -535,14 +535,14 @@ describe('Schedule Utilities', () => { expect(date.toISOString()).toBe('2025-07-01T08:00:00.000Z') }) - it('should handle date object input', () => { + it.concurrent('should handle date object input', () => { // Using a Date object that represents midnight UTC on the target day const dateInput = new Date(Date.UTC(2025, 3, 21)) // April 21, 2025 const date = createDateWithTimezone(dateInput, '14:00', 'America/Los_Angeles') expect(date.toISOString()).toBe('2025-04-21T21:00:00.000Z') }) - it('should handle time crossing midnight due to timezone offset', () => { + it.concurrent('should handle time crossing midnight due to timezone offset', () => { // Test case: 1:00 AM local time in Sydney (UTC+10/11) // This might result in a UTC date that is the *previous* day. const date = createDateWithTimezone( diff --git a/apps/sim/lib/subscription/utils.test.ts b/apps/sim/lib/subscription/utils.test.ts index 95d45a0dc58..de6f78fc683 100644 --- a/apps/sim/lib/subscription/utils.test.ts +++ b/apps/sim/lib/subscription/utils.test.ts @@ -16,41 +16,41 @@ afterAll(() => { describe('Subscription Utilities', () => { describe('checkEnterprisePlan', () => { - it('returns true for active enterprise subscription', () => { + it.concurrent('returns true for active enterprise subscription', () => { expect(checkEnterprisePlan({ plan: 'enterprise', status: 'active' })).toBeTruthy() }) - it('returns false for inactive enterprise subscription', () => { + it.concurrent('returns false for inactive enterprise subscription', () => { expect(checkEnterprisePlan({ plan: 'enterprise', status: 'canceled' })).toBeFalsy() }) - it('returns false when plan is not enterprise', () => { + it.concurrent('returns false when plan is not enterprise', () => { expect(checkEnterprisePlan({ plan: 'pro', status: 'active' })).toBeFalsy() }) }) describe('calculateUsageLimit', () => { - it('returns free-tier limit when subscription is null', () => { + it.concurrent('returns free-tier limit when subscription is null', () => { expect(calculateUsageLimit(null)).toBe(5) }) - it('returns free-tier limit when subscription is undefined', () => { + it.concurrent('returns free-tier limit when subscription is undefined', () => { expect(calculateUsageLimit(undefined)).toBe(5) }) - it('returns free-tier limit when subscription is not active', () => { + it.concurrent('returns free-tier limit when subscription is not active', () => { expect(calculateUsageLimit({ plan: 'pro', status: 'canceled', seats: 1 })).toBe(5) }) - it('returns pro limit for active pro plan', () => { + it.concurrent('returns pro limit for active pro plan', () => { expect(calculateUsageLimit({ plan: 'pro', status: 'active', seats: 1 })).toBe(20) }) - it('returns team limit multiplied by seats', () => { + it.concurrent('returns team limit multiplied by seats', () => { expect(calculateUsageLimit({ plan: 'team', status: 'active', seats: 3 })).toBe(3 * 40) }) - it('returns enterprise limit using perSeatAllowance metadata', () => { + it.concurrent('returns enterprise limit using perSeatAllowance metadata', () => { const sub = { plan: 'enterprise', status: 'active', @@ -60,7 +60,7 @@ describe('Subscription Utilities', () => { expect(calculateUsageLimit(sub)).toBe(10 * 150) }) - it('returns enterprise limit using totalAllowance metadata', () => { + it.concurrent('returns enterprise limit using totalAllowance metadata', () => { const sub = { plan: 'enterprise', status: 'active', @@ -70,7 +70,7 @@ describe('Subscription Utilities', () => { expect(calculateUsageLimit(sub)).toBe(5000) }) - it('falls back to default enterprise tier when metadata missing', () => { + it.concurrent('falls back to default enterprise tier when metadata missing', () => { const sub = { plan: 'enterprise', status: 'active', seats: 2, metadata: {} } expect(calculateUsageLimit(sub)).toBe(2 * 200) }) diff --git a/apps/sim/lib/utils.test.ts b/apps/sim/lib/utils.test.ts index df54fefe8e4..7034c3e1c17 100644 --- a/apps/sim/lib/utils.test.ts +++ b/apps/sim/lib/utils.test.ts @@ -16,7 +16,6 @@ import { validateName, } from './utils' -// Mock crypto module for encryption/decryption tests vi.mock('crypto', () => ({ createCipheriv: vi.fn().mockReturnValue({ update: vi.fn().mockReturnValue('encrypted-data'), @@ -35,7 +34,6 @@ vi.mock('crypto', () => ({ }), })) -// Mock environment variables for encryption key beforeEach(() => { process.env.ENCRYPTION_KEY = '1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef' }) @@ -45,18 +43,18 @@ afterEach(() => { }) describe('generateApiKey', () => { - it('should generate API key with sim_ prefix', () => { + it.concurrent('should generate API key with sim_ prefix', () => { const key = generateApiKey() expect(key).toMatch(/^sim_/) }) - it('should generate unique API keys for each call', () => { + it.concurrent('should generate unique API keys for each call', () => { const key1 = generateApiKey() const key2 = generateApiKey() expect(key1).not.toBe(key2) }) - it('should generate API keys of correct length', () => { + it.concurrent('should generate API keys of correct length', () => { const key = generateApiKey() // Expected format: 'sim_' + 32 random characters expect(key.length).toBe(36) @@ -64,23 +62,23 @@ describe('generateApiKey', () => { }) describe('cn (class name utility)', () => { - it('should merge class names correctly', () => { + it.concurrent('should merge class names correctly', () => { const result = cn('class1', 'class2') expect(result).toBe('class1 class2') }) - it('should handle conditional classes', () => { + it.concurrent('should handle conditional classes', () => { const isActive = true const result = cn('base', isActive && 'active') expect(result).toBe('base active') }) - it('should handle falsy values', () => { + it.concurrent('should handle falsy values', () => { const result = cn('base', false && 'hidden', null, undefined, 0, '') expect(result).toBe('base') }) - it('should handle arrays of class names', () => { + it.concurrent('should handle arrays of class names', () => { const result = cn('base', ['class1', 'class2']) expect(result).toContain('base') expect(result).toContain('class1') @@ -89,7 +87,7 @@ describe('cn (class name utility)', () => { }) describe('encryption and decryption', () => { - it('should encrypt secrets correctly', async () => { + it.concurrent('should encrypt secrets correctly', async () => { const result = await encryptSecret('my-secret') expect(result).toHaveProperty('encrypted') expect(result).toHaveProperty('iv') @@ -99,34 +97,34 @@ describe('encryption and decryption', () => { expect(result.encrypted).toContain('auth-tag') }) - it('should decrypt secrets correctly', async () => { + it.concurrent('should decrypt secrets correctly', async () => { const result = await decryptSecret('iv:encrypted:authTag') expect(result).toHaveProperty('decrypted') expect(result.decrypted).toBe('decrypted-datafinal-data') }) - it('should throw error for invalid decrypt format', async () => { + it.concurrent('should throw error for invalid decrypt format', async () => { await expect(decryptSecret('invalid-format')).rejects.toThrow('Invalid encrypted value format') }) }) describe('convertScheduleOptionsToCron', () => { - it('should convert minutes schedule to cron', () => { + it.concurrent('should convert minutes schedule to cron', () => { const result = convertScheduleOptionsToCron('minutes', { minutesInterval: '5' }) expect(result).toBe('*/5 * * * *') }) - it('should convert hourly schedule to cron', () => { + it.concurrent('should convert hourly schedule to cron', () => { const result = convertScheduleOptionsToCron('hourly', { hourlyMinute: '30' }) expect(result).toBe('30 * * * *') }) - it('should convert daily schedule to cron', () => { + it.concurrent('should convert daily schedule to cron', () => { const result = convertScheduleOptionsToCron('daily', { dailyTime: '15:30' }) expect(result).toBe('15 30 * * *') }) - it('should convert weekly schedule to cron', () => { + it.concurrent('should convert weekly schedule to cron', () => { const result = convertScheduleOptionsToCron('weekly', { weeklyDay: 'MON', weeklyDayTime: '09:30', @@ -134,7 +132,7 @@ describe('convertScheduleOptionsToCron', () => { expect(result).toBe('09 30 * * 1') }) - it('should convert monthly schedule to cron', () => { + it.concurrent('should convert monthly schedule to cron', () => { const result = convertScheduleOptionsToCron('monthly', { monthlyDay: '15', monthlyTime: '12:00', @@ -142,38 +140,38 @@ describe('convertScheduleOptionsToCron', () => { expect(result).toBe('12 00 15 * *') }) - it('should use custom cron expression directly', () => { + it.concurrent('should use custom cron expression directly', () => { const customCron = '*/15 9-17 * * 1-5' const result = convertScheduleOptionsToCron('custom', { cronExpression: customCron }) expect(result).toBe(customCron) }) - it('should throw error for unsupported schedule type', () => { + it.concurrent('should throw error for unsupported schedule type', () => { expect(() => convertScheduleOptionsToCron('invalid', {})).toThrow('Unsupported schedule type') }) - it('should use default values when options are not provided', () => { + it.concurrent('should use default values when options are not provided', () => { const result = convertScheduleOptionsToCron('daily', {}) expect(result).toBe('00 09 * * *') }) }) describe('date formatting functions', () => { - it('should format datetime correctly', () => { + it.concurrent('should format datetime correctly', () => { const date = new Date('2023-05-15T14:30:00') const result = formatDateTime(date) expect(result).toMatch(/May 15, 2023/) expect(result).toMatch(/2:30 PM|14:30/) }) - it('should format date correctly', () => { + it.concurrent('should format date correctly', () => { const date = new Date('2023-05-15T14:30:00') const result = formatDate(date) expect(result).toMatch(/May 15, 2023/) expect(result).not.toMatch(/2:30|14:30/) }) - it('should format time correctly', () => { + it.concurrent('should format time correctly', () => { const date = new Date('2023-05-15T14:30:00') const result = formatTime(date) expect(result).toMatch(/2:30 PM|14:30/) @@ -182,34 +180,34 @@ describe('date formatting functions', () => { }) describe('formatDuration', () => { - it('should format milliseconds correctly', () => { + it.concurrent('should format milliseconds correctly', () => { const result = formatDuration(500) expect(result).toBe('500ms') }) - it('should format seconds correctly', () => { + it.concurrent('should format seconds correctly', () => { const result = formatDuration(5000) expect(result).toBe('5s') }) - it('should format minutes and seconds correctly', () => { + it.concurrent('should format minutes and seconds correctly', () => { const result = formatDuration(125000) // 2m 5s expect(result).toBe('2m 5s') }) - it('should format hours, minutes correctly', () => { + it.concurrent('should format hours, minutes correctly', () => { const result = formatDuration(3725000) // 1h 2m 5s expect(result).toBe('1h 2m') }) }) describe('getTimezoneAbbreviation', () => { - it('should return UTC for UTC timezone', () => { + it.concurrent('should return UTC for UTC timezone', () => { const result = getTimezoneAbbreviation('UTC') expect(result).toBe('UTC') }) - it('should return PST/PDT for Los Angeles timezone', () => { + it.concurrent('should return PST/PDT for Los Angeles timezone', () => { const winterDate = new Date('2023-01-15') // Standard time const summerDate = new Date('2023-07-15') // Daylight time @@ -220,7 +218,7 @@ describe('getTimezoneAbbreviation', () => { expect(['PST', 'PDT']).toContain(summerResult) }) - it('should return JST for Tokyo timezone (no DST)', () => { + it.concurrent('should return JST for Tokyo timezone (no DST)', () => { const winterDate = new Date('2023-01-15') const summerDate = new Date('2023-07-15') @@ -231,14 +229,14 @@ describe('getTimezoneAbbreviation', () => { expect(summerResult).toBe('JST') }) - it('should return full timezone name for unknown timezones', () => { + it.concurrent('should return full timezone name for unknown timezones', () => { const result = getTimezoneAbbreviation('Unknown/Timezone') expect(result).toBe('Unknown/Timezone') }) }) describe('redactApiKeys', () => { - it('should redact API keys in objects', () => { + it.concurrent('should redact API keys in objects', () => { const obj = { apiKey: 'secret-key', api_key: 'another-secret', @@ -258,7 +256,7 @@ describe('redactApiKeys', () => { expect(result.normalField).toBe('normal-value') }) - it('should redact API keys in nested objects', () => { + it.concurrent('should redact API keys in nested objects', () => { const obj = { config: { apiKey: 'secret-key', @@ -272,7 +270,7 @@ describe('redactApiKeys', () => { expect(result.config.normalField).toBe('normal-value') }) - it('should redact API keys in arrays', () => { + it.concurrent('should redact API keys in arrays', () => { const arr = [{ apiKey: 'secret-key-1' }, { apiKey: 'secret-key-2' }] const result = redactApiKeys(arr) @@ -281,14 +279,14 @@ describe('redactApiKeys', () => { expect(result[1].apiKey).toBe('***REDACTED***') }) - it('should handle primitive values', () => { + it.concurrent('should handle primitive values', () => { expect(redactApiKeys('string')).toBe('string') expect(redactApiKeys(123)).toBe(123) expect(redactApiKeys(null)).toBe(null) expect(redactApiKeys(undefined)).toBe(undefined) }) - it('should handle complex nested structures', () => { + it.concurrent('should handle complex nested structures', () => { const obj = { users: [ { @@ -318,49 +316,49 @@ describe('redactApiKeys', () => { }) describe('validateName', () => { - it('should remove invalid characters', () => { + it.concurrent('should remove invalid characters', () => { const result = validateName('test@#$%name') expect(result).toBe('testname') }) - it('should keep valid characters', () => { + it.concurrent('should keep valid characters', () => { const result = validateName('test_name_123') expect(result).toBe('test_name_123') }) - it('should keep spaces', () => { + it.concurrent('should keep spaces', () => { const result = validateName('test name') expect(result).toBe('test name') }) - it('should handle empty string', () => { + it.concurrent('should handle empty string', () => { const result = validateName('') expect(result).toBe('') }) - it('should handle string with only invalid characters', () => { + it.concurrent('should handle string with only invalid characters', () => { const result = validateName('@#$%') expect(result).toBe('') }) - it('should handle mixed valid and invalid characters', () => { + it.concurrent('should handle mixed valid and invalid characters', () => { const result = validateName('my-workflow@2023!') expect(result).toBe('myworkflow2023') }) - it('should collapse multiple spaces into single spaces', () => { + it.concurrent('should collapse multiple spaces into single spaces', () => { const result = validateName('test multiple spaces') expect(result).toBe('test multiple spaces') }) - it('should handle mixed whitespace and invalid characters', () => { + it.concurrent('should handle mixed whitespace and invalid characters', () => { const result = validateName('test@#$ name') expect(result).toBe('test name') }) }) describe('isValidName', () => { - it('should return true for valid names', () => { + it.concurrent('should return true for valid names', () => { expect(isValidName('test_name')).toBe(true) expect(isValidName('test123')).toBe(true) expect(isValidName('test name')).toBe(true) @@ -368,7 +366,7 @@ describe('isValidName', () => { expect(isValidName('')).toBe(true) }) - it('should return false for invalid names', () => { + it.concurrent('should return false for invalid names', () => { expect(isValidName('test@name')).toBe(false) expect(isValidName('test-name')).toBe(false) expect(isValidName('test#name')).toBe(false) @@ -378,27 +376,27 @@ describe('isValidName', () => { }) describe('getInvalidCharacters', () => { - it('should return empty array for valid names', () => { + it.concurrent('should return empty array for valid names', () => { const result = getInvalidCharacters('test_name_123') expect(result).toEqual([]) }) - it('should return invalid characters', () => { + it.concurrent('should return invalid characters', () => { const result = getInvalidCharacters('test@#$name') expect(result).toEqual(['@', '#', '$']) }) - it('should return unique invalid characters', () => { + it.concurrent('should return unique invalid characters', () => { const result = getInvalidCharacters('test@@##name') expect(result).toEqual(['@', '#']) }) - it('should handle empty string', () => { + it.concurrent('should handle empty string', () => { const result = getInvalidCharacters('') expect(result).toEqual([]) }) - it('should handle string with only invalid characters', () => { + it.concurrent('should handle string with only invalid characters', () => { const result = getInvalidCharacters('@#$%') expect(result).toEqual(['@', '#', '$', '%']) }) diff --git a/apps/sim/lib/variables/variable-manager.test.ts b/apps/sim/lib/variables/variable-manager.test.ts index 00e564ca74d..5cdb8f3303f 100644 --- a/apps/sim/lib/variables/variable-manager.test.ts +++ b/apps/sim/lib/variables/variable-manager.test.ts @@ -3,27 +3,27 @@ import { VariableManager } from './variable-manager' describe('VariableManager', () => { describe('parseInputForStorage', () => { - it('should handle plain type variables', () => { + it.concurrent('should handle plain type variables', () => { expect(VariableManager.parseInputForStorage('hello world', 'plain')).toBe('hello world') expect(VariableManager.parseInputForStorage('123', 'plain')).toBe('123') expect(VariableManager.parseInputForStorage('true', 'plain')).toBe('true') expect(VariableManager.parseInputForStorage('{"foo":"bar"}', 'plain')).toBe('{"foo":"bar"}') }) - it('should handle string type variables', () => { + it.concurrent('should handle string type variables', () => { expect(VariableManager.parseInputForStorage('hello world', 'string')).toBe('hello world') expect(VariableManager.parseInputForStorage('"hello world"', 'string')).toBe('hello world') expect(VariableManager.parseInputForStorage("'hello world'", 'string')).toBe('hello world') }) - it('should handle number type variables', () => { + it.concurrent('should handle number type variables', () => { expect(VariableManager.parseInputForStorage('42', 'number')).toBe(42) expect(VariableManager.parseInputForStorage('-3.14', 'number')).toBe(-3.14) expect(VariableManager.parseInputForStorage('"42"', 'number')).toBe(42) expect(VariableManager.parseInputForStorage('not a number', 'number')).toBe(0) }) - it('should handle boolean type variables', () => { + it.concurrent('should handle boolean type variables', () => { expect(VariableManager.parseInputForStorage('true', 'boolean')).toBe(true) expect(VariableManager.parseInputForStorage('false', 'boolean')).toBe(false) expect(VariableManager.parseInputForStorage('1', 'boolean')).toBe(true) @@ -32,7 +32,7 @@ describe('VariableManager', () => { expect(VariableManager.parseInputForStorage("'false'", 'boolean')).toBe(false) }) - it('should handle object type variables', () => { + it.concurrent('should handle object type variables', () => { expect(VariableManager.parseInputForStorage('{"foo":"bar"}', 'object')).toEqual({ foo: 'bar', }) @@ -40,13 +40,13 @@ describe('VariableManager', () => { expect(VariableManager.parseInputForStorage('42', 'object')).toEqual({ value: '42' }) }) - it('should handle array type variables', () => { + it.concurrent('should handle array type variables', () => { expect(VariableManager.parseInputForStorage('[1,2,3]', 'array')).toEqual([1, 2, 3]) expect(VariableManager.parseInputForStorage('invalid json', 'array')).toEqual([]) expect(VariableManager.parseInputForStorage('42', 'array')).toEqual(['42']) }) - it('should handle empty values', () => { + it.concurrent('should handle empty values', () => { expect(VariableManager.parseInputForStorage('', 'string')).toBe('') expect(VariableManager.parseInputForStorage('', 'number')).toBe('') expect(VariableManager.parseInputForStorage(null as any, 'boolean')).toBe('') @@ -55,32 +55,32 @@ describe('VariableManager', () => { }) describe('formatForEditor', () => { - it('should format plain type variables for editor', () => { + it.concurrent('should format plain type variables for editor', () => { expect(VariableManager.formatForEditor('hello world', 'plain')).toBe('hello world') expect(VariableManager.formatForEditor(42, 'plain')).toBe('42') expect(VariableManager.formatForEditor(true, 'plain')).toBe('true') }) - it('should format string type variables for editor', () => { + it.concurrent('should format string type variables for editor', () => { expect(VariableManager.formatForEditor('hello world', 'string')).toBe('hello world') expect(VariableManager.formatForEditor(42, 'string')).toBe('42') expect(VariableManager.formatForEditor(true, 'string')).toBe('true') }) - it('should format number type variables for editor', () => { + it.concurrent('should format number type variables for editor', () => { expect(VariableManager.formatForEditor(42, 'number')).toBe('42') expect(VariableManager.formatForEditor('42', 'number')).toBe('42') expect(VariableManager.formatForEditor('not a number', 'number')).toBe('0') }) - it('should format boolean type variables for editor', () => { + it.concurrent('should format boolean type variables for editor', () => { expect(VariableManager.formatForEditor(true, 'boolean')).toBe('true') expect(VariableManager.formatForEditor(false, 'boolean')).toBe('false') expect(VariableManager.formatForEditor('true', 'boolean')).toBe('true') expect(VariableManager.formatForEditor('anything else', 'boolean')).toBe('true') }) - it('should format object type variables for editor', () => { + it.concurrent('should format object type variables for editor', () => { expect(VariableManager.formatForEditor({ foo: 'bar' }, 'object')).toBe('{\n "foo": "bar"\n}') expect(VariableManager.formatForEditor('{"foo":"bar"}', 'object')).toBe( '{\n "foo": "bar"\n}' @@ -90,7 +90,7 @@ describe('VariableManager', () => { ) }) - it('should format array type variables for editor', () => { + it.concurrent('should format array type variables for editor', () => { expect(VariableManager.formatForEditor([1, 2, 3], 'array')).toBe('[\n 1,\n 2,\n 3\n]') expect(VariableManager.formatForEditor('[1,2,3]', 'array')).toBe('[\n 1,\n 2,\n 3\n]') expect(VariableManager.formatForEditor('invalid json', 'array')).toEqual( @@ -98,32 +98,32 @@ describe('VariableManager', () => { ) }) - it('should handle empty values', () => { + it.concurrent('should handle empty values', () => { expect(VariableManager.formatForEditor(null, 'string')).toBe('') expect(VariableManager.formatForEditor(undefined, 'number')).toBe('') }) }) describe('resolveForExecution', () => { - it('should resolve plain type variables for execution', () => { + it.concurrent('should resolve plain type variables for execution', () => { expect(VariableManager.resolveForExecution('hello world', 'plain')).toBe('hello world') expect(VariableManager.resolveForExecution(42, 'plain')).toBe('42') expect(VariableManager.resolveForExecution(true, 'plain')).toBe('true') }) - it('should resolve string type variables for execution', () => { + it.concurrent('should resolve string type variables for execution', () => { expect(VariableManager.resolveForExecution('hello world', 'string')).toBe('hello world') expect(VariableManager.resolveForExecution(42, 'string')).toBe('42') expect(VariableManager.resolveForExecution(true, 'string')).toBe('true') }) - it('should resolve number type variables for execution', () => { + it.concurrent('should resolve number type variables for execution', () => { expect(VariableManager.resolveForExecution(42, 'number')).toBe(42) expect(VariableManager.resolveForExecution('42', 'number')).toBe(42) expect(VariableManager.resolveForExecution('not a number', 'number')).toBe(0) }) - it('should resolve boolean type variables for execution', () => { + it.concurrent('should resolve boolean type variables for execution', () => { expect(VariableManager.resolveForExecution(true, 'boolean')).toBe(true) expect(VariableManager.resolveForExecution(false, 'boolean')).toBe(false) expect(VariableManager.resolveForExecution('true', 'boolean')).toBe(true) @@ -132,26 +132,26 @@ describe('VariableManager', () => { expect(VariableManager.resolveForExecution('0', 'boolean')).toBe(false) }) - it('should resolve object type variables for execution', () => { + it.concurrent('should resolve object type variables for execution', () => { expect(VariableManager.resolveForExecution({ foo: 'bar' }, 'object')).toEqual({ foo: 'bar' }) expect(VariableManager.resolveForExecution('{"foo":"bar"}', 'object')).toEqual({ foo: 'bar' }) expect(VariableManager.resolveForExecution('invalid json', 'object')).toEqual({}) }) - it('should resolve array type variables for execution', () => { + it.concurrent('should resolve array type variables for execution', () => { expect(VariableManager.resolveForExecution([1, 2, 3], 'array')).toEqual([1, 2, 3]) expect(VariableManager.resolveForExecution('[1,2,3]', 'array')).toEqual([1, 2, 3]) expect(VariableManager.resolveForExecution('invalid json', 'array')).toEqual([]) }) - it('should handle null and undefined', () => { + it.concurrent('should handle null and undefined', () => { expect(VariableManager.resolveForExecution(null, 'string')).toBe(null) expect(VariableManager.resolveForExecution(undefined, 'number')).toBe(undefined) }) }) describe('formatForTemplateInterpolation', () => { - it('should format plain type variables for interpolation', () => { + it.concurrent('should format plain type variables for interpolation', () => { expect(VariableManager.formatForTemplateInterpolation('hello world', 'plain')).toBe( 'hello world' ) @@ -159,7 +159,7 @@ describe('VariableManager', () => { expect(VariableManager.formatForTemplateInterpolation(true, 'plain')).toBe('true') }) - it('should format string type variables for interpolation', () => { + it.concurrent('should format string type variables for interpolation', () => { expect(VariableManager.formatForTemplateInterpolation('hello world', 'string')).toBe( 'hello world' ) @@ -167,7 +167,7 @@ describe('VariableManager', () => { expect(VariableManager.formatForTemplateInterpolation(true, 'string')).toBe('true') }) - it('should format object type variables for interpolation', () => { + it.concurrent('should format object type variables for interpolation', () => { expect(VariableManager.formatForTemplateInterpolation({ foo: 'bar' }, 'object')).toBe( '{"foo":"bar"}' ) @@ -176,48 +176,48 @@ describe('VariableManager', () => { ) }) - it('should handle empty values', () => { + it.concurrent('should handle empty values', () => { expect(VariableManager.formatForTemplateInterpolation(null, 'string')).toBe('') expect(VariableManager.formatForTemplateInterpolation(undefined, 'number')).toBe('') }) }) describe('formatForCodeContext', () => { - it('should format plain type variables for code context', () => { + it.concurrent('should format plain type variables for code context', () => { expect(VariableManager.formatForCodeContext('hello world', 'plain')).toBe('hello world') expect(VariableManager.formatForCodeContext(42, 'plain')).toBe('42') expect(VariableManager.formatForCodeContext(true, 'plain')).toBe('true') }) - it('should format string type variables for code context', () => { + it.concurrent('should format string type variables for code context', () => { expect(VariableManager.formatForCodeContext('hello world', 'string')).toBe('"hello world"') expect(VariableManager.formatForCodeContext(42, 'string')).toBe('42') expect(VariableManager.formatForCodeContext(true, 'string')).toBe('true') }) - it('should format number type variables for code context', () => { + it.concurrent('should format number type variables for code context', () => { expect(VariableManager.formatForCodeContext(42, 'number')).toBe('42') expect(VariableManager.formatForCodeContext('42', 'number')).toBe('42') }) - it('should format boolean type variables for code context', () => { + it.concurrent('should format boolean type variables for code context', () => { expect(VariableManager.formatForCodeContext(true, 'boolean')).toBe('true') expect(VariableManager.formatForCodeContext(false, 'boolean')).toBe('false') }) - it('should format object and array types for code context', () => { + it.concurrent('should format object and array types for code context', () => { expect(VariableManager.formatForCodeContext({ foo: 'bar' }, 'object')).toBe('{"foo":"bar"}') expect(VariableManager.formatForCodeContext([1, 2, 3], 'array')).toBe('[1,2,3]') }) - it('should handle null and undefined', () => { + it.concurrent('should handle null and undefined', () => { expect(VariableManager.formatForCodeContext(null, 'string')).toBe('null') expect(VariableManager.formatForCodeContext(undefined, 'number')).toBe('undefined') }) }) describe('shouldStripQuotesForDisplay', () => { - it('should identify strings that need quotes stripped', () => { + it.concurrent('should identify strings that need quotes stripped', () => { expect(VariableManager.shouldStripQuotesForDisplay('"hello world"')).toBe(true) expect(VariableManager.shouldStripQuotesForDisplay("'hello world'")).toBe(true) expect(VariableManager.shouldStripQuotesForDisplay('hello world')).toBe(false) @@ -225,7 +225,7 @@ describe('VariableManager', () => { expect(VariableManager.shouldStripQuotesForDisplay("''")).toBe(false) // Too short }) - it('should handle edge cases', () => { + it.concurrent('should handle edge cases', () => { expect(VariableManager.shouldStripQuotesForDisplay('')).toBe(false) expect(VariableManager.shouldStripQuotesForDisplay(null as any)).toBe(false) expect(VariableManager.shouldStripQuotesForDisplay(undefined as any)).toBe(false) diff --git a/apps/sim/providers/anthropic/index.ts b/apps/sim/providers/anthropic/index.ts index ae0dba0f6fd..59e9c19e6da 100644 --- a/apps/sim/providers/anthropic/index.ts +++ b/apps/sim/providers/anthropic/index.ts @@ -2,6 +2,7 @@ import Anthropic from '@anthropic-ai/sdk' import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' import { executeTool } from '@/tools' +import { getProviderDefaultModel, getProviderModels } from '../models' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' import { prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils' @@ -36,13 +37,8 @@ export const anthropicProvider: ProviderConfig = { name: 'Anthropic', description: "Anthropic's Claude models", version: '1.0.0', - models: [ - 'claude-sonnet-4-0', - 'claude-opus-4-0', - 'claude-3-7-sonnet-latest', - 'claude-3-5-sonnet-latest', - ], - defaultModel: 'claude-sonnet-4-0', + models: getProviderModels('anthropic'), + defaultModel: getProviderDefaultModel('anthropic'), executeRequest: async ( request: ProviderRequest diff --git a/apps/sim/providers/azure-openai/index.ts b/apps/sim/providers/azure-openai/index.ts index 578a12741cb..1c7dfa660d5 100644 --- a/apps/sim/providers/azure-openai/index.ts +++ b/apps/sim/providers/azure-openai/index.ts @@ -3,6 +3,7 @@ import { env } from '@/lib/env' import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' import { executeTool } from '@/tools' +import { getProviderDefaultModel, getProviderModels } from '../models' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' import { prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils' @@ -56,8 +57,8 @@ export const azureOpenAIProvider: ProviderConfig = { name: 'Azure OpenAI', description: 'Microsoft Azure OpenAI Service models', version: '1.0.0', - models: ['azure/gpt-4o', 'azure/o3', 'azure/o4-mini', 'azure/gpt-4.1', 'azure/model-router'], - defaultModel: 'azure/gpt-4o', + models: getProviderModels('azure-openai'), + defaultModel: getProviderDefaultModel('azure-openai'), executeRequest: async ( request: ProviderRequest diff --git a/apps/sim/providers/cerebras/index.ts b/apps/sim/providers/cerebras/index.ts index 3479ce0c259..d1cd664eeb1 100644 --- a/apps/sim/providers/cerebras/index.ts +++ b/apps/sim/providers/cerebras/index.ts @@ -2,6 +2,7 @@ import { Cerebras } from '@cerebras/cerebras_cloud_sdk' import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' import { executeTool } from '@/tools' +import { getProviderDefaultModel, getProviderModels } from '../models' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' const logger = createLogger('CerebrasProvider') @@ -36,8 +37,9 @@ export const cerebrasProvider: ProviderConfig = { name: 'Cerebras', description: 'Cerebras Cloud LLMs', version: '1.0.0', - models: ['cerebras/llama-3.3-70b'], - defaultModel: 'cerebras/llama-3.3-70b', + models: getProviderModels('cerebras'), + defaultModel: getProviderDefaultModel('cerebras'), + executeRequest: async ( request: ProviderRequest ): Promise => { diff --git a/apps/sim/providers/deepseek/index.ts b/apps/sim/providers/deepseek/index.ts index cfeb4813f27..6b18551ba38 100644 --- a/apps/sim/providers/deepseek/index.ts +++ b/apps/sim/providers/deepseek/index.ts @@ -2,6 +2,7 @@ import OpenAI from 'openai' import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' import { executeTool } from '@/tools' +import { getProviderDefaultModel, getProviderModels } from '../models' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' import { prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils' @@ -34,8 +35,8 @@ export const deepseekProvider: ProviderConfig = { name: 'Deepseek', description: "Deepseek's chat models", version: '1.0.0', - models: ['deepseek-chat'], - defaultModel: 'deepseek-chat', + models: getProviderModels('deepseek'), + defaultModel: getProviderDefaultModel('deepseek'), executeRequest: async ( request: ProviderRequest diff --git a/apps/sim/providers/google/index.ts b/apps/sim/providers/google/index.ts index 841a0e831d1..594799e339c 100644 --- a/apps/sim/providers/google/index.ts +++ b/apps/sim/providers/google/index.ts @@ -1,6 +1,7 @@ import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' import { executeTool } from '@/tools' +import { getProviderDefaultModel, getProviderModels } from '../models' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' const logger = createLogger('GoogleProvider') @@ -90,8 +91,8 @@ export const googleProvider: ProviderConfig = { name: 'Google', description: "Google's Gemini models", version: '1.0.0', - models: ['gemini-2.5-pro', 'gemini-2.5-flash'], - defaultModel: 'gemini-2.5-pro', + models: getProviderModels('google'), + defaultModel: getProviderDefaultModel('google'), executeRequest: async ( request: ProviderRequest diff --git a/apps/sim/providers/groq/index.ts b/apps/sim/providers/groq/index.ts index b67fd619507..b67f7fdc3d3 100644 --- a/apps/sim/providers/groq/index.ts +++ b/apps/sim/providers/groq/index.ts @@ -2,6 +2,7 @@ import { Groq } from 'groq-sdk' import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' import { executeTool } from '@/tools' +import { getProviderDefaultModel, getProviderModels } from '../models' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' const logger = createLogger('GroqProvider') @@ -32,12 +33,8 @@ export const groqProvider: ProviderConfig = { name: 'Groq', description: "Groq's LLM models with high-performance inference", version: '1.0.0', - models: [ - 'groq/meta-llama/llama-4-scout-17b-16e-instruct', - 'groq/deepseek-r1-distill-llama-70b', - 'groq/qwen-qwq-32b', - ], - defaultModel: 'groq/meta-llama/llama-4-scout-17b-16e-instruct', + models: getProviderModels('groq'), + defaultModel: getProviderDefaultModel('groq'), executeRequest: async ( request: ProviderRequest diff --git a/apps/sim/providers/index.ts b/apps/sim/providers/index.ts index 0b3dccdd754..d1d53011a0d 100644 --- a/apps/sim/providers/index.ts +++ b/apps/sim/providers/index.ts @@ -1,17 +1,19 @@ import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' -import { supportsTemperature } from './model-capabilities' import type { ProviderRequest, ProviderResponse } from './types' -import { calculateCost, generateStructuredOutputInstructions, getProvider } from './utils' +import { + calculateCost, + generateStructuredOutputInstructions, + getProvider, + supportsTemperature, +} from './utils' const logger = createLogger('Providers') // Sanitize the request by removing parameters that aren't supported by the model function sanitizeRequest(request: ProviderRequest): ProviderRequest { - // Create a shallow copy of the request const sanitizedRequest = { ...request } - // Remove temperature if the model doesn't support it if (sanitizedRequest.model && !supportsTemperature(sanitizedRequest.model)) { sanitizedRequest.temperature = undefined } @@ -19,12 +21,10 @@ function sanitizeRequest(request: ProviderRequest): ProviderRequest { return sanitizedRequest } -// Type guard for StreamingExecution function isStreamingExecution(response: any): response is StreamingExecution { return response && typeof response === 'object' && 'stream' in response && 'execution' in response } -// Type guard for ReadableStream function isReadableStream(response: any): response is ReadableStream { return response instanceof ReadableStream } diff --git a/apps/sim/providers/model-capabilities.test.ts b/apps/sim/providers/model-capabilities.test.ts deleted file mode 100644 index c27a31a256f..00000000000 --- a/apps/sim/providers/model-capabilities.test.ts +++ /dev/null @@ -1,80 +0,0 @@ -import { describe, expect, it } from 'vitest' -import { - getMaxTemperature, - PROVIDERS_WITH_TOOL_USAGE_CONTROL, - supportsTemperature, - supportsToolUsageControl, -} from './model-capabilities' - -describe('supportsToolUsageControl', () => { - it('should return true for providers that support tool usage control', () => { - // Test each provider that should support tool usage control - for (const provider of PROVIDERS_WITH_TOOL_USAGE_CONTROL) { - expect(supportsToolUsageControl(provider)).toBe(true) - } - }) - - it('should return false for providers that do not support tool usage control', () => { - const unsupportedProviders = ['google', 'ollama', 'non-existent-provider'] - - for (const provider of unsupportedProviders) { - expect(supportsToolUsageControl(provider)).toBe(false) - } - }) -}) - -describe('supportsTemperature', () => { - it('should return true for models that support temperature', () => { - const supportedModels = [ - 'gpt-4o', - 'gemini-2.5-flash', - 'claude-sonnet-4-0', - 'claude-opus-4-0', - 'claude-3-7-sonnet-latest', - 'claude-3-5-sonnet-latest', - 'grok-3-latest', - 'grok-3-fast-latest', - ] - - for (const model of supportedModels) { - expect(supportsTemperature(model)).toBe(true) - } - }) - - it('should return false for models that do not support temperature', () => { - const unsupportedModels = ['unsupported-model'] - - for (const model of unsupportedModels) { - expect(supportsTemperature(model)).toBe(false) - } - }) -}) - -describe('getMaxTemperature', () => { - it('should return 2 for models with temperature range 0-2', () => { - const models = ['gpt-4o', 'gemini-2.5-flash', 'deepseek-v3'] - - for (const model of models) { - expect(getMaxTemperature(model)).toBe(2) - } - }) - - it('should return 1 for models with temperature range 0-1', () => { - const models = [ - 'claude-sonnet-4-0', - 'claude-opus-4-0', - 'claude-3-7-sonnet-latest', - 'claude-3-5-sonnet-latest', - 'grok-3-latest', - 'grok-3-fast-latest', - ] - - for (const model of models) { - expect(getMaxTemperature(model)).toBe(1) - } - }) - - it('should return undefined for models that do not support temperature', () => { - expect(getMaxTemperature('unsupported-model')).toBeUndefined() - }) -}) diff --git a/apps/sim/providers/model-capabilities.ts b/apps/sim/providers/model-capabilities.ts deleted file mode 100644 index 7bcd1c2d47d..00000000000 --- a/apps/sim/providers/model-capabilities.ts +++ /dev/null @@ -1,83 +0,0 @@ -/** - * This file defines model capabilities and constraints - * It serves as a single source of truth for model-specific features - */ - -// Models that support temperature with range 0-2 -export const MODELS_TEMP_RANGE_0_2 = [ - // OpenAI models - 'gpt-4o', - // Azure OpenAI models - 'azure/gpt-4o', - // Google models - 'gemini-2.5-pro', - 'gemini-2.5-flash', - // Deepseek models - 'deepseek-v3', -] - -// Models that support temperature with range 0-1 -export const MODELS_TEMP_RANGE_0_1 = [ - // Anthropic models - 'claude-sonnet-4-0', - 'claude-opus-4-0', - 'claude-3-7-sonnet-latest', - 'claude-3-5-sonnet-latest', - // xAI models - 'grok-3-latest', - 'grok-3-fast-latest', -] - -// All models that support temperature (combined list) -export const MODELS_WITH_TEMPERATURE_SUPPORT = [...MODELS_TEMP_RANGE_0_2, ...MODELS_TEMP_RANGE_0_1] - -// Models and their providers that support tool usage control (force, auto, none) -export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = [ - 'openai', - 'azure-openai', - 'anthropic', - 'deepseek', - 'xai', -] - -/** - * Check if a model supports temperature parameter - */ -export function supportsTemperature(model: string): boolean { - // Normalize model name for comparison - const normalizedModel = model.toLowerCase() - - // Check if model is in the supported list - return MODELS_WITH_TEMPERATURE_SUPPORT.some( - (supportedModel) => supportedModel.toLowerCase() === normalizedModel - ) -} - -/** - * Get the maximum temperature value for a model - * @returns Maximum temperature value (1 or 2) or undefined if temperature not supported - */ -export function getMaxTemperature(model: string): number | undefined { - // Normalize model name for comparison - const normalizedModel = model.toLowerCase() - - // Check if model is in the 0-2 range - if (MODELS_TEMP_RANGE_0_2.some((m) => m.toLowerCase() === normalizedModel)) { - return 2 - } - - // Check if model is in the 0-1 range - if (MODELS_TEMP_RANGE_0_1.some((m) => m.toLowerCase() === normalizedModel)) { - return 1 - } - - // Temperature not supported - return undefined -} - -/** - * Check if a provider supports tool usage control - */ -export function supportsToolUsageControl(provider: string): boolean { - return PROVIDERS_WITH_TOOL_USAGE_CONTROL.includes(provider) -} diff --git a/apps/sim/providers/models.ts b/apps/sim/providers/models.ts new file mode 100644 index 00000000000..f84bd97986a --- /dev/null +++ b/apps/sim/providers/models.ts @@ -0,0 +1,635 @@ +/** + * Comprehensive provider definitions - Single source of truth + * This file contains all provider and model information including: + * - Model lists + * - Pricing information + * - Model capabilities (temperature support, etc.) + * - Provider configurations + */ + +export interface ModelPricing { + input: number // Per 1M tokens + cachedInput?: number // Per 1M tokens (if supported) + output: number // Per 1M tokens + updatedAt: string +} + +export interface ModelCapabilities { + temperature?: { + min: number + max: number + } + toolUsageControl?: boolean + computerUse?: boolean +} + +export interface ModelDefinition { + id: string + pricing: ModelPricing + capabilities: ModelCapabilities +} + +export interface ProviderDefinition { + id: string + name: string + description: string + models: ModelDefinition[] + defaultModel: string + modelPatterns?: RegExp[] +} + +/** + * Comprehensive provider definitions, single source of truth + */ +export const PROVIDER_DEFINITIONS: Record = { + openai: { + id: 'openai', + name: 'OpenAI', + description: "OpenAI's models", + defaultModel: 'gpt-4o', + modelPatterns: [/^gpt/, /^o1/], + models: [ + { + id: 'gpt-4o', + pricing: { + input: 2.5, + cachedInput: 1.25, + output: 10.0, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'o1', + pricing: { + input: 15.0, + cachedInput: 7.5, + output: 60, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'o3', + pricing: { + input: 2, + cachedInput: 0.5, + output: 8, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'o4-mini', + pricing: { + input: 1.1, + cachedInput: 0.275, + output: 4.4, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'gpt-4.1', + pricing: { + input: 2.0, + cachedInput: 0.5, + output: 8.0, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'gpt-4.1-nano', + pricing: { + input: 0.1, + cachedInput: 0.025, + output: 0.4, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'gpt-4.1-mini', + pricing: { + input: 0.4, + cachedInput: 0.1, + output: 1.6, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + ], + }, + 'azure-openai': { + id: 'azure-openai', + name: 'Azure OpenAI', + description: 'Microsoft Azure OpenAI Service models', + defaultModel: 'azure/gpt-4o', + modelPatterns: [/^azure\//], + models: [ + { + id: 'azure/gpt-4o', + pricing: { + input: 2.5, + cachedInput: 1.25, + output: 10.0, + updatedAt: '2025-06-15', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'azure/o3', + pricing: { + input: 10, + cachedInput: 2.5, + output: 40, + updatedAt: '2025-06-15', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'azure/o4-mini', + pricing: { + input: 1.1, + cachedInput: 0.275, + output: 4.4, + updatedAt: '2025-06-15', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'azure/gpt-4.1', + pricing: { + input: 2.0, + cachedInput: 0.5, + output: 8.0, + updatedAt: '2025-06-15', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'azure/model-router', + pricing: { + input: 2.0, + cachedInput: 0.5, + output: 8.0, + updatedAt: '2025-06-15', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + ], + }, + anthropic: { + id: 'anthropic', + name: 'Anthropic', + description: "Anthropic's Claude models", + defaultModel: 'claude-sonnet-4-0', + modelPatterns: [/^claude/], + models: [ + { + id: 'claude-sonnet-4-0', + pricing: { + input: 3.0, + cachedInput: 1.5, + output: 15.0, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 1 }, + toolUsageControl: true, + }, + }, + { + id: 'claude-opus-4-0', + pricing: { + input: 15.0, + cachedInput: 7.5, + output: 75.0, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 1 }, + toolUsageControl: true, + }, + }, + { + id: 'claude-3-7-sonnet-latest', + pricing: { + input: 3.0, + cachedInput: 1.5, + output: 15.0, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 1 }, + toolUsageControl: true, + computerUse: true, + }, + }, + { + id: 'claude-3-5-sonnet-latest', + pricing: { + input: 3.0, + cachedInput: 1.5, + output: 15.0, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 1 }, + toolUsageControl: true, + computerUse: true, + }, + }, + ], + }, + google: { + id: 'google', + name: 'Google', + description: "Google's Gemini models", + defaultModel: 'gemini-2.5-pro', + modelPatterns: [/^gemini/], + models: [ + { + id: 'gemini-2.5-pro', + pricing: { + input: 0.15, + cachedInput: 0.075, + output: 0.6, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: false, + }, + }, + { + id: 'gemini-2.5-flash', + pricing: { + input: 0.15, + cachedInput: 0.075, + output: 0.6, + updatedAt: '2025-06-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: false, + }, + }, + ], + }, + deepseek: { + id: 'deepseek', + name: 'Deepseek', + description: "Deepseek's chat models", + defaultModel: 'deepseek-chat', + modelPatterns: [], + models: [ + { + id: 'deepseek-chat', + pricing: { + input: 0.75, + cachedInput: 0.4, + output: 1.0, + updatedAt: '2025-03-21', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'deepseek-v3', + pricing: { + input: 0.75, + cachedInput: 0.4, + output: 1.0, + updatedAt: '2025-03-21', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + { + id: 'deepseek-r1', + pricing: { + input: 1.0, + cachedInput: 0.5, + output: 1.5, + updatedAt: '2025-03-21', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + }, + ], + }, + xai: { + id: 'xai', + name: 'xAI', + description: "xAI's Grok models", + defaultModel: 'grok-3-latest', + modelPatterns: [/^grok/], + models: [ + { + id: 'grok-3-latest', + pricing: { + input: 3.0, + cachedInput: 1.5, + output: 15.0, + updatedAt: '2025-04-17', + }, + capabilities: { + temperature: { min: 0, max: 1 }, + toolUsageControl: true, + }, + }, + { + id: 'grok-3-fast-latest', + pricing: { + input: 5.0, + cachedInput: 2.5, + output: 25.0, + updatedAt: '2025-04-17', + }, + capabilities: { + temperature: { min: 0, max: 1 }, + toolUsageControl: true, + }, + }, + ], + }, + cerebras: { + id: 'cerebras', + name: 'Cerebras', + description: 'Cerebras Cloud LLMs', + defaultModel: 'cerebras/llama-3.3-70b', + modelPatterns: [/^cerebras/], + models: [ + { + id: 'cerebras/llama-3.3-70b', + pricing: { + input: 0.94, + cachedInput: 0.47, + output: 0.94, + updatedAt: '2025-03-21', + }, + capabilities: { + toolUsageControl: false, + }, + }, + ], + }, + groq: { + id: 'groq', + name: 'Groq', + description: "Groq's LLM models with high-performance inference", + defaultModel: 'groq/meta-llama/llama-4-scout-17b-16e-instruct', + modelPatterns: [/^groq/], + models: [ + { + id: 'groq/meta-llama/llama-4-scout-17b-16e-instruct', + pricing: { + input: 0.4, + cachedInput: 0.2, + output: 0.6, + updatedAt: '2025-06-17', + }, + capabilities: { + toolUsageControl: false, + }, + }, + { + id: 'groq/deepseek-r1-distill-llama-70b', + pricing: { + input: 0.75, + cachedInput: 0.38, + output: 0.99, + updatedAt: '2025-06-17', + }, + capabilities: { + toolUsageControl: false, + }, + }, + { + id: 'groq/qwen-qwq-32b', + pricing: { + input: 0.29, + cachedInput: 0.29, + output: 0.39, + updatedAt: '2025-06-17', + }, + capabilities: { + toolUsageControl: false, + }, + }, + ], + }, + ollama: { + id: 'ollama', + name: 'Ollama', + description: 'Local LLM models via Ollama', + defaultModel: '', + modelPatterns: [], + models: [], // Populated dynamically + }, +} + +// Helper functions to extract information from the comprehensive definitions + +/** + * Get all models for a specific provider + */ +export function getProviderModels(providerId: string): string[] { + return PROVIDER_DEFINITIONS[providerId]?.models.map((m) => m.id) || [] +} + +/** + * Get the default model for a specific provider + */ +export function getProviderDefaultModel(providerId: string): string { + return PROVIDER_DEFINITIONS[providerId]?.defaultModel || '' +} + +/** + * Get pricing information for a specific model + */ +export function getModelPricing(modelId: string): ModelPricing | null { + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase()) + if (model) { + return model.pricing + } + } + return null +} + +/** + * Get capabilities for a specific model + */ +export function getModelCapabilities(modelId: string): ModelCapabilities | null { + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase()) + if (model) { + return model.capabilities + } + } + return null +} + +/** + * Get all models that support temperature + */ +export function getModelsWithTemperatureSupport(): string[] { + const models: string[] = [] + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + for (const model of provider.models) { + if (model.capabilities.temperature) { + models.push(model.id) + } + } + } + return models +} + +/** + * Get all models with temperature range 0-1 + */ +export function getModelsWithTempRange01(): string[] { + const models: string[] = [] + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + for (const model of provider.models) { + if (model.capabilities.temperature?.max === 1) { + models.push(model.id) + } + } + } + return models +} + +/** + * Get all models with temperature range 0-2 + */ +export function getModelsWithTempRange02(): string[] { + const models: string[] = [] + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + for (const model of provider.models) { + if (model.capabilities.temperature?.max === 2) { + models.push(model.id) + } + } + } + return models +} + +/** + * Get all providers that support tool usage control + */ +export function getProvidersWithToolUsageControl(): string[] { + const providers: string[] = [] + for (const [providerId, provider] of Object.entries(PROVIDER_DEFINITIONS)) { + if (provider.models.some((model) => model.capabilities.toolUsageControl)) { + providers.push(providerId) + } + } + return providers +} + +/** + * Get all models that are hosted (don't require user API keys) + */ +export function getHostedModels(): string[] { + // Currently, OpenAI and Anthropic models are hosted + return [...getProviderModels('openai'), ...getProviderModels('anthropic')] +} + +/** + * Get all computer use models + */ +export function getComputerUseModels(): string[] { + const models: string[] = [] + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + for (const model of provider.models) { + if (model.capabilities.computerUse) { + models.push(model.id) + } + } + } + return models +} + +/** + * Check if a model supports temperature + */ +export function supportsTemperature(modelId: string): boolean { + const capabilities = getModelCapabilities(modelId) + return !!capabilities?.temperature +} + +/** + * Get maximum temperature for a model + */ +export function getMaxTemperature(modelId: string): number | undefined { + const capabilities = getModelCapabilities(modelId) + return capabilities?.temperature?.max +} + +/** + * Check if a provider supports tool usage control + */ +export function supportsToolUsageControl(providerId: string): boolean { + return getProvidersWithToolUsageControl().includes(providerId) +} + +/** + * Update Ollama models dynamically + */ +export function updateOllamaModels(models: string[]): void { + PROVIDER_DEFINITIONS.ollama.models = models.map((modelId) => ({ + id: modelId, + pricing: { + input: 0, + output: 0, + updatedAt: new Date().toISOString().split('T')[0], + }, + capabilities: {}, + })) +} diff --git a/apps/sim/providers/openai/index.ts b/apps/sim/providers/openai/index.ts index 00e84c91c94..9c6624574b7 100644 --- a/apps/sim/providers/openai/index.ts +++ b/apps/sim/providers/openai/index.ts @@ -2,6 +2,7 @@ import OpenAI from 'openai' import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' import { executeTool } from '@/tools' +import { getProviderDefaultModel, getProviderModels } from '../models' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' import { prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils' @@ -55,8 +56,8 @@ export const openaiProvider: ProviderConfig = { name: 'OpenAI', description: "OpenAI's GPT models", version: '1.0.0', - models: ['gpt-4o', 'o1', 'o3', 'o4-mini', 'gpt-4.1', 'gpt-4.1-nano', 'gpt-4.1-mini'], - defaultModel: 'gpt-4o', + models: getProviderModels('openai'), + defaultModel: getProviderDefaultModel('openai'), executeRequest: async ( request: ProviderRequest diff --git a/apps/sim/providers/pricing.ts b/apps/sim/providers/pricing.ts deleted file mode 100644 index df17c569c8a..00000000000 --- a/apps/sim/providers/pricing.ts +++ /dev/null @@ -1,211 +0,0 @@ -import type { ModelPricingMap } from './types' - -/** - * Model pricing information per million tokens - * - * Prices are in USD per 1M tokens - * All prices should be regularly updated to reflect current market rates - */ -const modelPricing: ModelPricingMap = { - // OpenAI Models - 'gpt-4o': { - input: 2.5, - cachedInput: 1.25, - output: 10.0, - updatedAt: '2025-06-17', - }, - o1: { - input: 15.0, - cachedInput: 7.5, - output: 60, - updatedAt: '2025-06-17', - }, - o3: { - input: 2, - cachedInput: 0.5, - output: 8, - updatedAt: '2025-06-17', - }, - 'o4-mini': { - input: 1.1, - cachedInput: 0.275, - output: 4.4, - updatedAt: '2025-06-17', - }, - 'gpt-4.1': { - input: 2.0, - cachedInput: 0.5, - output: 8.0, - updatedAt: '2025-06-17', - }, - 'gpt-4.1-nano': { - input: 0.1, - cachedInput: 0.025, - output: 0.4, - updatedAt: '2025-06-17', - }, - 'gpt-4.1-mini': { - input: 0.4, - cachedInput: 0.1, - output: 1.6, - updatedAt: '2025-06-17', - }, - - // Azure OpenAI Models (same pricing as OpenAI) - 'azure/gpt-4o': { - input: 2.5, - cachedInput: 1.25, // 50% discount for cached input - output: 10.0, - updatedAt: '2025-06-15', - }, - 'azure/o3': { - input: 10, - cachedInput: 2.5, - output: 40, - updatedAt: '2025-06-15', - }, - 'azure/o4-mini': { - input: 1.1, - cachedInput: 0.275, - output: 4.4, - updatedAt: '2025-06-15', - }, - 'azure/gpt-4.1': { - input: 2.0, - cachedInput: 0.5, - output: 8.0, - updatedAt: '2025-06-15', - }, - 'azure/model-router': { - input: 2.0, - cachedInput: 0.5, - output: 8.0, - updatedAt: '2025-06-15', - }, - - // Anthropic Models - 'claude-3-5-sonnet-latest': { - input: 3.0, - cachedInput: 1.5, - output: 15.0, - updatedAt: '2025-06-17', - }, - 'claude-3-7-sonnet-latest': { - input: 3.0, - cachedInput: 1.5, - output: 15.0, - updatedAt: '2025-06-17', - }, - 'claude-sonnet-4-0': { - input: 3.0, - cachedInput: 1.5, - output: 15.0, - updatedAt: '2025-06-17', - }, - 'claude-opus-4-0': { - input: 15.0, - cachedInput: 7.5, - output: 75.0, - updatedAt: '2025-06-17', - }, - - // Google Models - 'gemini-2.5-pro': { - input: 0.15, - cachedInput: 0.075, - output: 0.6, - updatedAt: '2025-06-17', - }, - 'gemini-2.5-flash': { - input: 0.15, - cachedInput: 0.075, - output: 0.6, - updatedAt: '2025-06-17', - }, - - // Deepseek Models - 'deepseek-v3': { - input: 0.75, - cachedInput: 0.4, - output: 1.0, - updatedAt: '2025-03-21', - }, - 'deepseek-r1': { - input: 1.0, - cachedInput: 0.5, - output: 1.5, - updatedAt: '2025-03-21', - }, - - // xAI Models - 'grok-3-latest': { - input: 3.0, - cachedInput: 1.5, - output: 15.0, - updatedAt: '2025-04-17', - }, - 'grok-3-fast-latest': { - input: 5.0, - cachedInput: 2.5, - output: 25.0, - updatedAt: '2025-04-17', - }, - - // Cerebras Models - 'cerebras/llama-3.3-70b': { - input: 0.94, - cachedInput: 0.47, - output: 0.94, - updatedAt: '2025-03-21', - }, - - // Groq Models - 'groq/meta-llama/llama-4-scout-17b-16e-instruct': { - input: 0.4, - cachedInput: 0.2, - output: 0.6, - updatedAt: '2025-06-17', - }, - 'groq/deepseek-r1-distill-llama-70b': { - input: 0.75, - cachedInput: 0.38, - output: 0.99, - updatedAt: '2025-06-17', - }, - 'groq/qwen-qwq-32b': { - input: 0.29, - cachedInput: 0.29, - output: 0.39, - updatedAt: '2025-06-17', - }, -} - -/** - * Get pricing for a specific model - * Returns default pricing if model not found - */ -export function getModelPricing(model: string) { - const normalizedModel = model.toLowerCase() - - // Exact match - if (normalizedModel in modelPricing) { - return modelPricing[normalizedModel] - } - - // Partial match (for models with prefixes/versions) - for (const [pricingModel, pricing] of Object.entries(modelPricing)) { - if (normalizedModel.includes(pricingModel.toLowerCase())) { - return pricing - } - } - - // Default pricing if model not found - return { - input: 1.0, - cachedInput: 0.5, - output: 5.0, - updatedAt: '2025-03-21', - } -} - -export default modelPricing diff --git a/apps/sim/providers/utils.test.ts b/apps/sim/providers/utils.test.ts index b981c82f319..b77989f5cc3 100644 --- a/apps/sim/providers/utils.test.ts +++ b/apps/sim/providers/utils.test.ts @@ -1,6 +1,32 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import * as environmentModule from '@/lib/environment' -import { getApiKey } from './utils' +import { + calculateCost, + extractAndParseJSON, + formatCost, + generateStructuredOutputInstructions, + getAllModelProviders, + getAllModels, + getAllProviderIds, + getApiKey, + getBaseModelProviders, + getCustomTools, + getHostedModels, + getMaxTemperature, + getProvider, + getProviderConfigFromModel, + getProviderFromModel, + getProviderModels, + MODELS_TEMP_RANGE_0_1, + MODELS_TEMP_RANGE_0_2, + MODELS_WITH_TEMPERATURE_SUPPORT, + PROVIDERS_WITH_TOOL_USAGE_CONTROL, + prepareToolsWithUsageControl, + supportsTemperature, + supportsToolUsageControl, + transformCustomTool, + updateOllamaProviderModels, +} from './utils' const isHostedSpy = vi.spyOn(environmentModule, 'isHosted', 'get') const mockGetRotatingApiKey = vi.fn().mockReturnValue('rotating-server-key') @@ -78,3 +104,619 @@ describe('getApiKey', () => { ) }) }) + +describe('Model Capabilities', () => { + describe('supportsTemperature', () => { + it.concurrent('should return true for models that support temperature', () => { + const supportedModels = [ + 'gpt-4o', + 'gemini-2.5-flash', + 'claude-sonnet-4-0', + 'claude-opus-4-0', + 'claude-3-7-sonnet-latest', + 'claude-3-5-sonnet-latest', + 'grok-3-latest', + 'grok-3-fast-latest', + 'deepseek-v3', + 'deepseek-chat', + ] + + for (const model of supportedModels) { + expect(supportsTemperature(model)).toBe(true) + } + }) + + it.concurrent('should return false for models that do not support temperature', () => { + const unsupportedModels = [ + 'unsupported-model', + 'cerebras/llama-3.3-70b', // Cerebras models don't have temperature defined + 'groq/meta-llama/llama-4-scout-17b-16e-instruct', // Groq models don't have temperature defined + ] + + for (const model of unsupportedModels) { + expect(supportsTemperature(model)).toBe(false) + } + }) + + it.concurrent('should be case insensitive', () => { + expect(supportsTemperature('GPT-4O')).toBe(true) + expect(supportsTemperature('claude-sonnet-4-0')).toBe(true) + }) + }) + + describe('getMaxTemperature', () => { + it.concurrent('should return 2 for models with temperature range 0-2', () => { + const modelsRange02 = [ + 'gpt-4o', + 'o1', + 'o3', + 'o4-mini', + 'azure/gpt-4o', + 'gemini-2.5-pro', + 'gemini-2.5-flash', + 'deepseek-v3', + 'deepseek-chat', + 'deepseek-r1', + ] + + for (const model of modelsRange02) { + expect(getMaxTemperature(model)).toBe(2) + } + }) + + it.concurrent('should return 1 for models with temperature range 0-1', () => { + const modelsRange01 = [ + 'claude-sonnet-4-0', + 'claude-opus-4-0', + 'claude-3-7-sonnet-latest', + 'claude-3-5-sonnet-latest', + 'grok-3-latest', + 'grok-3-fast-latest', + ] + + for (const model of modelsRange01) { + expect(getMaxTemperature(model)).toBe(1) + } + }) + + it.concurrent('should return undefined for models that do not support temperature', () => { + expect(getMaxTemperature('unsupported-model')).toBeUndefined() + expect(getMaxTemperature('cerebras/llama-3.3-70b')).toBeUndefined() + expect(getMaxTemperature('groq/meta-llama/llama-4-scout-17b-16e-instruct')).toBeUndefined() + }) + + it.concurrent('should be case insensitive', () => { + expect(getMaxTemperature('GPT-4O')).toBe(2) + expect(getMaxTemperature('CLAUDE-SONNET-4-0')).toBe(1) + }) + }) + + describe('supportsToolUsageControl', () => { + it.concurrent('should return true for providers that support tool usage control', () => { + const supportedProviders = ['openai', 'azure-openai', 'anthropic', 'deepseek', 'xai'] + + for (const provider of supportedProviders) { + expect(supportsToolUsageControl(provider)).toBe(true) + } + }) + + it.concurrent( + 'should return false for providers that do not support tool usage control', + () => { + const unsupportedProviders = [ + 'google', + 'ollama', + 'cerebras', + 'groq', + 'non-existent-provider', + ] + + for (const provider of unsupportedProviders) { + expect(supportsToolUsageControl(provider)).toBe(false) + } + } + ) + }) + + describe('Model Constants', () => { + it.concurrent('should have correct models in MODELS_TEMP_RANGE_0_2', () => { + expect(MODELS_TEMP_RANGE_0_2).toContain('gpt-4o') + expect(MODELS_TEMP_RANGE_0_2).toContain('gemini-2.5-flash') + expect(MODELS_TEMP_RANGE_0_2).toContain('deepseek-v3') + expect(MODELS_TEMP_RANGE_0_2).not.toContain('claude-sonnet-4-0') // Should be in 0-1 range + }) + + it.concurrent('should have correct models in MODELS_TEMP_RANGE_0_1', () => { + expect(MODELS_TEMP_RANGE_0_1).toContain('claude-sonnet-4-0') + expect(MODELS_TEMP_RANGE_0_1).toContain('grok-3-latest') + expect(MODELS_TEMP_RANGE_0_1).not.toContain('gpt-4o') // Should be in 0-2 range + }) + + it.concurrent('should have correct providers in PROVIDERS_WITH_TOOL_USAGE_CONTROL', () => { + expect(PROVIDERS_WITH_TOOL_USAGE_CONTROL).toContain('openai') + expect(PROVIDERS_WITH_TOOL_USAGE_CONTROL).toContain('anthropic') + expect(PROVIDERS_WITH_TOOL_USAGE_CONTROL).toContain('deepseek') + expect(PROVIDERS_WITH_TOOL_USAGE_CONTROL).not.toContain('google') + expect(PROVIDERS_WITH_TOOL_USAGE_CONTROL).not.toContain('ollama') + }) + + it.concurrent( + 'should combine both temperature ranges in MODELS_WITH_TEMPERATURE_SUPPORT', + () => { + expect(MODELS_WITH_TEMPERATURE_SUPPORT.length).toBe( + MODELS_TEMP_RANGE_0_2.length + MODELS_TEMP_RANGE_0_1.length + ) + expect(MODELS_WITH_TEMPERATURE_SUPPORT).toContain('gpt-4o') // From 0-2 range + expect(MODELS_WITH_TEMPERATURE_SUPPORT).toContain('claude-sonnet-4-0') // From 0-1 range + } + ) + }) +}) + +describe('Cost Calculation', () => { + describe('calculateCost', () => { + it.concurrent('should calculate cost correctly for known models', () => { + const result = calculateCost('gpt-4o', 1000, 500, false) + + expect(result.input).toBeGreaterThan(0) + expect(result.output).toBeGreaterThan(0) + expect(result.total).toBeCloseTo(result.input + result.output, 6) + expect(result.pricing).toBeDefined() + expect(result.pricing.input).toBe(2.5) // GPT-4o pricing + }) + + it.concurrent('should handle cached input pricing when enabled', () => { + const regularCost = calculateCost('gpt-4o', 1000, 500, false) + const cachedCost = calculateCost('gpt-4o', 1000, 500, true) + + expect(cachedCost.input).toBeLessThan(regularCost.input) + expect(cachedCost.output).toBe(regularCost.output) // Output cost should be same + }) + + it.concurrent('should return default pricing for unknown models', () => { + const result = calculateCost('unknown-model', 1000, 500, false) + + expect(result.input).toBe(0) + expect(result.output).toBe(0) + expect(result.total).toBe(0) + expect(result.pricing.input).toBe(1.0) // Default pricing + }) + + it.concurrent('should handle zero tokens', () => { + const result = calculateCost('gpt-4o', 0, 0, false) + + expect(result.input).toBe(0) + expect(result.output).toBe(0) + expect(result.total).toBe(0) + }) + }) + + describe('formatCost', () => { + it.concurrent('should format costs >= $1 with two decimal places', () => { + expect(formatCost(1.234)).toBe('$1.23') + expect(formatCost(10.567)).toBe('$10.57') + }) + + it.concurrent('should format costs between 1¢ and $1 with three decimal places', () => { + expect(formatCost(0.0234)).toBe('$0.023') + expect(formatCost(0.1567)).toBe('$0.157') + }) + + it.concurrent('should format costs between 0.1¢ and 1¢ with four decimal places', () => { + expect(formatCost(0.00234)).toBe('$0.0023') + expect(formatCost(0.00567)).toBe('$0.0057') + }) + + it.concurrent('should format very small costs with appropriate precision', () => { + expect(formatCost(0.000234)).toContain('$0.000234') + }) + + it.concurrent('should handle zero cost', () => { + expect(formatCost(0)).toBe('$0') + }) + + it.concurrent('should handle undefined/null costs', () => { + expect(formatCost(undefined as any)).toBe('—') + expect(formatCost(null as any)).toBe('—') + }) + }) +}) + +describe('getHostedModels', () => { + it.concurrent('should return OpenAI and Anthropic models as hosted', () => { + const hostedModels = getHostedModels() + + expect(hostedModels).toContain('gpt-4o') + expect(hostedModels).toContain('claude-sonnet-4-0') + expect(hostedModels).toContain('o1') + expect(hostedModels).toContain('claude-opus-4-0') + + // Should not contain models from other providers + expect(hostedModels).not.toContain('gemini-2.5-pro') + expect(hostedModels).not.toContain('deepseek-v3') + }) + + it.concurrent('should return an array of strings', () => { + const hostedModels = getHostedModels() + + expect(Array.isArray(hostedModels)).toBe(true) + expect(hostedModels.length).toBeGreaterThan(0) + hostedModels.forEach((model) => { + expect(typeof model).toBe('string') + }) + }) +}) + +describe('Provider Management', () => { + describe('getProviderFromModel', () => { + it.concurrent('should return correct provider for known models', () => { + expect(getProviderFromModel('gpt-4o')).toBe('openai') + expect(getProviderFromModel('claude-sonnet-4-0')).toBe('anthropic') + expect(getProviderFromModel('gemini-2.5-pro')).toBe('google') + expect(getProviderFromModel('azure/gpt-4o')).toBe('azure-openai') + }) + + it.concurrent('should use model patterns for pattern matching', () => { + expect(getProviderFromModel('gpt-5-custom')).toBe('openai') // Matches /^gpt/ pattern + expect(getProviderFromModel('claude-custom-model')).toBe('anthropic') // Matches /^claude/ pattern + }) + + it.concurrent('should default to ollama for unknown models', () => { + expect(getProviderFromModel('unknown-model')).toBe('ollama') + }) + + it.concurrent('should be case insensitive', () => { + expect(getProviderFromModel('GPT-4O')).toBe('openai') + expect(getProviderFromModel('CLAUDE-SONNET-4-0')).toBe('anthropic') + }) + }) + + describe('getProvider', () => { + it.concurrent('should return provider config for valid provider IDs', () => { + const openaiProvider = getProvider('openai') + expect(openaiProvider).toBeDefined() + expect(openaiProvider?.id).toBe('openai') + expect(openaiProvider?.name).toBe('OpenAI') + + const anthropicProvider = getProvider('anthropic') + expect(anthropicProvider).toBeDefined() + expect(anthropicProvider?.id).toBe('anthropic') + }) + + it.concurrent('should handle provider/service format', () => { + const provider = getProvider('openai/chat') + expect(provider).toBeDefined() + expect(provider?.id).toBe('openai') + }) + + it.concurrent('should return undefined for invalid provider IDs', () => { + expect(getProvider('nonexistent')).toBeUndefined() + }) + }) + + describe('getProviderConfigFromModel', () => { + it.concurrent('should return provider config for model', () => { + const config = getProviderConfigFromModel('gpt-4o') + expect(config).toBeDefined() + expect(config?.id).toBe('openai') + + const anthropicConfig = getProviderConfigFromModel('claude-sonnet-4-0') + expect(anthropicConfig).toBeDefined() + expect(anthropicConfig?.id).toBe('anthropic') + }) + }) + + describe('getAllModels', () => { + it.concurrent('should return all models from all providers', () => { + const allModels = getAllModels() + expect(Array.isArray(allModels)).toBe(true) + expect(allModels.length).toBeGreaterThan(0) + + // Should contain models from different providers + expect(allModels).toContain('gpt-4o') + expect(allModels).toContain('claude-sonnet-4-0') + expect(allModels).toContain('gemini-2.5-pro') + }) + }) + + describe('getAllProviderIds', () => { + it.concurrent('should return all provider IDs', () => { + const providerIds = getAllProviderIds() + expect(Array.isArray(providerIds)).toBe(true) + expect(providerIds).toContain('openai') + expect(providerIds).toContain('anthropic') + expect(providerIds).toContain('google') + expect(providerIds).toContain('azure-openai') + }) + }) + + describe('getProviderModels', () => { + it.concurrent('should return models for specific providers', () => { + const openaiModels = getProviderModels('openai') + expect(Array.isArray(openaiModels)).toBe(true) + expect(openaiModels).toContain('gpt-4o') + expect(openaiModels).toContain('o1') + + const anthropicModels = getProviderModels('anthropic') + expect(anthropicModels).toContain('claude-sonnet-4-0') + expect(anthropicModels).toContain('claude-opus-4-0') + }) + + it.concurrent('should return empty array for unknown providers', () => { + const unknownModels = getProviderModels('unknown' as any) + expect(unknownModels).toEqual([]) + }) + }) + + describe('getBaseModelProviders and getAllModelProviders', () => { + it.concurrent('should return model to provider mapping', () => { + const allProviders = getAllModelProviders() + expect(typeof allProviders).toBe('object') + expect(allProviders['gpt-4o']).toBe('openai') + expect(allProviders['claude-sonnet-4-0']).toBe('anthropic') + + const baseProviders = getBaseModelProviders() + expect(typeof baseProviders).toBe('object') + // Should exclude ollama models + }) + }) + + describe('updateOllamaProviderModels', () => { + it.concurrent('should update ollama models', () => { + const mockModels = ['llama2', 'codellama', 'mistral'] + + // This should not throw + expect(() => updateOllamaProviderModels(mockModels)).not.toThrow() + + // Verify the models were updated + const ollamaModels = getProviderModels('ollama') + expect(ollamaModels).toEqual(mockModels) + }) + }) +}) + +describe('JSON and Structured Output', () => { + describe('extractAndParseJSON', () => { + it.concurrent('should extract and parse valid JSON', () => { + const content = 'Some text before ```json\n{"key": "value"}\n``` some text after' + const result = extractAndParseJSON(content) + expect(result).toEqual({ key: 'value' }) + }) + + it.concurrent('should extract JSON without code blocks', () => { + const content = 'Text before {"name": "test", "value": 42} text after' + const result = extractAndParseJSON(content) + expect(result).toEqual({ name: 'test', value: 42 }) + }) + + it.concurrent('should handle nested objects', () => { + const content = '{"user": {"name": "John", "age": 30}, "active": true}' + const result = extractAndParseJSON(content) + expect(result).toEqual({ + user: { name: 'John', age: 30 }, + active: true, + }) + }) + + it.concurrent('should clean up common JSON issues', () => { + const content = '{\n "key": "value",\n "number": 42,\n}' // Trailing comma + const result = extractAndParseJSON(content) + expect(result).toEqual({ key: 'value', number: 42 }) + }) + + it.concurrent('should throw error for content without JSON', () => { + expect(() => extractAndParseJSON('No JSON here')).toThrow('No JSON object found in content') + }) + + it.concurrent('should throw error for invalid JSON', () => { + const invalidJson = '{"key": invalid, "broken": }' + expect(() => extractAndParseJSON(invalidJson)).toThrow('Failed to parse JSON after cleanup') + }) + }) + + describe('generateStructuredOutputInstructions', () => { + it.concurrent('should return empty string for JSON Schema format', () => { + const schemaFormat = { + schema: { + type: 'object', + properties: { key: { type: 'string' } }, + }, + } + expect(generateStructuredOutputInstructions(schemaFormat)).toBe('') + }) + + it.concurrent('should return empty string for object type with properties', () => { + const objectFormat = { + type: 'object', + properties: { key: { type: 'string' } }, + } + expect(generateStructuredOutputInstructions(objectFormat)).toBe('') + }) + + it.concurrent('should generate instructions for legacy fields format', () => { + const fieldsFormat = { + fields: [ + { name: 'score', type: 'number', description: 'A score from 1-10' }, + { name: 'comment', type: 'string', description: 'A comment' }, + ], + } + const result = generateStructuredOutputInstructions(fieldsFormat) + + expect(result).toContain('JSON format') + expect(result).toContain('score') + expect(result).toContain('comment') + expect(result).toContain('A score from 1-10') + }) + + it.concurrent('should handle object fields with properties', () => { + const fieldsFormat = { + fields: [ + { + name: 'metadata', + type: 'object', + properties: { + version: { type: 'string', description: 'Version number' }, + count: { type: 'number', description: 'Item count' }, + }, + }, + ], + } + const result = generateStructuredOutputInstructions(fieldsFormat) + + expect(result).toContain('metadata') + expect(result).toContain('Properties:') + expect(result).toContain('version') + expect(result).toContain('count') + }) + + it.concurrent('should return empty string for missing fields', () => { + expect(generateStructuredOutputInstructions({})).toBe('') + expect(generateStructuredOutputInstructions(null)).toBe('') + expect(generateStructuredOutputInstructions({ fields: null })).toBe('') + }) + }) +}) + +describe('Tool Management', () => { + describe('transformCustomTool', () => { + it.concurrent('should transform valid custom tool schema', () => { + const customTool = { + id: 'test-tool', + schema: { + function: { + name: 'testFunction', + description: 'A test function', + parameters: { + type: 'object', + properties: { + input: { type: 'string', description: 'Input parameter' }, + }, + required: ['input'], + }, + }, + }, + } + + const result = transformCustomTool(customTool) + + expect(result.id).toBe('custom_test-tool') + expect(result.name).toBe('testFunction') + expect(result.description).toBe('A test function') + expect(result.parameters.type).toBe('object') + expect(result.parameters.properties).toBeDefined() + expect(result.parameters.required).toEqual(['input']) + }) + + it.concurrent('should throw error for invalid schema', () => { + const invalidTool = { id: 'test', schema: null } + expect(() => transformCustomTool(invalidTool)).toThrow('Invalid custom tool schema') + + const noFunction = { id: 'test', schema: {} } + expect(() => transformCustomTool(noFunction)).toThrow('Invalid custom tool schema') + }) + }) + + describe('getCustomTools', () => { + it.concurrent('should return array of transformed custom tools', () => { + const result = getCustomTools() + expect(Array.isArray(result)).toBe(true) + }) + }) + + describe('prepareToolsWithUsageControl', () => { + const mockLogger = { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + } + + beforeEach(() => { + mockLogger.info.mockClear() + }) + + it.concurrent('should return early for no tools', () => { + const result = prepareToolsWithUsageControl(undefined, undefined, mockLogger) + + expect(result.tools).toBeUndefined() + expect(result.toolChoice).toBeUndefined() + expect(result.hasFilteredTools).toBe(false) + expect(result.forcedTools).toEqual([]) + }) + + it.concurrent('should filter out tools with usageControl="none"', () => { + const tools = [ + { function: { name: 'tool1' } }, + { function: { name: 'tool2' } }, + { function: { name: 'tool3' } }, + ] + const providerTools = [ + { id: 'tool1', usageControl: 'auto' }, + { id: 'tool2', usageControl: 'none' }, + { id: 'tool3', usageControl: 'force' }, + ] + + const result = prepareToolsWithUsageControl(tools, providerTools, mockLogger) + + expect(result.tools).toHaveLength(2) + expect(result.hasFilteredTools).toBe(true) + expect(result.forcedTools).toEqual(['tool3']) + expect(mockLogger.info).toHaveBeenCalledWith("Filtered out 1 tools with usageControl='none'") + }) + + it.concurrent('should set toolChoice for forced tools (OpenAI format)', () => { + const tools = [{ function: { name: 'forcedTool' } }] + const providerTools = [{ id: 'forcedTool', usageControl: 'force' }] + + const result = prepareToolsWithUsageControl(tools, providerTools, mockLogger) + + expect(result.toolChoice).toEqual({ + type: 'function', + function: { name: 'forcedTool' }, + }) + }) + + it.concurrent('should set toolChoice for forced tools (Anthropic format)', () => { + const tools = [{ function: { name: 'forcedTool' } }] + const providerTools = [{ id: 'forcedTool', usageControl: 'force' }] + + const result = prepareToolsWithUsageControl(tools, providerTools, mockLogger, 'anthropic') + + expect(result.toolChoice).toEqual({ + type: 'tool', + name: 'forcedTool', + }) + }) + + it.concurrent('should set toolConfig for Google format', () => { + const tools = [{ function: { name: 'forcedTool' } }] + const providerTools = [{ id: 'forcedTool', usageControl: 'force' }] + + const result = prepareToolsWithUsageControl(tools, providerTools, mockLogger, 'google') + + expect(result.toolConfig).toEqual({ + mode: 'ANY', + allowed_function_names: ['forcedTool'], + }) + }) + + it.concurrent('should return empty when all tools are filtered', () => { + const tools = [{ function: { name: 'tool1' } }] + const providerTools = [{ id: 'tool1', usageControl: 'none' }] + + const result = prepareToolsWithUsageControl(tools, providerTools, mockLogger) + + expect(result.tools).toBeUndefined() + expect(result.toolChoice).toBeUndefined() + expect(result.hasFilteredTools).toBe(true) + }) + + it.concurrent('should default to auto when no forced tools', () => { + const tools = [{ function: { name: 'tool1' } }] + const providerTools = [{ id: 'tool1', usageControl: 'auto' }] + + const result = prepareToolsWithUsageControl(tools, providerTools, mockLogger) + + expect(result.toolChoice).toBe('auto') + }) + }) +}) diff --git a/apps/sim/providers/utils.ts b/apps/sim/providers/utils.ts index 6f54e8dd9a6..b48c3c40ff6 100644 --- a/apps/sim/providers/utils.ts +++ b/apps/sim/providers/utils.ts @@ -7,16 +7,30 @@ import { cerebrasProvider } from './cerebras' import { deepseekProvider } from './deepseek' import { googleProvider } from './google' import { groqProvider } from './groq' +import { + getComputerUseModels, + getHostedModels as getHostedModelsFromDefinitions, + getMaxTemperature as getMaxTempFromDefinitions, + getModelPricing as getModelPricingFromDefinitions, + getModelsWithTemperatureSupport, + getModelsWithTempRange01, + getModelsWithTempRange02, + getProviderModels as getProviderModelsFromDefinitions, + getProvidersWithToolUsageControl, + PROVIDER_DEFINITIONS, + supportsTemperature as supportsTemperatureFromDefinitions, + supportsToolUsageControl as supportsToolUsageControlFromDefinitions, + updateOllamaModels as updateOllamaModelsInDefinitions, +} from './models' import { ollamaProvider } from './ollama' import { openaiProvider } from './openai' -import { getModelPricing } from './pricing' import type { ProviderConfig, ProviderId, ProviderToolConfig } from './types' import { xAIProvider } from './xai' const logger = createLogger('ProviderUtils') /** - * Provider configurations with associated model names/patterns + * Provider configurations - built from the comprehensive definitions */ export const providers: Record< ProviderId, @@ -28,59 +42,52 @@ export const providers: Record< > = { openai: { ...openaiProvider, - models: ['gpt-4o', 'o1', 'o3', 'o4-mini', 'gpt-4.1', 'gpt-4.1-nano', 'gpt-4.1-mini'], + models: getProviderModelsFromDefinitions('openai'), computerUseModels: ['computer-use-preview'], - modelPatterns: [/^gpt/, /^o1/], - }, - 'azure-openai': { - ...azureOpenAIProvider, - models: ['azure/gpt-4o', 'azure/o3', 'azure/o4-mini', 'azure/gpt-4.1', 'azure/model-router'], - modelPatterns: [/^azure\//], + modelPatterns: PROVIDER_DEFINITIONS.openai.modelPatterns, }, anthropic: { ...anthropicProvider, - models: [ - 'claude-sonnet-4-0', - 'claude-opus-4-0', - 'claude-3-7-sonnet-latest', - 'claude-3-5-sonnet-latest', - ], - computerUseModels: ['claude-3-5-sonnet-latest', 'claude-3-7-sonnet-latest'], - modelPatterns: [/^claude/], + models: getProviderModelsFromDefinitions('anthropic'), + computerUseModels: getComputerUseModels().filter((model) => + getProviderModelsFromDefinitions('anthropic').includes(model) + ), + modelPatterns: PROVIDER_DEFINITIONS.anthropic.modelPatterns, }, google: { ...googleProvider, - models: ['gemini-2.5-pro', 'gemini-2.5-flash'], - modelPatterns: [/^gemini/], + models: getProviderModelsFromDefinitions('google'), + modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns, }, deepseek: { ...deepseekProvider, - models: ['deepseek-v3', 'deepseek-r1'], - modelPatterns: [], + models: getProviderModelsFromDefinitions('deepseek'), + modelPatterns: PROVIDER_DEFINITIONS.deepseek.modelPatterns, }, xai: { ...xAIProvider, - models: ['grok-3-latest', 'grok-3-fast-latest'], - modelPatterns: [/^grok/], + models: getProviderModelsFromDefinitions('xai'), + modelPatterns: PROVIDER_DEFINITIONS.xai.modelPatterns, }, cerebras: { ...cerebrasProvider, - models: ['cerebras/llama-3.3-70b'], - modelPatterns: [/^cerebras/], + models: getProviderModelsFromDefinitions('cerebras'), + modelPatterns: PROVIDER_DEFINITIONS.cerebras.modelPatterns, }, groq: { ...groqProvider, - models: [ - 'groq/meta-llama/llama-4-scout-17b-16e-instruct', - 'groq/deepseek-r1-distill-llama-70b', - 'groq/qwen-qwq-32b', - ], - modelPatterns: [/^groq/], + models: getProviderModelsFromDefinitions('groq'), + modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns, + }, + 'azure-openai': { + ...azureOpenAIProvider, + models: getProviderModelsFromDefinitions('azure-openai'), + modelPatterns: PROVIDER_DEFINITIONS['azure-openai'].modelPatterns, }, ollama: { ...ollamaProvider, - models: [], - modelPatterns: [], + models: getProviderModelsFromDefinitions('ollama'), + modelPatterns: PROVIDER_DEFINITIONS.ollama.modelPatterns, }, } @@ -97,7 +104,8 @@ Object.entries(providers).forEach(([id, provider]) => { // Function to update Ollama provider models export function updateOllamaProviderModels(models: string[]): void { - providers.ollama.models = models + updateOllamaModelsInDefinitions(models) + providers.ollama.models = getProviderModelsFromDefinitions('ollama') logger.info('Updated Ollama provider models', { models }) } @@ -167,11 +175,13 @@ export function getAllProviderIds(): ProviderId[] { } export function getProviderModels(providerId: ProviderId): string[] { - const provider = providers[providerId] - return provider?.models || [] + return getProviderModelsFromDefinitions(providerId) } export function generateStructuredOutputInstructions(responseFormat: any): string { + // Handle null/undefined input + if (!responseFormat) return '' + // If using the new JSON Schema format, don't add additional instructions // This is necessary because providers now handle the schema directly if (responseFormat.schema || (responseFormat.type === 'object' && responseFormat.properties)) { @@ -179,7 +189,7 @@ export function generateStructuredOutputInstructions(responseFormat: any): strin } // Handle legacy format with fields array - if (!responseFormat?.fields) return '' + if (!responseFormat.fields) return '' function generateFieldStructure(field: any): string { if (field.type === 'object' && field.properties) { @@ -426,7 +436,23 @@ export function calculateCost( completionTokens = 0, useCachedInput = false ) { - const pricing = getModelPricing(model) + const pricing = getModelPricingFromDefinitions(model) + + // If no pricing found, return default pricing + if (!pricing) { + const defaultPricing = { + input: 1.0, + cachedInput: 0.5, + output: 5.0, + updatedAt: '2025-03-21', + } + return { + input: 0, + output: 0, + total: 0, + pricing: defaultPricing, + } + } // Calculate costs in USD // Convert from "per million tokens" to "per token" by dividing by 1,000,000 @@ -479,6 +505,14 @@ export function formatCost(cost: number): string { return '$0' } +/** + * Get the list of models that are hosted by the platform (don't require user API keys) + * These are the models for which we hide the API key field in the hosted environment + */ +export function getHostedModels(): string[] { + return getHostedModelsFromDefinitions() +} + /** * Get an API key for a specific provider, handling rotation and fallbacks * For use server-side only @@ -794,3 +828,30 @@ export function trackForcedToolUsage( : undefined, } } + +export const MODELS_TEMP_RANGE_0_2 = getModelsWithTempRange02() +export const MODELS_TEMP_RANGE_0_1 = getModelsWithTempRange01() +export const MODELS_WITH_TEMPERATURE_SUPPORT = getModelsWithTemperatureSupport() +export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl() + +/** + * Check if a model supports temperature parameter + */ +export function supportsTemperature(model: string): boolean { + return supportsTemperatureFromDefinitions(model) +} + +/** + * Get the maximum temperature value for a model + * @returns Maximum temperature value (1 or 2) or undefined if temperature not supported + */ +export function getMaxTemperature(model: string): number | undefined { + return getMaxTempFromDefinitions(model) +} + +/** + * Check if a provider supports tool usage control + */ +export function supportsToolUsageControl(provider: string): boolean { + return supportsToolUsageControlFromDefinitions(provider) +} diff --git a/apps/sim/providers/xai/index.ts b/apps/sim/providers/xai/index.ts index 67ba28bf5c7..a4f2f9be9f3 100644 --- a/apps/sim/providers/xai/index.ts +++ b/apps/sim/providers/xai/index.ts @@ -2,6 +2,7 @@ import OpenAI from 'openai' import { createLogger } from '@/lib/logs/console-logger' import type { StreamingExecution } from '@/executor/types' import { executeTool } from '@/tools' +import { getProviderDefaultModel, getProviderModels } from '../models' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' import { prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils' @@ -34,8 +35,8 @@ export const xAIProvider: ProviderConfig = { name: 'xAI', description: "xAI's Grok models", version: '1.0.0', - models: ['grok-3-latest', 'grok-3-fast-latest'], - defaultModel: 'grok-3-latest', + models: getProviderModels('xai'), + defaultModel: getProviderDefaultModel('xai'), executeRequest: async ( request: ProviderRequest