Skip to content
76 changes: 76 additions & 0 deletions apps/web/src/lib/server/interpret-write-turn.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { beforeEach, describe, expect, it, vi } from "vitest";

vi.mock("@atlas/integrations", () => ({
interpretWriteTurnWithResponses: vi.fn(),
}));

import { interpretWriteTurnWithResponses } from "@atlas/integrations";

import { interpretWriteTurn } from "./interpret-write-turn";

const mockInterpretWriteTurnWithResponses = vi.mocked(
interpretWriteTurnWithResponses,
);

beforeEach(() => {
vi.clearAllMocks();
});

describe("interpretWriteTurn", () => {
it("forwards entityContext to the integrations layer", async () => {
mockInterpretWriteTurnWithResponses.mockResolvedValueOnce({
operationKind: "plan",
actionDomain: "task",
targetRef: { entityId: "task-1", description: null, entityKind: null },
taskName: null,
fields: {
scheduleFields: null,
taskFields: null,
},
confidence: {},
unresolvedFields: [],
});

await interpretWriteTurn({
currentTurnText: "move gym",
turnType: "edit_request",
entityContext: 'Known entities:\n- "Gym" (task, scheduled) [id: task-1]',
});

expect(mockInterpretWriteTurnWithResponses).toHaveBeenCalledWith(
expect.objectContaining({
entityContext:
'Known entities:\n- "Gym" (task, scheduled) [id: task-1]',
}),
undefined,
);
});

it("falls back cleanly when the integrations layer returns malformed output", async () => {
mockInterpretWriteTurnWithResponses.mockResolvedValueOnce({
operationKind: "plan",
actionDomain: "task",
targetRef: null,
taskName: null,
fields: {
scheduleFields: null,
taskFields: null,
},
confidence: {
bad: 2,
},
unresolvedFields: [],
} as never);

await expect(
interpretWriteTurn({
currentTurnText: "schedule gym",
turnType: "planning_request",
}),
).resolves.toMatchObject({
operationKind: "plan",
targetRef: null,
sourceText: "schedule gym",
});
});
});
101 changes: 101 additions & 0 deletions apps/web/src/lib/server/turn-router.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,107 @@ describe("turn router", () => {
expect(mockInterpretWriteTurn).not.toHaveBeenCalled();
});

it("prefers the interpreter targetRef entity over discourse focus for write turns", async () => {
mockClassification({
turnType: "edit_request",
confidence: 0.96,
});
mockInterpretWriteTurn.mockResolvedValueOnce({
operationKind: "edit",
actionDomain: "task",
targetRef: { entityId: "task-2" },
taskName: null,
fields: { scheduleFields: { time: t(11, 0) } },
sourceText: "Move weekly review to 11",
confidence: {
"scheduleFields.time": 0.95,
},
unresolvedFields: [],
});

const result = await routeMessageTurn({
rawText: "Move weekly review to 11",
normalizedText: "Move weekly review to 11",
recentTurns: [],
tasks: [
{
id: "task-2",
userId: "user-1",
sourceInboxItemId: "inbox-1",
lastInboxItemId: "inbox-1",
title: "Weekly review",
lifecycleState: "pending_schedule",
externalCalendarEventId: null,
externalCalendarId: null,
scheduledStartAt: null,
scheduledEndAt: null,
calendarSyncStatus: "in_sync",
calendarSyncUpdatedAt: null,
rescheduleCount: 0,
lastFollowupAt: null,
followupReminderSentAt: null,
completedAt: null,
archivedAt: null,
priority: "medium",
urgency: "medium",
},
],
discourseState: {
focus_entity_id: "task-1",
currently_editable_entity_id: "task-1",
last_user_mentioned_entity_ids: [],
last_presented_items: [],
pending_clarifications: [],
mode: "editing",
},
});

expect(result.interpretation.resolvedEntityIds).toEqual(["task-2"]);
expect(result.policy.targetEntityId).toBe("task-2");
expect(result.policy.resolvedOperation?.targetRef).toEqual({
entityId: "task-2",
});
});

it("falls back to resolveWriteTarget when the interpreter does not resolve an entity", async () => {
mockClassification({
turnType: "edit_request",
confidence: 0.96,
});
mockInterpretWriteTurn.mockResolvedValueOnce({
operationKind: "edit",
actionDomain: "task",
targetRef: null,
taskName: null,
fields: { scheduleFields: { time: t(11, 0) } },
sourceText: "Move it to 11",
confidence: {
"scheduleFields.time": 0.95,
},
unresolvedFields: [],
});

const result = await routeMessageTurn({
rawText: "Move it to 11",
normalizedText: "Move it to 11",
recentTurns: [],
discourseState: {
focus_entity_id: "task-1",
currently_editable_entity_id: null,
last_user_mentioned_entity_ids: [],
last_presented_items: [],
pending_clarifications: [],
mode: "editing",
},
});

expect(result.interpretation.resolvedEntityIds).toEqual(["task-1"]);
expect(result.policy.targetEntityId).toBe("task-1");
expect(result.policy.resolvedOperation?.targetRef).toEqual({
entityId: "task-1",
});
});

it("clears prior committed fields when the interpreted workflow changes", async () => {
mockClassification({
turnType: "planning_request",
Expand Down
29 changes: 25 additions & 4 deletions apps/web/src/lib/server/turn-router.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import {
applyWriteCommit,
buildEntityContext,
type ConversationDiscourseState,
type ConversationEntity,
type ConversationTurn,
createEmptyDiscourseState,
deriveAmbiguity,
type PendingWriteOperation,
renderEntityContext,
type RoutedTurn,
routedTurnSchema,
taskSchema,
type TurnAmbiguity,
type TurnClassifierOutput,
type TurnInterpretation,
Expand Down Expand Up @@ -75,6 +78,7 @@ export async function routeMessageTurn(
): Promise<TurnRouterResult> {
const discourseState = input.discourseState ?? createEmptyDiscourseState();
const entityRegistry = input.entityRegistry ?? [];
const tasks = (input.tasks ?? []).map((task) => taskSchema.parse(task));

// Pipeline A: classify intent
let classification = await classifyTurn({
Expand Down Expand Up @@ -122,6 +126,13 @@ export async function routeMessageTurn(
turnType: classification.turnType,
priorPendingWriteOperation: priorOperation,
conversationContext: deriveConversationContext(input.recentTurns),
entityContext: renderEntityContext(
buildEntityContext({
entityRegistry,
tasks,
discourseState,
}),
),
})
: {
operationKind: priorOperation?.operationKind ?? "plan",
Expand All @@ -134,21 +145,31 @@ export async function routeMessageTurn(
unresolvedFields: [],
};

// LLM-resolved targetRef takes priority over discourse-state entity lookup
const effectiveTargetEntityId =
writeInterpretation.targetRef?.entityId ?? writeTarget.targetEntityId;
const effectiveWriteTarget: WriteTarget = {
...writeTarget,
...(effectiveTargetEntityId
? { targetEntityId: effectiveTargetEntityId }
: {}),
};

// Policy layer: commit + route
const commitResult = applyWriteCommit({
turnType: classification.turnType,
interpretation: writeInterpretation,
priorPendingWriteOperation: priorOperation,
...(writeTarget.targetEntityId !== undefined
? { currentTargetEntityId: writeTarget.targetEntityId }
...(effectiveTargetEntityId !== undefined
? { currentTargetEntityId: effectiveTargetEntityId }
: {}),
});

const policy = decideTurnPolicy({
classification,
commitResult,
routingContext: input,
...writeTarget,
...effectiveWriteTarget,
});

// Assemble the resolved PendingWriteOperation for any turn that advances or maintains
Expand All @@ -168,7 +189,7 @@ export async function routeMessageTurn(
const interpretation = buildInterpretation(
classification,
commitResult,
writeTarget,
effectiveWriteTarget,
);

return routedTurnSchema.parse({
Expand Down
Loading