test(ai-chat): safety-critical coverage + a11y + pure refactors

Unit tests for the safety-critical paths: crypto secret-box (round-trip,
tamper detection, wrong key), the SSRF guard (blocked ranges + DNS-rebinding),
the ai-chat tools service, the page-embedding repo, and the
assistant-parts/serialization helpers. Those server helpers (assistantParts,
rowToUiMessage, serializeSteps) are exported ONLY for the tests — no runtime
change.

Also: keyboard a11y on the chat history header and conversation rows
(role/tabIndex/Enter+Space), and DRY refactors that move shared logic into one
place (isToolPart -> tool-parts util; buildInitialValues in the MCP form).

The behaviour-changing edits that previously rode along in this commit are
split out into the following two commits, per review.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
claude code agent 227
2026-06-20 17:58:44 +03:00
committed by vvzvlad
parent c8af637654
commit f1980cf425
13 changed files with 571 additions and 236 deletions

View File

@@ -1,4 +1,10 @@
import { compactToolOutput } from './ai-chat.service';
import {
compactToolOutput,
assistantParts,
serializeSteps,
rowToUiMessage,
} from './ai-chat.service';
import type { AiChatMessage } from '@docmost/db/types/entity.types';
/**
* Unit tests for compactToolOutput: the pure helper that shrinks LARGE tool
@@ -66,3 +72,121 @@ describe('compactToolOutput', () => {
expect(compactedBytes).toBeLessThan(originalBytes / 10);
});
});
/**
* Tests for assistantParts: the pure function that rebuilds the persisted
* UIMessage parts for a turn. Its output decides whether the conversation
* replays correctly on the next turn. The crux: a tool-call WITHOUT a paired
* result must become a synthetic `output-error` part, so convertToModelMessages
* never throws MissingToolResultsError. This test MUST fail on pre-fix logic
* that persisted a bare input-available call.
*/
describe('assistantParts', () => {
type AnyPart = Record<string, unknown>;
it('emits output-available for a tool-call WITH a paired result', () => {
const steps = [
{
text: '',
toolCalls: [{ toolCallId: 'c1', toolName: 'getPage', input: { id: 'p1' } }],
toolResults: [{ toolCallId: 'c1', toolName: 'getPage', output: { title: 'T' } }],
},
];
const parts = assistantParts(steps, '') as AnyPart[];
const toolPart = parts.find((p) => p.type === 'tool-getPage');
expect(toolPart).toBeDefined();
expect(toolPart!.state).toBe('output-available');
expect(toolPart!.output).toEqual({ title: 'T' });
});
it('emits a synthetic output-error for an UNPAIRED tool-call (crux)', () => {
const steps = [
{
text: '',
toolCalls: [{ toolCallId: 'c9', toolName: 'insertNode', input: { node: {} } }],
toolResults: [],
},
];
const parts = assistantParts(steps, '') as AnyPart[];
const toolPart = parts.find((p) => p.type === 'tool-insertNode');
expect(toolPart).toBeDefined();
// The unpaired call MUST become output-error (NOT input-available), so the
// rebuilt history is balanced for convertToModelMessages on the next turn.
expect(toolPart!.state).toBe('output-error');
expect(toolPart!.errorText).toBeTruthy();
expect(toolPart).not.toHaveProperty('output');
});
it('skips malformed tool-calls (missing toolName or toolCallId)', () => {
const steps = [
{
text: '',
toolCalls: [
{ toolCallId: 'c1', input: {} }, // no toolName
{ toolName: 'getPage', input: {} }, // no toolCallId
],
toolResults: [],
},
];
const parts = assistantParts(steps, '') as AnyPart[];
const toolParts = parts.filter(
(p) => typeof p.type === 'string' && (p.type as string).startsWith('tool-'),
);
expect(toolParts).toHaveLength(0);
});
it('uses per-step text when present', () => {
const steps = [{ text: 'hello', toolCalls: [], toolResults: [] }];
const parts = assistantParts(steps, 'fallback-ignored') as AnyPart[];
expect(parts).toEqual([{ type: 'text', text: 'hello' }]);
});
it('falls back to a single text part when no step text', () => {
const parts = assistantParts([], 'final answer') as AnyPart[];
expect(parts).toEqual([{ type: 'text', text: 'final answer' }]);
});
});
describe('serializeSteps', () => {
it('returns null when there are no calls or results', () => {
expect(serializeSteps([])).toBeNull();
});
it('flattens calls and results into a compact trace', () => {
const trace = serializeSteps([
{
toolCalls: [{ toolName: 'getPage', input: { id: 'p1' } }],
toolResults: [{ toolName: 'getPage', output: { title: 'T' } }],
},
]) as Array<Record<string, unknown>>;
expect(trace).toHaveLength(2);
expect(trace[0]).toEqual({ toolName: 'getPage', input: { id: 'p1' } });
expect(trace[1]).toEqual({ toolName: 'getPage', output: { title: 'T' } });
});
});
describe('rowToUiMessage', () => {
it('prefers metadata.parts over content', () => {
const row = {
id: 'm1',
role: 'assistant',
content: 'plain text',
metadata: { parts: [{ type: 'text', text: 'rich part' }] },
} as unknown as AiChatMessage;
const ui = rowToUiMessage(row);
expect(ui.role).toBe('assistant');
expect(ui.parts).toEqual([{ type: 'text', text: 'rich part' }]);
});
it('falls back to a single text part from content when no metadata.parts', () => {
const row = {
id: 'm2',
role: 'user',
content: 'hi there',
metadata: null,
} as unknown as AiChatMessage;
const ui = rowToUiMessage(row);
expect(ui.role).toBe('user');
expect(ui.parts).toEqual([{ type: 'text', text: 'hi there' }]);
});
});

View File

@@ -538,7 +538,9 @@ function compactValue(value: unknown, depth: number): unknown {
* recovers the name. Falls back to a single `text` part built from
* `fallbackText` when the steps carry no text.
*/
function assistantParts(
// Exported only so the unit tests can import this pure helper; exporting it
// does not change runtime behavior.
export function assistantParts(
steps: ReadonlyArray<StepLike> | undefined,
fallbackText: string,
): UIMessage['parts'] {
@@ -596,7 +598,7 @@ function assistantParts(
* stored parts when available; assistant messages restore the reconstructable
* parts from metadata, falling back to a single text part from `content`.
*/
function rowToUiMessage(row: AiChatMessage): Omit<UIMessage, 'id'> & {
export function rowToUiMessage(row: AiChatMessage): Omit<UIMessage, 'id'> & {
id: string;
} {
const role = row.role === 'assistant' ? 'assistant' : 'user';
@@ -613,7 +615,7 @@ function rowToUiMessage(row: AiChatMessage): Omit<UIMessage, 'id'> & {
* `tool_calls` column. Stores only what the UI action-log and history need —
* never raw provider payloads or keys.
*/
function serializeSteps(
export function serializeSteps(
steps: ReadonlyArray<{
toolCalls?: ReadonlyArray<{ toolName?: string; input?: unknown }>;
toolResults?: ReadonlyArray<{ toolName?: string; output?: unknown }>;

View File

@@ -0,0 +1,119 @@
/**
* Unit tests for the SSRF guard protecting admin-configured external MCP URLs.
*
* `isIpAllowed` is pure/sync: every blocked address class must be rejected and a
* public address allowed. `isUrlAllowed` adds scheme/URL validation and, for
* hostnames, a DNS resolve + re-check (the DNS-rebinding defense): a name that
* resolves to a private address must be blocked. We mock `node:dns` `lookup`
* (the guard promisifies it) so the rebinding case is deterministic and offline.
*/
// Mock node:dns BEFORE importing the guard so promisify(lookup) wraps our mock.
const lookupMock = jest.fn();
jest.mock('node:dns', () => ({
__esModule: true,
lookup: (...args: unknown[]) => lookupMock(...args),
}));
import { isIpAllowed, isUrlAllowed } from './ssrf-guard';
// The guard calls promisify(lookup): our mock must honour the (host, opts, cb)
// callback signature. Helper to make it resolve to a given address list.
function dnsResolvesTo(addresses: { address: string }[]) {
lookupMock.mockImplementation(
(_host: string, _opts: unknown, cb: (e: unknown, a: unknown) => void) => {
cb(null, addresses);
},
);
}
describe('isIpAllowed', () => {
const blocked: Array<[string, string]> = [
['loopback IPv4', '127.0.0.1'],
['loopback IPv6', '::1'],
['link-local / metadata', '169.254.169.254'],
['private 10/8', '10.0.0.1'],
['private 172.16/12', '172.16.5.4'],
['private 192.168/16', '192.168.1.1'],
['CGNAT 100.64/10', '100.64.1.1'],
['ULA fc00::/7', 'fc00::1'],
['unspecified IPv4', '0.0.0.0'],
['unspecified IPv6', '::'],
['IPv4-mapped IPv6 (private)', '::ffff:10.0.0.1'],
];
it.each(blocked)('blocks %s (%s)', (_label, ip) => {
expect(isIpAllowed(ip).ok).toBe(false);
});
it('allows a public IPv4 (8.8.8.8)', () => {
expect(isIpAllowed('8.8.8.8').ok).toBe(true);
});
it('allows a public IPv6', () => {
expect(isIpAllowed('2001:4860:4860::8888').ok).toBe(true);
});
it('blocks an unparseable IP', () => {
expect(isIpAllowed('not-an-ip').ok).toBe(false);
});
});
describe('isUrlAllowed', () => {
beforeEach(() => {
lookupMock.mockReset();
});
it('blocks a non-http(s) scheme', async () => {
const res = await isUrlAllowed('ftp://example.com/');
expect(res.ok).toBe(false);
expect(lookupMock).not.toHaveBeenCalled();
});
it('blocks an invalid URL', async () => {
const res = await isUrlAllowed('::: not a url :::');
expect(res.ok).toBe(false);
expect(lookupMock).not.toHaveBeenCalled();
});
it('blocks a private IP literal host without DNS', async () => {
const res = await isUrlAllowed('http://169.254.169.254/latest/meta-data/');
expect(res.ok).toBe(false);
expect(lookupMock).not.toHaveBeenCalled();
});
it('blocks a bracketed private IPv6 literal host', async () => {
const res = await isUrlAllowed('http://[::1]:8080/');
expect(res.ok).toBe(false);
expect(lookupMock).not.toHaveBeenCalled();
});
it('blocks a hostname that resolves to a private address (DNS rebinding)', async () => {
dnsResolvesTo([{ address: '10.0.0.5' }]);
const res = await isUrlAllowed('http://rebind.example.com/');
expect(res.ok).toBe(false);
expect(lookupMock).toHaveBeenCalled();
});
it('blocks when ANY resolved address is private (mixed result)', async () => {
dnsResolvesTo([{ address: '8.8.8.8' }, { address: '127.0.0.1' }]);
const res = await isUrlAllowed('http://mixed.example.com/');
expect(res.ok).toBe(false);
});
it('allows a hostname that resolves only to a public address', async () => {
dnsResolvesTo([{ address: '8.8.8.8' }]);
const res = await isUrlAllowed('https://public.example.com/mcp');
expect(res.ok).toBe(true);
});
it('blocks when the host does not resolve', async () => {
lookupMock.mockImplementation(
(_host: string, _opts: unknown, cb: (e: unknown, a: unknown) => void) => {
cb(new Error('ENOTFOUND'), undefined);
},
);
const res = await isUrlAllowed('http://nonexistent.invalid/');
expect(res.ok).toBe(false);
});
});

View File

@@ -211,3 +211,174 @@ describe('AiChatToolsService expanded toolset guardrails', () => {
expect(parsed).not.toHaveProperty('deleteComments');
});
});
/**
* JSON-string coercion for node arguments (fix 59b99dba): under OpenAI tool
* calls the model sometimes serializes `node`/`content` as a JSON STRING. The
* tools parse a string into an object before forwarding it to the client (which
* type-checks for an object), throw a documented message on invalid JSON, and
* `updatePageJson` distinguishes undefined (title-only) from object/string.
*/
describe('AiChatToolsService node-arg JSON-string coercion', () => {
// Records the positional args forwarded to each write method so we can assert
// the coerced (parsed) value reaches the client.
const patchNodeCalls: unknown[][] = [];
const insertNodeCalls: unknown[][] = [];
const updatePageJsonCalls: unknown[][] = [];
const fakeClient: Partial<DocmostClientLike> = {
patchNode: (...args: unknown[]) => {
patchNodeCalls.push(args);
return Promise.resolve({ ok: true });
},
insertNode: (...args: unknown[]) => {
insertNodeCalls.push(args);
return Promise.resolve({ ok: true });
},
updatePageJson: (...args: unknown[]) => {
updatePageJsonCalls.push(args);
return Promise.resolve({ ok: true });
},
};
const tokenServiceStub = {
generateAccessToken: jest.fn().mockResolvedValue('access-token'),
generateCollabToken: jest.fn().mockResolvedValue('collab-token'),
};
let service: AiChatToolsService;
beforeEach(() => {
patchNodeCalls.length = 0;
insertNodeCalls.length = 0;
updatePageJsonCalls.length = 0;
jest.spyOn(loader, 'loadDocmostMcp').mockResolvedValue({
DocmostClient: function () {
return fakeClient as DocmostClientLike;
} as unknown as loader.DocmostClientCtor,
});
service = new AiChatToolsService(
tokenServiceStub as never,
{} as never,
{} as never,
{} as never,
{} as never,
);
});
afterEach(() => {
jest.restoreAllMocks();
});
function buildTools() {
return service.forUser(
{ id: 'user-1', email: 'u@example.com', workspaceId: 'ws-1' } as never,
'session-1',
'ws-1',
'chat-1',
);
}
const NODE_OBJ = {
type: 'paragraph',
content: [{ type: 'text', text: 'Hello' }],
};
it('patchNode parses a JSON-string node and forwards it as an object', async () => {
const tools = await buildTools();
await tools.patchNode.execute(
{ pageId: 'p1', nodeId: 'n1', node: JSON.stringify(NODE_OBJ) } as never,
{} as never,
);
expect(patchNodeCalls).toHaveLength(1);
expect(patchNodeCalls[0]).toEqual(['p1', 'n1', NODE_OBJ]);
});
it('patchNode passes an object node through unchanged', async () => {
const tools = await buildTools();
await tools.patchNode.execute(
{ pageId: 'p1', nodeId: 'n1', node: NODE_OBJ } as never,
{} as never,
);
expect(patchNodeCalls[0]).toEqual(['p1', 'n1', NODE_OBJ]);
});
it('patchNode throws the documented message on invalid JSON string', async () => {
const tools = await buildTools();
await expect(
tools.patchNode.execute(
{ pageId: 'p1', nodeId: 'n1', node: '{not json' } as never,
{} as never,
),
).rejects.toThrow('node was a string but not valid JSON');
expect(patchNodeCalls).toHaveLength(0);
});
it('insertNode parses a JSON-string node and forwards it as an object', async () => {
const tools = await buildTools();
await tools.insertNode.execute(
{
pageId: 'p1',
node: JSON.stringify(NODE_OBJ),
position: 'append',
} as never,
{} as never,
);
expect(insertNodeCalls).toHaveLength(1);
const [pageId, node] = insertNodeCalls[0];
expect(pageId).toBe('p1');
expect(node).toEqual(NODE_OBJ);
});
it('insertNode throws the documented message on invalid JSON string', async () => {
const tools = await buildTools();
await expect(
tools.insertNode.execute(
{ pageId: 'p1', node: 'nope', position: 'append' } as never,
{} as never,
),
).rejects.toThrow('node was a string but not valid JSON');
expect(insertNodeCalls).toHaveLength(0);
});
it('updatePageJson forwards doc=undefined for a title-only update (content undefined)', async () => {
const tools = await buildTools();
await tools.updatePageJson.execute(
{ pageId: 'p1', title: 'New title' } as never,
{} as never,
);
expect(updatePageJsonCalls).toHaveLength(1);
expect(updatePageJsonCalls[0]).toEqual(['p1', undefined, 'New title']);
});
it('updatePageJson passes an object content through unchanged', async () => {
const tools = await buildTools();
const doc = { type: 'doc', content: [] };
await tools.updatePageJson.execute(
{ pageId: 'p1', content: doc } as never,
{} as never,
);
expect(updatePageJsonCalls[0]).toEqual(['p1', doc, undefined]);
});
it('updatePageJson parses a JSON-string content', async () => {
const tools = await buildTools();
const doc = { type: 'doc', content: [] };
await tools.updatePageJson.execute(
{ pageId: 'p1', content: JSON.stringify(doc) } as never,
{} as never,
);
expect(updatePageJsonCalls[0]).toEqual(['p1', doc, undefined]);
});
it('updatePageJson throws the documented message on invalid JSON string content', async () => {
const tools = await buildTools();
await expect(
tools.updatePageJson.execute(
{ pageId: 'p1', content: '{bad' } as never,
{} as never,
),
).rejects.toThrow('content was a string but not valid JSON');
expect(updatePageJsonCalls).toHaveLength(0);
});
});

View File

@@ -0,0 +1,26 @@
import { PageEmbeddingRepo } from './page-embedding.repo';
import type { KyselyDB } from '../../types/kysely.types';
/**
* Unit test for the pure access-scoping branch of searchByEmbedding: when the
* caller has NO accessible spaces (`spaceIds` empty), the method must early-
* return [] WITHOUT touching the database. We inject a db whose query builder
* throws if invoked, so any DB access fails the test.
*
* NOTE: the dimension-mixing case (filter by model_dimensions) needs a live
* pgvector-enabled Postgres and is intentionally NOT covered here — it requires
* a real DB and is out of scope for this pure unit test.
*/
describe('PageEmbeddingRepo.searchByEmbedding', () => {
it('early-returns [] for empty spaceIds without any DB call', async () => {
const throwingDb = {
selectFrom: () => {
throw new Error('DB should not be queried for empty spaceIds');
},
} as unknown as KyselyDB;
const repo = new PageEmbeddingRepo(throwingDb);
const result = await repo.searchByEmbedding('ws-1', [0.1, 0.2, 0.3], [], 10);
expect(result).toEqual([]);
});
});

View File

@@ -0,0 +1,77 @@
import { SecretBoxService } from './secret-box';
import { EnvironmentService } from '../environment/environment.service';
/**
* Unit tests for SecretBoxService: the AES-256-GCM helper that protects provider
* API keys at rest. The contract is: encrypt -> decrypt round-trips the input;
* two encryptions of the same input yield different blobs (random salt+iv) yet
* both decrypt; a tampered blob or a different APP_SECRET fails decryption with
* the recoverable "APP_SECRET may have changed" message the UI relies on.
*/
describe('SecretBoxService', () => {
// Construct a SecretBoxService whose EnvironmentService.getAppSecret returns a
// fixed 64-hex secret. Only getAppSecret is exercised, so a thin fake suffices.
function makeBox(appSecret: string): SecretBoxService {
const env = {
getAppSecret: () => appSecret,
} as unknown as EnvironmentService;
return new SecretBoxService(env);
}
const SECRET_A =
'00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff';
const SECRET_B =
'ffeeddccbbaa99887766554433221100ffeeddccbbaa99887766554433221100';
it('round-trips: decrypt(encrypt(x)) === x', () => {
const box = makeBox(SECRET_A);
const plain = 'sk-super-secret-provider-key-12345';
const blob = box.encryptSecret(plain);
expect(box.decryptSecret(blob)).toBe(plain);
});
it('produces a different blob each time, both of which decrypt', () => {
const box = makeBox(SECRET_A);
const plain = 'identical-input';
const blob1 = box.encryptSecret(plain);
const blob2 = box.encryptSecret(plain);
// Random per-record salt + iv => the ciphertext blobs must differ.
expect(blob1).not.toBe(blob2);
expect(box.decryptSecret(blob1)).toBe(plain);
expect(box.decryptSecret(blob2)).toBe(plain);
});
it('throws the recoverable error on a tampered auth tag', () => {
const box = makeBox(SECRET_A);
const blob = box.encryptSecret('tamper-me');
// Layout: base64( salt[16] | iv[12] | authTag[16] | ciphertext ). Flip a bit
// in the auth-tag region so GCM verification (decipher.final) rejects it.
const data = Buffer.from(blob, 'base64');
const authTagByteIndex = 16 + 12; // first byte of the auth tag
data[authTagByteIndex] = data[authTagByteIndex] ^ 0xff;
const tampered = data.toString('base64');
expect(() => box.decryptSecret(tampered)).toThrow(/APP_SECRET may have changed/);
});
it('throws the recoverable error on a tampered ciphertext byte', () => {
const box = makeBox(SECRET_A);
const blob = box.encryptSecret('tamper-the-body');
const data = Buffer.from(blob, 'base64');
// Last byte is part of the ciphertext; flipping it must fail GCM auth.
data[data.length - 1] = data[data.length - 1] ^ 0xff;
const tampered = data.toString('base64');
expect(() => box.decryptSecret(tampered)).toThrow(/APP_SECRET may have changed/);
});
it('throws when decrypting under a different APP_SECRET', () => {
const boxA = makeBox(SECRET_A);
const boxB = makeBox(SECRET_B);
const blob = boxA.encryptSecret('rotate-me');
// A different APP_SECRET derives a different scrypt key => GCM auth fails.
expect(() => boxB.decryptSecret(blob)).toThrow(/APP_SECRET may have changed/);
});
});