diff --git a/CHANGELOG.md b/CHANGELOG.md index 689718d36..8c6281d83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - [EE] Improved Ask Sourcebot prompt caching by splitting static and dynamic prompt sections and advancing cache breakpoints after every agent step instead of only after each message. [#1366](https://github.com/sourcebot-dev/sourcebot/pull/1366) +- Refactored Ask Sourcebot user message text extraction into a shared helper that robustly handles non-text message parts. [#1371](https://github.com/sourcebot-dev/sourcebot/pull/1371) ### Added - Added per-step token cost tracking and estimated tool call token usage to Ask Sourcebot chat history. [#1353](https://github.com/sourcebot-dev/sourcebot/pull/1353) diff --git a/packages/web/src/app/(app)/chat/[id]/page.tsx b/packages/web/src/app/(app)/chat/[id]/page.tsx index 45b67be8a..e49d359b8 100644 --- a/packages/web/src/app/(app)/chat/[id]/page.tsx +++ b/packages/web/src/app/(app)/chat/[id]/page.tsx @@ -14,6 +14,7 @@ import { __unsafePrisma } from '@/prisma'; import { ChatVisibility } from '@sourcebot/db'; import { Metadata } from 'next'; import { SBChatMessage } from '@/features/chat/types'; +import { getUserMessageText } from '@/features/chat/utils'; import { env } from '@sourcebot/shared'; import { hasEntitlement } from '@/lib/entitlements'; import { ChatEntitlementMessage } from '@/features/chat/components/chatEntitlementMessage'; @@ -54,11 +55,11 @@ export const generateMetadata = async ({ params }: PageProps): Promise let description = 'A chat on Sourcebot'; if (firstUserMessage) { - const textPart = firstUserMessage.parts.find(p => p.type === 'text'); - if (textPart && textPart.type === 'text') { - description = textPart.text.length > 160 - ? textPart.text.substring(0, 160).trim() + '...' - : textPart.text; + const text = getUserMessageText(firstUserMessage); + if (text) { + description = text.length > 160 + ? text.substring(0, 160).trim() + '...' + : text; } } diff --git a/packages/web/src/ee/features/chat/agent.ts b/packages/web/src/ee/features/chat/agent.ts index d2f3a4761..b34494c2f 100644 --- a/packages/web/src/ee/features/chat/agent.ts +++ b/packages/web/src/ee/features/chat/agent.ts @@ -22,7 +22,7 @@ import { randomUUID } from "crypto"; import _dedent from "dedent"; import { ANSWER_TAG, FILE_REFERENCE_PREFIX } from "@/features/chat/constants"; import { Source } from "@/features/chat/types"; -import { addLineNumbers, fileReferenceToString, getAnswerPartFromAssistantMessage, getTurnProgressState } from "@/features/chat/utils"; +import { addLineNumbers, fileReferenceToString, getAnswerPartFromAssistantMessage, getTurnProgressState, getUserMessageText } from "@/features/chat/utils"; import { createTools } from "./tools"; import { getConnectedMcpClients } from "@/ee/features/chat/mcp/mcpClientFactory"; import { getMcpTools, McpToolsResult } from "@/ee/features/chat/mcp/mcpToolSets"; @@ -105,7 +105,7 @@ export const createMessageStream = async ({ if (message.role === 'user') { return { role: 'user', - content: message.parts[0].type === 'text' ? message.parts[0].text : '', + content: getUserMessageText(message), }; } diff --git a/packages/web/src/ee/features/chat/components/chatThread/chatThread.tsx b/packages/web/src/ee/features/chat/components/chatThread/chatThread.tsx index 9ceae4311..87faf79f8 100644 --- a/packages/web/src/ee/features/chat/components/chatThread/chatThread.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/chatThread.tsx @@ -5,7 +5,7 @@ import { Button } from '@/components/ui/button'; import { Separator } from '@/components/ui/separator'; import { CustomSlateEditor } from '@/features/chat/customSlateEditor'; import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types'; -import { createUIMessage, getAllMentionElements, getTurnProgressState, resetEditor, slateContentToString } from '@/features/chat/utils'; +import { createUIMessage, getAllMentionElements, getTurnProgressState, getUserMessageText, resetEditor, slateContentToString } from '@/features/chat/utils'; import { useChat } from '@ai-sdk/react'; import { CreateUIMessage, DefaultChatTransport, lastAssistantMessageIsCompleteWithApprovalResponses } from 'ai'; import { ArrowDownIcon, CopyIcon } from 'lucide-react'; @@ -204,16 +204,16 @@ export const ChatThread = ({ } satisfies AdditionalChatRequestParams, }); + const userMessageText = getUserMessageText(message); if ( messages.length === 0 && - message.parts.length > 0 && - message.parts[0].type === 'text' + userMessageText.length > 0 ) { generateAndUpdateChatNameFromMessage( { chatId, languageModelId: selectedLanguageModel.model, - message: message.parts[0].text, + message: userMessageText, }, ).then((response) => { if (isServiceError(response)) { diff --git a/packages/web/src/ee/features/chat/components/chatThread/chatThreadListItem.tsx b/packages/web/src/ee/features/chat/components/chatThread/chatThreadListItem.tsx index a9dc9d2f4..6b79de0e6 100644 --- a/packages/web/src/ee/features/chat/components/chatThread/chatThreadListItem.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/chatThreadListItem.tsx @@ -8,7 +8,7 @@ import { CSSProperties, forwardRef, memo, useCallback, useEffect, useMemo, useRe import scrollIntoView from 'scroll-into-view-if-needed'; import { Reference, referenceSchema, SBChatMessage, Source } from "@/features/chat/types"; import { useExtractReferences } from '../../useExtractReferences'; -import { getAnswerPartFromAssistantMessage, getLastStepParts, groupMessageIntoSteps, isSBChatToolPart, repairReferences, tryResolveFileReference } from '@/features/chat/utils'; +import { getAnswerPartFromAssistantMessage, getLastStepParts, getUserMessageText, groupMessageIntoSteps, isSBChatToolPart, repairReferences, tryResolveFileReference } from '@/features/chat/utils'; import { AnswerCard } from './answerCard'; import { DetailsCard } from './detailsCard'; import { ApprovalRequestedToolPart, ToolApprovalBanner } from './toolApprovalBanner'; @@ -49,7 +49,7 @@ const ChatThreadListItemComponent = forwardRef { - return userMessage.parts.length > 0 && userMessage.parts[0].type === 'text' ? userMessage.parts[0].text : ''; + return getUserMessageText(userMessage); }, [userMessage]); // Take the assistant message and repair any references that are not properly formatted. diff --git a/packages/web/src/features/chat/utils.test.ts b/packages/web/src/features/chat/utils.test.ts index 8f0a77b82..1f83c8f50 100644 --- a/packages/web/src/features/chat/utils.test.ts +++ b/packages/web/src/features/chat/utils.test.ts @@ -1,5 +1,5 @@ import { expect, test, describe, vi } from 'vitest' -import { createUIMessage, fileReferenceToString, getAnswerPartFromAssistantMessage, getLastStepParts, getTurnProgressState, groupMessageIntoSteps, repairReferences } from './utils' +import { createUIMessage, fileReferenceToString, getAnswerPartFromAssistantMessage, getLastStepParts, getTurnProgressState, getUserMessageText, groupMessageIntoSteps, repairReferences } from './utils' import { FILE_REFERENCE_REGEX, ANSWER_TAG } from './constants'; import { SBChatMessage, SBChatMessagePart } from './types'; @@ -537,6 +537,75 @@ test('getAnswerPartFromAssistantMessage returns undefined when turn is in progre expect(result).toBeUndefined(); }); +describe('getUserMessageText', () => { + test('returns the text when the text part is first', () => { + const message: SBChatMessage = { + role: 'user', + parts: [ + { + type: 'text', + text: 'Hello, world!', + }, + ], + } as SBChatMessage; + + expect(getUserMessageText(message)).toBe('Hello, world!'); + }); + + test('returns the text when a non-text part precedes the text part', () => { + const message: SBChatMessage = { + role: 'user', + parts: [ + { + type: 'data-source', + data: { + type: 'file', + path: 'auth.ts', + repo: 'github.com/sourcebot-dev/sourcebot', + name: 'auth.ts', + revision: 'main', + }, + }, + { + type: 'text', + text: 'Explain this file', + }, + ], + } as SBChatMessage; + + expect(getUserMessageText(message)).toBe('Explain this file'); + }); + + test('returns an empty string when there is no text part', () => { + const message: SBChatMessage = { + role: 'user', + parts: [ + { + type: 'data-source', + data: { + type: 'file', + path: 'auth.ts', + repo: 'github.com/sourcebot-dev/sourcebot', + name: 'auth.ts', + revision: 'main', + }, + }, + ], + } as SBChatMessage; + + expect(getUserMessageText(message)).toBe(''); + }); + + test('returns an empty string when there are no parts', () => { + const message: SBChatMessage = { + role: 'user', + parts: [], + } as unknown as SBChatMessage; + + expect(getUserMessageText(message)).toBe(''); + }); +}); + test('repairReferences fixes missing colon after @file', () => { const input = 'See the function in @file{github.com/sourcebot-dev/sourcebot::auth.ts} for details.'; const expected = 'See the function in @file:{github.com/sourcebot-dev/sourcebot::auth.ts} for details.'; diff --git a/packages/web/src/features/chat/utils.ts b/packages/web/src/features/chat/utils.ts index c7f409ac7..0f4a4383e 100644 --- a/packages/web/src/features/chat/utils.ts +++ b/packages/web/src/features/chat/utils.ts @@ -397,6 +397,14 @@ export const repairReferences = (text: string): string => { .replace(/`(@file:\{[^`]+)`\}/g, '$1}'); }; +// Extracts the user's text from a message by finding the first text part. +// User messages may contain non-text parts (e.g., file attachments), so we +// cannot assume the text is always at index 0. Accepts anything carrying +// `parts` so it works for both persisted and freshly created messages. +export const getUserMessageText = (message: Pick): string => { + return message.parts.find((part) => part.type === 'text')?.text ?? ''; +} + // Attempts to find the part of the assistant's message // that contains the answer. export const getAnswerPartFromAssistantMessage = (message: SBChatMessage, isTurnInProgress: boolean): TextUIPart | undefined => {