From e2d180ab0bcd6d4e58e0705963e792abc61c47d3 Mon Sep 17 00:00:00 2001 From: claude code agent 227 Date: Sat, 20 Jun 2026 05:51:52 +0300 Subject: [PATCH] test(ai-chat): cover crypto/SSRF/assistant-parts; a11y + refactors Closes the ai-chat code-review follow-ups. Tests (security-critical paths previously uncovered): - secret-box.spec: AES-256-GCM round-trip, random-salt uniqueness, tampered blob / wrong APP_SECRET throw the expected message. - ssrf-guard.spec: isIpAllowed blocks loopback/link-local/private/CGNAT/ULA/ unspecified/IPv4-mapped, allows public; isUrlAllowed blocks bad scheme, invalid URL, IP literals, and DNS-rebinding (mocked dns.lookup). - ai-chat.service.spec: assistantParts emits output-error for an UNPAIRED tool call (guards the MissingToolResultsError fix), output-available when paired, skips malformed calls; serializeSteps/rowToUiMessage. - ai-chat-tools.service.spec: JSON-string node/content coercion + invalid-JSON throws; updatePageJson title-only vs object. - page-embedding.repo.spec: empty spaceIds early-returns [] with no DB call. a11y: - Chat history toggle and conversation rows are now keyboard-operable (role=button, tabIndex, Enter/Space), matching history-item.tsx. Refactors: - onError on useChat adopts the server chat id when the first turn errors (AI SDK v6 onFinish doesn't fire on error). - isToolPart exported once from tool-parts and shared (was duplicated). - buildInitialValues() dedups the ai-mcp-server-form initial values. - describeProviderError replaces two inline statusCode/message snippets. - tool-parts stale tool-list comment refreshed. Implements docs/backlog/ai-chat-review-followups.md. Co-Authored-By: Claude Opus 4.8 --- .../ai-chat/components/ai-chat-window.tsx | 9 + .../ai-chat/components/chat-thread.tsx | 7 + .../ai-chat/components/conversation-list.tsx | 10 + .../ai-chat/components/message-item.tsx | 7 +- .../ai-chat/components/message-list.tsx | 6 +- .../src/features/ai-chat/utils/tool-parts.tsx | 13 +- .../components/ai-mcp-server-form.tsx | 33 ++-- .../src/core/ai-chat/ai-chat.service.spec.ts | 126 ++++++++++++- .../src/core/ai-chat/ai-chat.service.ts | 29 ++- .../ai-chat/external-mcp/ssrf-guard.spec.ts | 119 ++++++++++++ .../tools/ai-chat-tools.service.spec.ts | 171 ++++++++++++++++++ .../repos/ai-chat/page-embedding.repo.spec.ts | 26 +++ .../src/integrations/ai/ai-error.util.ts | 12 +- .../integrations/crypto/secret-box.spec.ts | 77 ++++++++ 14 files changed, 595 insertions(+), 50 deletions(-) create mode 100644 apps/server/src/core/ai-chat/external-mcp/ssrf-guard.spec.ts create mode 100644 apps/server/src/database/repos/ai-chat/page-embedding.repo.spec.ts create mode 100644 apps/server/src/integrations/crypto/secret-box.spec.ts diff --git a/apps/client/src/features/ai-chat/components/ai-chat-window.tsx b/apps/client/src/features/ai-chat/components/ai-chat-window.tsx index 122f80ff..90b6e33d 100644 --- a/apps/client/src/features/ai-chat/components/ai-chat-window.tsx +++ b/apps/client/src/features/ai-chat/components/ai-chat-window.tsx @@ -400,7 +400,16 @@ export default function AiChatWindow() { >
setHistoryOpen((o) => !o)} + onKeyDown={(event) => { + if (event.key === "Enter" || event.key === " ") { + event.preventDefault(); + setHistoryOpen((o) => !o); + } + }} > onTurnFinished(), + // In AI SDK v6 `onFinish` does NOT fire when the stream errors, so a brand + // new chat that fails on its first turn would never invalidate the chat list + // nor adopt the server-created chat id (the server still creates the row and + // saves the error message). Run the same post-turn path on error so the + // failed chat appears in history immediately instead of after a manual + // refresh. The error itself is still surfaced via `error` below. + onError: () => onTurnFinished(), }); const isStreaming = status === "submitted" || status === "streaming"; diff --git a/apps/client/src/features/ai-chat/components/conversation-list.tsx b/apps/client/src/features/ai-chat/components/conversation-list.tsx index c4c566dd..8fe4549c 100644 --- a/apps/client/src/features/ai-chat/components/conversation-list.tsx +++ b/apps/client/src/features/ai-chat/components/conversation-list.tsx @@ -115,7 +115,17 @@ export default function ConversationList({ classes.conversationItem, isActive && classes.conversationItemActive, )} + role="button" + tabIndex={0} onClick={() => onSelect(chat.id)} + onKeyDown={(e) => { + // Activate on Enter/Space like a native button; the inner menu + // button stops propagation so its own keys never reach this row. + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + onSelect(chat.id); + } + }} > {chat.title || t("Untitled chat")} diff --git a/apps/client/src/features/ai-chat/components/message-item.tsx b/apps/client/src/features/ai-chat/components/message-item.tsx index 680d4715..4ba1d934 100644 --- a/apps/client/src/features/ai-chat/components/message-item.tsx +++ b/apps/client/src/features/ai-chat/components/message-item.tsx @@ -3,7 +3,7 @@ import { IconAlertTriangle } from "@tabler/icons-react"; import { useTranslation } from "react-i18next"; import type { UIMessage } from "@ai-sdk/react"; import ToolCallCard from "@/features/ai-chat/components/tool-call-card.tsx"; -import { ToolUiPart } from "@/features/ai-chat/utils/tool-parts.tsx"; +import { ToolUiPart, isToolPart } from "@/features/ai-chat/utils/tool-parts.tsx"; import { renderChatMarkdown } from "@/features/ai-chat/utils/markdown.ts"; import { describeChatError } from "@/features/ai-chat/utils/error-message.ts"; import classes from "@/features/ai-chat/components/ai-chat.module.css"; @@ -12,11 +12,6 @@ interface MessageItemProps { message: UIMessage; } -/** True for AI SDK tool parts (static `tool-*` or `dynamic-tool`). */ -function isToolPart(type: string): boolean { - return type.startsWith("tool-") || type === "dynamic-tool"; -} - /** * Render a single UIMessage by iterating its `parts`: * - `text` parts -> sanitized markdown. diff --git a/apps/client/src/features/ai-chat/components/message-list.tsx b/apps/client/src/features/ai-chat/components/message-list.tsx index fb9d137e..ed0fb73d 100644 --- a/apps/client/src/features/ai-chat/components/message-list.tsx +++ b/apps/client/src/features/ai-chat/components/message-list.tsx @@ -4,6 +4,7 @@ import { useTranslation } from "react-i18next"; import type { UIMessage } from "@ai-sdk/react"; import MessageItem from "@/features/ai-chat/components/message-item.tsx"; import TypingIndicator from "@/features/ai-chat/components/typing-indicator.tsx"; +import { isToolPart } from "@/features/ai-chat/utils/tool-parts.tsx"; import classes from "@/features/ai-chat/components/ai-chat.module.css"; interface MessageListProps { @@ -11,11 +12,6 @@ interface MessageListProps { isStreaming: boolean; } -/** True for AI SDK tool parts (static `tool-*` or `dynamic-tool`). */ -function isToolPart(type: string): boolean { - return type.startsWith("tool-") || type === "dynamic-tool"; -} - // Distance (px) from the bottom within which the viewport still counts as // "pinned" — absorbs sub-pixel rounding and small content jitter. const BOTTOM_THRESHOLD = 40; diff --git a/apps/client/src/features/ai-chat/utils/tool-parts.tsx b/apps/client/src/features/ai-chat/utils/tool-parts.tsx index e7705936..be972050 100644 --- a/apps/client/src/features/ai-chat/utils/tool-parts.tsx +++ b/apps/client/src/features/ai-chat/utils/tool-parts.tsx @@ -5,9 +5,11 @@ * * A tool part's `type` is `tool-${toolName}` (AI SDK v6 static tool parts) and * its `state` is one of input-streaming / input-available / output-available / - * output-error (we only surface running / done / error). The server tools are: - * searchPages, getPage, createPage, updatePageContent, renamePage, movePage, - * deletePage, createComment, resolveComment — see ai-chat-tools.service.ts. + * output-error (we only surface running / done / error). The full toolset the + * server exposes lives in `ai-chat-tools.service.ts` (the agent now exposes the + * complete Docmost toolset); friendly action-log labels exist ONLY for the + * tools listed in `toolLabelKey` below — every other tool falls through to the + * generic "Ran tool {{name}}" label. */ /** A tool UI part as it arrives from `useChat` / persisted history. */ @@ -38,6 +40,11 @@ export interface ToolCitation { href: string; } +/** True for AI SDK tool parts (static `tool-*` or `dynamic-tool`). */ +export function isToolPart(type: string): boolean { + return type.startsWith("tool-") || type === "dynamic-tool"; +} + /** Extract the tool name from a part `type` of `tool-${name}` (or dynamic). */ export function getToolName(part: ToolUiPart): string { if (part.type === "dynamic-tool") return part.toolName ?? ""; diff --git a/apps/client/src/features/workspace/components/settings/components/ai-mcp-server-form.tsx b/apps/client/src/features/workspace/components/settings/components/ai-mcp-server-form.tsx index 3e6a8958..1d07e7d5 100644 --- a/apps/client/src/features/workspace/components/settings/components/ai-mcp-server-form.tsx +++ b/apps/client/src/features/workspace/components/settings/components/ai-mcp-server-form.tsx @@ -47,6 +47,21 @@ interface AiMcpServerFormProps { onClose: () => void; } +// Build the form's field values from a (possibly undefined) server. Used both +// for the initial mount and for re-hydration when the modal is reused for a +// different server, so the two stay in sync. authHeader is always empty: it is +// a write-only secret buffer never echoed back from the server. +function buildInitialValues(server?: IAiMcpServer): FormValues { + return { + name: server?.name ?? "", + transport: server?.transport ?? "http", + url: server?.url ?? "", + authHeader: "", + toolAllowlist: server?.toolAllowlist ?? [], + enabled: server?.enabled ?? true, + }; +} + // Tavily preset (§8.10): the API key goes in the Authorization HEADER, not the URL. const TAVILY_PRESET = { name: "Tavily", @@ -72,26 +87,12 @@ export default function AiMcpServerForm({ const form = useForm({ validate: zod4Resolver(formSchema), - initialValues: { - name: server?.name ?? "", - transport: server?.transport ?? "http", - url: server?.url ?? "", - authHeader: "", - toolAllowlist: server?.toolAllowlist ?? [], - enabled: server?.enabled ?? true, - }, + initialValues: buildInitialValues(server), }); // Re-hydrate when the target server changes (e.g. reusing the modal). useEffect(() => { - form.setValues({ - name: server?.name ?? "", - transport: server?.transport ?? "http", - url: server?.url ?? "", - authHeader: "", - toolAllowlist: server?.toolAllowlist ?? [], - enabled: server?.enabled ?? true, - }); + form.setValues(buildInitialValues(server)); form.resetDirty(); setHasHeaders(server?.hasHeaders ?? false); setHeadersCleared(false); diff --git a/apps/server/src/core/ai-chat/ai-chat.service.spec.ts b/apps/server/src/core/ai-chat/ai-chat.service.spec.ts index f1f3461a..ef99a462 100644 --- a/apps/server/src/core/ai-chat/ai-chat.service.spec.ts +++ b/apps/server/src/core/ai-chat/ai-chat.service.spec.ts @@ -1,4 +1,10 @@ -import { compactToolOutput } from './ai-chat.service'; +import { + compactToolOutput, + assistantParts, + serializeSteps, + rowToUiMessage, +} from './ai-chat.service'; +import type { AiChatMessage } from '@docmost/db/types/entity.types'; /** * Unit tests for compactToolOutput: the pure helper that shrinks LARGE tool @@ -66,3 +72,121 @@ describe('compactToolOutput', () => { expect(compactedBytes).toBeLessThan(originalBytes / 10); }); }); + +/** + * Tests for assistantParts: the pure function that rebuilds the persisted + * UIMessage parts for a turn. Its output decides whether the conversation + * replays correctly on the next turn. The crux: a tool-call WITHOUT a paired + * result must become a synthetic `output-error` part, so convertToModelMessages + * never throws MissingToolResultsError. This test MUST fail on pre-fix logic + * that persisted a bare input-available call. + */ +describe('assistantParts', () => { + type AnyPart = Record; + + it('emits output-available for a tool-call WITH a paired result', () => { + const steps = [ + { + text: '', + toolCalls: [{ toolCallId: 'c1', toolName: 'getPage', input: { id: 'p1' } }], + toolResults: [{ toolCallId: 'c1', toolName: 'getPage', output: { title: 'T' } }], + }, + ]; + const parts = assistantParts(steps, '') as AnyPart[]; + const toolPart = parts.find((p) => p.type === 'tool-getPage'); + expect(toolPart).toBeDefined(); + expect(toolPart!.state).toBe('output-available'); + expect(toolPart!.output).toEqual({ title: 'T' }); + }); + + it('emits a synthetic output-error for an UNPAIRED tool-call (crux)', () => { + const steps = [ + { + text: '', + toolCalls: [{ toolCallId: 'c9', toolName: 'insertNode', input: { node: {} } }], + toolResults: [], + }, + ]; + const parts = assistantParts(steps, '') as AnyPart[]; + const toolPart = parts.find((p) => p.type === 'tool-insertNode'); + expect(toolPart).toBeDefined(); + // The unpaired call MUST become output-error (NOT input-available), so the + // rebuilt history is balanced for convertToModelMessages on the next turn. + expect(toolPart!.state).toBe('output-error'); + expect(toolPart!.errorText).toBeTruthy(); + expect(toolPart).not.toHaveProperty('output'); + }); + + it('skips malformed tool-calls (missing toolName or toolCallId)', () => { + const steps = [ + { + text: '', + toolCalls: [ + { toolCallId: 'c1', input: {} }, // no toolName + { toolName: 'getPage', input: {} }, // no toolCallId + ], + toolResults: [], + }, + ]; + const parts = assistantParts(steps, '') as AnyPart[]; + const toolParts = parts.filter( + (p) => typeof p.type === 'string' && (p.type as string).startsWith('tool-'), + ); + expect(toolParts).toHaveLength(0); + }); + + it('uses per-step text when present', () => { + const steps = [{ text: 'hello', toolCalls: [], toolResults: [] }]; + const parts = assistantParts(steps, 'fallback-ignored') as AnyPart[]; + expect(parts).toEqual([{ type: 'text', text: 'hello' }]); + }); + + it('falls back to a single text part when no step text', () => { + const parts = assistantParts([], 'final answer') as AnyPart[]; + expect(parts).toEqual([{ type: 'text', text: 'final answer' }]); + }); +}); + +describe('serializeSteps', () => { + it('returns null when there are no calls or results', () => { + expect(serializeSteps([])).toBeNull(); + }); + + it('flattens calls and results into a compact trace', () => { + const trace = serializeSteps([ + { + toolCalls: [{ toolName: 'getPage', input: { id: 'p1' } }], + toolResults: [{ toolName: 'getPage', output: { title: 'T' } }], + }, + ]) as Array>; + expect(trace).toHaveLength(2); + expect(trace[0]).toEqual({ toolName: 'getPage', input: { id: 'p1' } }); + expect(trace[1]).toEqual({ toolName: 'getPage', output: { title: 'T' } }); + }); +}); + +describe('rowToUiMessage', () => { + it('prefers metadata.parts over content', () => { + const row = { + id: 'm1', + role: 'assistant', + content: 'plain text', + metadata: { parts: [{ type: 'text', text: 'rich part' }] }, + } as unknown as AiChatMessage; + const ui = rowToUiMessage(row); + expect(ui.role).toBe('assistant'); + expect(ui.parts).toEqual([{ type: 'text', text: 'rich part' }]); + }); + + it('falls back to a single text part from content when no metadata.parts', () => { + const row = { + id: 'm2', + role: 'user', + content: 'hi there', + metadata: null, + } as unknown as AiChatMessage; + const ui = rowToUiMessage(row); + expect(ui.role).toBe('user'); + expect(ui.parts).toEqual([{ type: 'text', text: 'hi there' }]); + }); +}); diff --git a/apps/server/src/core/ai-chat/ai-chat.service.ts b/apps/server/src/core/ai-chat/ai-chat.service.ts index 3119c3c4..8230dc6d 100644 --- a/apps/server/src/core/ai-chat/ai-chat.service.ts +++ b/apps/server/src/core/ai-chat/ai-chat.service.ts @@ -10,6 +10,7 @@ import { } from 'ai'; import { AiService } from '../../integrations/ai/ai.service'; import { AiSettingsService } from '../../integrations/ai/ai-settings.service'; +import { describeProviderError } from '../../integrations/ai/ai-error.util'; import { AiChatRepo } from '@docmost/db/repos/ai-chat/ai-chat.repo'; import { AiChatMessageRepo } from '@docmost/db/repos/ai-chat/ai-chat-message.repo'; import { User, Workspace, AiChatMessage } from '@docmost/db/types/entity.types'; @@ -271,15 +272,10 @@ export class AiChatService { onError: async ({ error }) => { // NestJS Logger.error(message, stack?, context?): pass the real message // (with statusCode when present) + the stack string, not the Error - // object, so the actual provider cause is clearly logged. - const e = error as { - statusCode?: number; - message?: string; - stack?: string; - }; - const errorText = e?.statusCode - ? `${e.statusCode}: ${e.message ?? String(error)}` - : (e?.message ?? String(error)); + // object, so the actual provider cause is clearly logged. Reuse the + // shared formatter so provider error formatting stays unified. + const e = error as { stack?: string }; + const errorText = describeProviderError(error, String(error)); this.logger.error(`AI chat stream error: ${errorText}`, e?.stack); // Persist whatever text we have (likely empty) so the turn is recorded, // and record the error text in metadata so it is visible in history. @@ -340,10 +336,9 @@ export class AiChatService { result.pipeUIMessageStreamToResponse(res.raw, { headers: { 'X-Accel-Buffering': 'no' }, onError: (error: unknown) => { - const e = error as { statusCode?: number; message?: string }; - return e?.statusCode - ? `${e.statusCode}: ${e.message}` - : (e?.message ?? 'AI stream error'); + // Reuse the shared formatter so provider error formatting stays + // unified between the log line and the streamed error message. + return describeProviderError(error, 'AI stream error'); }, }); @@ -538,7 +533,9 @@ function compactValue(value: unknown, depth: number): unknown { * recovers the name. Falls back to a single `text` part built from * `fallbackText` when the steps carry no text. */ -function assistantParts( +// Exported only so the unit tests can import these pure helpers; exporting +// them does not change runtime behavior. +export function assistantParts( steps: ReadonlyArray | undefined, fallbackText: string, ): UIMessage['parts'] { @@ -596,7 +593,7 @@ function assistantParts( * stored parts when available; assistant messages restore the reconstructable * parts from metadata, falling back to a single text part from `content`. */ -function rowToUiMessage(row: AiChatMessage): Omit & { +export function rowToUiMessage(row: AiChatMessage): Omit & { id: string; } { const role = row.role === 'assistant' ? 'assistant' : 'user'; @@ -613,7 +610,7 @@ function rowToUiMessage(row: AiChatMessage): Omit & { * `tool_calls` column. Stores only what the UI action-log and history need — * never raw provider payloads or keys. */ -function serializeSteps( +export function serializeSteps( steps: ReadonlyArray<{ toolCalls?: ReadonlyArray<{ toolName?: string; input?: unknown }>; toolResults?: ReadonlyArray<{ toolName?: string; output?: unknown }>; diff --git a/apps/server/src/core/ai-chat/external-mcp/ssrf-guard.spec.ts b/apps/server/src/core/ai-chat/external-mcp/ssrf-guard.spec.ts new file mode 100644 index 00000000..bd115129 --- /dev/null +++ b/apps/server/src/core/ai-chat/external-mcp/ssrf-guard.spec.ts @@ -0,0 +1,119 @@ +/** + * Unit tests for the SSRF guard protecting admin-configured external MCP URLs. + * + * `isIpAllowed` is pure/sync: every blocked address class must be rejected and a + * public address allowed. `isUrlAllowed` adds scheme/URL validation and, for + * hostnames, a DNS resolve + re-check (the DNS-rebinding defense): a name that + * resolves to a private address must be blocked. We mock `node:dns` `lookup` + * (the guard promisifies it) so the rebinding case is deterministic and offline. + */ + +// Mock node:dns BEFORE importing the guard so promisify(lookup) wraps our mock. +const lookupMock = jest.fn(); +jest.mock('node:dns', () => ({ + __esModule: true, + lookup: (...args: unknown[]) => lookupMock(...args), +})); + +import { isIpAllowed, isUrlAllowed } from './ssrf-guard'; + +// The guard calls promisify(lookup): our mock must honour the (host, opts, cb) +// callback signature. Helper to make it resolve to a given address list. +function dnsResolvesTo(addresses: { address: string }[]) { + lookupMock.mockImplementation( + (_host: string, _opts: unknown, cb: (e: unknown, a: unknown) => void) => { + cb(null, addresses); + }, + ); +} + +describe('isIpAllowed', () => { + const blocked: Array<[string, string]> = [ + ['loopback IPv4', '127.0.0.1'], + ['loopback IPv6', '::1'], + ['link-local / metadata', '169.254.169.254'], + ['private 10/8', '10.0.0.1'], + ['private 172.16/12', '172.16.5.4'], + ['private 192.168/16', '192.168.1.1'], + ['CGNAT 100.64/10', '100.64.1.1'], + ['ULA fc00::/7', 'fc00::1'], + ['unspecified IPv4', '0.0.0.0'], + ['unspecified IPv6', '::'], + ['IPv4-mapped IPv6 (private)', '::ffff:10.0.0.1'], + ]; + + it.each(blocked)('blocks %s (%s)', (_label, ip) => { + expect(isIpAllowed(ip).ok).toBe(false); + }); + + it('allows a public IPv4 (8.8.8.8)', () => { + expect(isIpAllowed('8.8.8.8').ok).toBe(true); + }); + + it('allows a public IPv6', () => { + expect(isIpAllowed('2001:4860:4860::8888').ok).toBe(true); + }); + + it('blocks an unparseable IP', () => { + expect(isIpAllowed('not-an-ip').ok).toBe(false); + }); +}); + +describe('isUrlAllowed', () => { + beforeEach(() => { + lookupMock.mockReset(); + }); + + it('blocks a non-http(s) scheme', async () => { + const res = await isUrlAllowed('ftp://example.com/'); + expect(res.ok).toBe(false); + expect(lookupMock).not.toHaveBeenCalled(); + }); + + it('blocks an invalid URL', async () => { + const res = await isUrlAllowed('::: not a url :::'); + expect(res.ok).toBe(false); + expect(lookupMock).not.toHaveBeenCalled(); + }); + + it('blocks a private IP literal host without DNS', async () => { + const res = await isUrlAllowed('http://169.254.169.254/latest/meta-data/'); + expect(res.ok).toBe(false); + expect(lookupMock).not.toHaveBeenCalled(); + }); + + it('blocks a bracketed private IPv6 literal host', async () => { + const res = await isUrlAllowed('http://[::1]:8080/'); + expect(res.ok).toBe(false); + expect(lookupMock).not.toHaveBeenCalled(); + }); + + it('blocks a hostname that resolves to a private address (DNS rebinding)', async () => { + dnsResolvesTo([{ address: '10.0.0.5' }]); + const res = await isUrlAllowed('http://rebind.example.com/'); + expect(res.ok).toBe(false); + expect(lookupMock).toHaveBeenCalled(); + }); + + it('blocks when ANY resolved address is private (mixed result)', async () => { + dnsResolvesTo([{ address: '8.8.8.8' }, { address: '127.0.0.1' }]); + const res = await isUrlAllowed('http://mixed.example.com/'); + expect(res.ok).toBe(false); + }); + + it('allows a hostname that resolves only to a public address', async () => { + dnsResolvesTo([{ address: '8.8.8.8' }]); + const res = await isUrlAllowed('https://public.example.com/mcp'); + expect(res.ok).toBe(true); + }); + + it('blocks when the host does not resolve', async () => { + lookupMock.mockImplementation( + (_host: string, _opts: unknown, cb: (e: unknown, a: unknown) => void) => { + cb(new Error('ENOTFOUND'), undefined); + }, + ); + const res = await isUrlAllowed('http://nonexistent.invalid/'); + expect(res.ok).toBe(false); + }); +}); diff --git a/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.spec.ts b/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.spec.ts index 65218300..becf082f 100644 --- a/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.spec.ts +++ b/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.spec.ts @@ -211,3 +211,174 @@ describe('AiChatToolsService expanded toolset guardrails', () => { expect(parsed).not.toHaveProperty('deleteComments'); }); }); + +/** + * JSON-string coercion for node arguments (fix 59b99dba): under OpenAI tool + * calls the model sometimes serializes `node`/`content` as a JSON STRING. The + * tools parse a string into an object before forwarding it to the client (which + * type-checks for an object), throw a documented message on invalid JSON, and + * `updatePageJson` distinguishes undefined (title-only) from object/string. + */ +describe('AiChatToolsService node-arg JSON-string coercion', () => { + // Records the positional args forwarded to each write method so we can assert + // the coerced (parsed) value reaches the client. + const patchNodeCalls: unknown[][] = []; + const insertNodeCalls: unknown[][] = []; + const updatePageJsonCalls: unknown[][] = []; + + const fakeClient: Partial = { + patchNode: (...args: unknown[]) => { + patchNodeCalls.push(args); + return Promise.resolve({ ok: true }); + }, + insertNode: (...args: unknown[]) => { + insertNodeCalls.push(args); + return Promise.resolve({ ok: true }); + }, + updatePageJson: (...args: unknown[]) => { + updatePageJsonCalls.push(args); + return Promise.resolve({ ok: true }); + }, + }; + + const tokenServiceStub = { + generateAccessToken: jest.fn().mockResolvedValue('access-token'), + generateCollabToken: jest.fn().mockResolvedValue('collab-token'), + }; + + let service: AiChatToolsService; + + beforeEach(() => { + patchNodeCalls.length = 0; + insertNodeCalls.length = 0; + updatePageJsonCalls.length = 0; + jest.spyOn(loader, 'loadDocmostMcp').mockResolvedValue({ + DocmostClient: function () { + return fakeClient as DocmostClientLike; + } as unknown as loader.DocmostClientCtor, + }); + service = new AiChatToolsService( + tokenServiceStub as never, + {} as never, + {} as never, + {} as never, + {} as never, + ); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + function buildTools() { + return service.forUser( + { id: 'user-1', email: 'u@example.com', workspaceId: 'ws-1' } as never, + 'session-1', + 'ws-1', + 'chat-1', + ); + } + + const NODE_OBJ = { + type: 'paragraph', + content: [{ type: 'text', text: 'Hello' }], + }; + + it('patchNode parses a JSON-string node and forwards it as an object', async () => { + const tools = await buildTools(); + await tools.patchNode.execute( + { pageId: 'p1', nodeId: 'n1', node: JSON.stringify(NODE_OBJ) } as never, + {} as never, + ); + expect(patchNodeCalls).toHaveLength(1); + expect(patchNodeCalls[0]).toEqual(['p1', 'n1', NODE_OBJ]); + }); + + it('patchNode passes an object node through unchanged', async () => { + const tools = await buildTools(); + await tools.patchNode.execute( + { pageId: 'p1', nodeId: 'n1', node: NODE_OBJ } as never, + {} as never, + ); + expect(patchNodeCalls[0]).toEqual(['p1', 'n1', NODE_OBJ]); + }); + + it('patchNode throws the documented message on invalid JSON string', async () => { + const tools = await buildTools(); + await expect( + tools.patchNode.execute( + { pageId: 'p1', nodeId: 'n1', node: '{not json' } as never, + {} as never, + ), + ).rejects.toThrow('node was a string but not valid JSON'); + expect(patchNodeCalls).toHaveLength(0); + }); + + it('insertNode parses a JSON-string node and forwards it as an object', async () => { + const tools = await buildTools(); + await tools.insertNode.execute( + { + pageId: 'p1', + node: JSON.stringify(NODE_OBJ), + position: 'append', + } as never, + {} as never, + ); + expect(insertNodeCalls).toHaveLength(1); + const [pageId, node] = insertNodeCalls[0]; + expect(pageId).toBe('p1'); + expect(node).toEqual(NODE_OBJ); + }); + + it('insertNode throws the documented message on invalid JSON string', async () => { + const tools = await buildTools(); + await expect( + tools.insertNode.execute( + { pageId: 'p1', node: 'nope', position: 'append' } as never, + {} as never, + ), + ).rejects.toThrow('node was a string but not valid JSON'); + expect(insertNodeCalls).toHaveLength(0); + }); + + it('updatePageJson forwards doc=undefined for a title-only update (content undefined)', async () => { + const tools = await buildTools(); + await tools.updatePageJson.execute( + { pageId: 'p1', title: 'New title' } as never, + {} as never, + ); + expect(updatePageJsonCalls).toHaveLength(1); + expect(updatePageJsonCalls[0]).toEqual(['p1', undefined, 'New title']); + }); + + it('updatePageJson passes an object content through unchanged', async () => { + const tools = await buildTools(); + const doc = { type: 'doc', content: [] }; + await tools.updatePageJson.execute( + { pageId: 'p1', content: doc } as never, + {} as never, + ); + expect(updatePageJsonCalls[0]).toEqual(['p1', doc, undefined]); + }); + + it('updatePageJson parses a JSON-string content', async () => { + const tools = await buildTools(); + const doc = { type: 'doc', content: [] }; + await tools.updatePageJson.execute( + { pageId: 'p1', content: JSON.stringify(doc) } as never, + {} as never, + ); + expect(updatePageJsonCalls[0]).toEqual(['p1', doc, undefined]); + }); + + it('updatePageJson throws the documented message on invalid JSON string content', async () => { + const tools = await buildTools(); + await expect( + tools.updatePageJson.execute( + { pageId: 'p1', content: '{bad' } as never, + {} as never, + ), + ).rejects.toThrow('content was a string but not valid JSON'); + expect(updatePageJsonCalls).toHaveLength(0); + }); +}); diff --git a/apps/server/src/database/repos/ai-chat/page-embedding.repo.spec.ts b/apps/server/src/database/repos/ai-chat/page-embedding.repo.spec.ts new file mode 100644 index 00000000..792c8762 --- /dev/null +++ b/apps/server/src/database/repos/ai-chat/page-embedding.repo.spec.ts @@ -0,0 +1,26 @@ +import { PageEmbeddingRepo } from './page-embedding.repo'; +import type { KyselyDB } from '../../types/kysely.types'; + +/** + * Unit test for the pure access-scoping branch of searchByEmbedding: when the + * caller has NO accessible spaces (`spaceIds` empty), the method must early- + * return [] WITHOUT touching the database. We inject a db whose query builder + * throws if invoked, so any DB access fails the test. + * + * NOTE: the dimension-mixing case (filter by model_dimensions) needs a live + * pgvector-enabled Postgres and is intentionally NOT covered here — it requires + * a real DB and is out of scope for this pure unit test. + */ +describe('PageEmbeddingRepo.searchByEmbedding', () => { + it('early-returns [] for empty spaceIds without any DB call', async () => { + const throwingDb = { + selectFrom: () => { + throw new Error('DB should not be queried for empty spaceIds'); + }, + } as unknown as KyselyDB; + + const repo = new PageEmbeddingRepo(throwingDb); + const result = await repo.searchByEmbedding('ws-1', [0.1, 0.2, 0.3], [], 10); + expect(result).toEqual([]); + }); +}); diff --git a/apps/server/src/integrations/ai/ai-error.util.ts b/apps/server/src/integrations/ai/ai-error.util.ts index 68fa328b..0a0f949b 100644 --- a/apps/server/src/integrations/ai/ai-error.util.ts +++ b/apps/server/src/integrations/ai/ai-error.util.ts @@ -9,10 +9,16 @@ * * None of these fields contain the API key (it is sent as an Authorization * header and never echoed in the response body), so this is safe to log/return. + * + * `fallback` is used when the error carries no usable message (e.g. a bare + * object); defaults to 'Unknown error'. */ -export function describeProviderError(err: unknown): string { +export function describeProviderError( + err: unknown, + fallback = 'Unknown error', +): string { if (typeof err !== 'object' || err === null) { - return typeof err === 'string' ? err : 'Unknown error'; + return typeof err === 'string' && err ? err : fallback; } const e = err as { statusCode?: number; @@ -23,7 +29,7 @@ export function describeProviderError(err: unknown): string { const base = typeof e.statusCode === 'number' ? `${e.statusCode}: ${e.message ?? ''}`.trim() - : (e.message ?? 'Unknown error'); + : (e.message ?? fallback); const body = (e.responseBody ?? e.text ?? '').trim(); if (!body) return base; // Collapse whitespace so a multi-line HTML body stays on one log line. diff --git a/apps/server/src/integrations/crypto/secret-box.spec.ts b/apps/server/src/integrations/crypto/secret-box.spec.ts new file mode 100644 index 00000000..d53d7093 --- /dev/null +++ b/apps/server/src/integrations/crypto/secret-box.spec.ts @@ -0,0 +1,77 @@ +import { SecretBoxService } from './secret-box'; +import { EnvironmentService } from '../environment/environment.service'; + +/** + * Unit tests for SecretBoxService: the AES-256-GCM helper that protects provider + * API keys at rest. The contract is: encrypt -> decrypt round-trips the input; + * two encryptions of the same input yield different blobs (random salt+iv) yet + * both decrypt; a tampered blob or a different APP_SECRET fails decryption with + * the recoverable "APP_SECRET may have changed" message the UI relies on. + */ +describe('SecretBoxService', () => { + // Construct a SecretBoxService whose EnvironmentService.getAppSecret returns a + // fixed 64-hex secret. Only getAppSecret is exercised, so a thin fake suffices. + function makeBox(appSecret: string): SecretBoxService { + const env = { + getAppSecret: () => appSecret, + } as unknown as EnvironmentService; + return new SecretBoxService(env); + } + + const SECRET_A = + '00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff'; + const SECRET_B = + 'ffeeddccbbaa99887766554433221100ffeeddccbbaa99887766554433221100'; + + it('round-trips: decrypt(encrypt(x)) === x', () => { + const box = makeBox(SECRET_A); + const plain = 'sk-super-secret-provider-key-12345'; + const blob = box.encryptSecret(plain); + expect(box.decryptSecret(blob)).toBe(plain); + }); + + it('produces a different blob each time, both of which decrypt', () => { + const box = makeBox(SECRET_A); + const plain = 'identical-input'; + const blob1 = box.encryptSecret(plain); + const blob2 = box.encryptSecret(plain); + // Random per-record salt + iv => the ciphertext blobs must differ. + expect(blob1).not.toBe(blob2); + expect(box.decryptSecret(blob1)).toBe(plain); + expect(box.decryptSecret(blob2)).toBe(plain); + }); + + it('throws the recoverable error on a tampered auth tag', () => { + const box = makeBox(SECRET_A); + const blob = box.encryptSecret('tamper-me'); + + // Layout: base64( salt[16] | iv[12] | authTag[16] | ciphertext ). Flip a bit + // in the auth-tag region so GCM verification (decipher.final) rejects it. + const data = Buffer.from(blob, 'base64'); + const authTagByteIndex = 16 + 12; // first byte of the auth tag + data[authTagByteIndex] = data[authTagByteIndex] ^ 0xff; + const tampered = data.toString('base64'); + + expect(() => box.decryptSecret(tampered)).toThrow(/APP_SECRET may have changed/); + }); + + it('throws the recoverable error on a tampered ciphertext byte', () => { + const box = makeBox(SECRET_A); + const blob = box.encryptSecret('tamper-the-body'); + + const data = Buffer.from(blob, 'base64'); + // Last byte is part of the ciphertext; flipping it must fail GCM auth. + data[data.length - 1] = data[data.length - 1] ^ 0xff; + const tampered = data.toString('base64'); + + expect(() => box.decryptSecret(tampered)).toThrow(/APP_SECRET may have changed/); + }); + + it('throws when decrypting under a different APP_SECRET', () => { + const boxA = makeBox(SECRET_A); + const boxB = makeBox(SECRET_B); + const blob = boxA.encryptSecret('rotate-me'); + // A different APP_SECRET derives a different scrypt key => GCM auth fails. + expect(() => boxB.decryptSecret(blob)).toThrow(/APP_SECRET may have changed/); + }); +});