diff --git a/.env.example b/.env.example index 834ba7d7..73e57348 100644 --- a/.env.example +++ b/.env.example @@ -187,3 +187,11 @@ MCP_DOCMOST_PASSWORD= # Per-request output-token ceiling for the anonymous assistant (default: 512). # Worst-case output per accepted call = agent steps (5) × this value. # SHARE_AI_MAX_OUTPUT_TOKENS=512 +# +# Second cost backstop: a cluster-wide per-workspace rolling-DAY token budget +# (input re-sent per step + output, summed across every accepted turn). The +# hourly request cap above bounds how MANY calls run, not how expensive each is, +# so this caps the owner's actual provider bill directly. Like the request cap it +# FAILS CLOSED if Redis is unavailable (default: 1,000,000 tokens per workspace +# per rolling day). +# SHARE_AI_WORKSPACE_TOKEN_BUDGET_PER_DAY=1000000 diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c2aa9c9..64872489 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- **AI chat: the desktop app no longer freezes at 100% CPU on long agent runs.** + `useChat` re-rendered on every streamed token and `MessageItem`/`ReasoningBlock` + re-parsed the whole transcript markdown (marked + DOMPurify) on every delta, so + per-turn work grew quadratically and saturated the main thread. The stream is now + throttled (`experimental_throttle`) to ~20 Hz and each finalized message row / + markdown part / reasoning block is memoized, so a long turn no longer re-parses + already-finished content. (#182) - **Editor: caret/selection landed on the wrong line when clicking inside code blocks and footnotes.** The affected NodeViews rendered their non-editable chrome (language menu, footnotes heading, footnote number marker) before the @@ -92,6 +99,37 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 no longer froze on the previous step's authoritative usage; the current step's estimate is combined per-component with `max`, so the count rises smoothly and never jumps backwards. (#163) +- **AI chat: "New chat" during a streaming first turn now resets the whole + chat, not just the role badge.** Starting a new chat mid-stream cleared the + header but left the in-flight turn's messages behind, so the fresh chat opened + pre-populated with the previous conversation; it now fully resets. (#161) +- **AI chat: a dropped tool argument now yields an actionable error.** When the + model omitted a required parameter (typically `pageId`) in a parallel/batch + tool call, the assistant forwarded zod's raw "expected string, received + undefined" text; tool inputs now return a message naming each missing/invalid + parameter (the JSON Schema contract is unchanged and nothing is backfilled). + (#190) +- **Page move: cycle checks are now atomic and depth-bounded.** Moving a page + under one of its own descendants is rejected in the same transaction as the + update (closing a TOCTOU window where two concurrent A→B / B→A moves could + form a cycle), and the recursive tree-traversal CTEs carry a cycle/depth guard + so a pre-existing cycle can no longer spin a query. (#207) +- **Page/editor robustness batch.** Duplicating a page now copies shared + attachments for every referencing page (not just the first); colliding block + ids are de-duplicated on import/normalize so MCP addressed edits can't hit the + wrong node; transient collab store failures are retried so autosave edits + aren't lost; and an out-of-order tree move no longer drops the moved subtree. + (#206) + +### Security + +- **Public share AI: per-workspace rolling-day token budget.** The anonymous + share assistant now caps a workspace's actual token spend (input + output, + summed across every accepted turn) over a trailing day, on top of the hourly + request cap — so a caller who evades the per-IP throttle still cannot run up + the owner's provider bill without bound. Cluster-wide via Redis and FAILS + CLOSED if Redis is down; default 1,000,000 tokens/day, overridable via + `SHARE_AI_WORKSPACE_TOKEN_BUDGET_PER_DAY`. (#159) ## [0.93.0] - 2026-06-21 diff --git a/apps/client/src/features/ai-chat/components/ai-chat-window.tsx b/apps/client/src/features/ai-chat/components/ai-chat-window.tsx index de0b9923..b3d003db 100644 --- a/apps/client/src/features/ai-chat/components/ai-chat-window.tsx +++ b/apps/client/src/features/ai-chat/components/ai-chat-window.tsx @@ -193,6 +193,7 @@ export default function AiChatWindow() { const { threadKey, waitingForHistory, + startFreshThread, onTurnFinished, onServerChatId, cancelPendingAdoption, @@ -215,12 +216,25 @@ export default function AiChatWindow() { // just-failed chat after they chose a fresh one. const startNewChat = useCallback((): void => { cancelPendingAdoption(); + // Force a fresh, empty thread UNCONDITIONALLY (#161). Pressing "New chat" + // while a brand-new chat's first turn is still streaming leaves activeChatId + // null (the real id is adopted only at turn end), so setActiveChatId(null) + // alone is a no-op and the reconciler never remounts — the chat/stream/history + // would persist and only the role badge would drop. This always remounts the + // thread into a clean new chat. + startFreshThread(); setActiveChatId(null); setHistoryOpen(false); setDraft(""); // Default the picker back to "Universal assistant" for the fresh chat. setSelectedRoleId(null); - }, [cancelPendingAdoption, setActiveChatId, setDraft, setSelectedRoleId]); + }, [ + cancelPendingAdoption, + startFreshThread, + setActiveChatId, + setDraft, + setSelectedRoleId, + ]); const selectChat = useCallback( (chatId: string): void => { @@ -622,6 +636,7 @@ export default function AiChatWindow() { ) : ( void; + * undefined on a failed turn — see adopt-chat-id.ts for the full #137 design. + * `finishingThreadKey` (this thread's mount key) lets the session ignore a turn + * finishing on a thread already abandoned by New chat mid-stream (#161). */ + onTurnFinished: (serverChatId?: string, finishingThreadKey?: string) => void; /** Called EARLY (at the stream's `start` chunk) with the authoritative server * chat id streamed on the assistant message metadata, so a brand-new chat * adopts its real id WHILE the first turn is still streaming (#174 — makes the @@ -109,6 +124,7 @@ function rowToUiMessage(row: IAiChatMessageRow): UIMessage { */ export default function ChatThread({ chatId, + threadKey, initialRows, openPage, roleId, @@ -246,6 +262,8 @@ export default function ChatThread({ id: chatStoreId, messages: initialMessages, transport, + // See STREAM_THROTTLE_MS — bounds re-render/markdown-reparse frequency. + experimental_throttle: STREAM_THROTTLE_MS, // `onFinish` (ai@6 useChat) fires from a `finally` on EVERY terminal outcome // — success, user Stop/abort (`isAbort`), network drop (`isDisconnect`), and // stream error (`isError`). Keep calling `onTurnFinished()` on all of them @@ -257,8 +275,10 @@ export default function ChatThread({ onFinish: ({ message, isAbort, isDisconnect, isError }) => { // Forward the authoritative server chatId (streamed on the assistant // message metadata) so the parent adopts the REAL created chat id for a new - // chat — see adopt-chat-id.ts for the full #137 design. - onTurnFinished(extractServerChatId(message)); + // chat — see adopt-chat-id.ts for the full #137 design. `threadKey` lets the + // session ignore this finish if it belongs to a thread abandoned by New chat + // mid-stream (#161). + onTurnFinished(extractServerChatId(message), threadKey); // Show a neutral "stopped" marker for an aborted turn; the red error banner // (via `error`) already covers isError, and a clean finish clears any marker. if (isError) setStopNotice(null); @@ -279,7 +299,7 @@ export default function ChatThread({ // Surface the raw failure in the browser console (devtools) for debugging; // the UI separately shows a friendly classified banner (see errorView). console.error("AI chat stream error:", streamError); - onTurnFinished(); + onTurnFinished(undefined, threadKey); }, }); diff --git a/apps/client/src/features/ai-chat/components/message-item-memo.test.tsx b/apps/client/src/features/ai-chat/components/message-item-memo.test.tsx new file mode 100644 index 00000000..06c0c5fb --- /dev/null +++ b/apps/client/src/features/ai-chat/components/message-item-memo.test.tsx @@ -0,0 +1,81 @@ +import { describe, expect, it, vi } from "vitest"; +import { render } from "@testing-library/react"; +import { MantineProvider } from "@mantine/core"; +import type { UIMessage } from "@ai-sdk/react"; + +// Stub react-i18next (the component reads `useTranslation`). Mirrors the stub in +// reasoning-block.test.tsx. +vi.mock("react-i18next", () => ({ + useTranslation: () => ({ t: (key: string) => key }), +})); + +// Spy on `renderChatMarkdown` so we can count parse calls per text. We keep every +// OTHER named export of markdown.ts intact via `importActual`, and override only +// `renderChatMarkdown` with a `vi.fn()` that returns simple HTML so the component +// still renders. This is the seam that proves the MarkdownPart memo works: a +// finalized text part must NOT be re-parsed on a later streamed delta. +// `vi.hoisted` so the spy exists when the hoisted `vi.mock` factory runs. +const { renderChatMarkdownSpy } = vi.hoisted(() => ({ + renderChatMarkdownSpy: vi.fn((text: string) => `

${text}

`), +})); +vi.mock("@/features/ai-chat/utils/markdown.ts", async () => { + const actual = await vi.importActual< + typeof import("@/features/ai-chat/utils/markdown.ts") + >("@/features/ai-chat/utils/markdown.ts"); + return { ...actual, renderChatMarkdown: renderChatMarkdownSpy }; +}); + +import MessageItem from "./message-item"; + +// matchMedia (read by MantineProvider) is stubbed globally in vitest.setup.ts. + +const msg = (parts: UIMessage["parts"]): UIMessage => + ({ id: "m1", role: "assistant", parts }) as UIMessage; + +const renderRow = (message: UIMessage) => + render( + + + , + ); + +/** Count how many spy calls parsed exactly `text` (filtering by the first arg). */ +const callsFor = (text: string) => + renderChatMarkdownSpy.mock.calls.filter((c) => c[0] === text).length; + +describe("MessageItem markdown memoization", () => { + it("does not re-parse finalized text parts when only a tail part grows", () => { + renderChatMarkdownSpy.mockClear(); + + // Two finalized text parts. + const first = msg([ + { type: "text", text: "alpha" }, + { type: "text", text: "beta" }, + ]); + const { rerender } = renderRow(first); + + // Both finalized parts parsed exactly once on the initial render. + expect(callsFor("alpha")).toBe(1); + expect(callsFor("beta")).toBe(1); + + // A streamed delta: a NEW message object where only a third tail part grows; + // the first two parts' text is byte-identical. + const next = msg([ + { type: "text", text: "alpha" }, + { type: "text", text: "beta" }, + { type: "text", text: "gamm" }, + ]); + rerender( + + + , + ); + + // The finalized parts hit the MarkdownPart memo: still parsed at most once + // each across BOTH renders (the resilient invariant). The only new parse is + // for the changed/added tail part. + expect(callsFor("alpha")).toBe(1); + expect(callsFor("beta")).toBe(1); + expect(callsFor("gamm")).toBe(1); + }); +}); diff --git a/apps/client/src/features/ai-chat/components/message-item.test.ts b/apps/client/src/features/ai-chat/components/message-item.test.ts new file mode 100644 index 00000000..dfed46f4 --- /dev/null +++ b/apps/client/src/features/ai-chat/components/message-item.test.ts @@ -0,0 +1,73 @@ +import { describe, expect, it, vi } from "vitest"; +import type { UIMessage } from "@ai-sdk/react"; + +// Stub react-i18next: importing the component module pulls in `useTranslation`, +// and we only exercise the pure `arePropsEqual` comparator (no rendering), so a +// minimal `t` that echoes the key is enough. Mirrors the stub in +// reasoning-block.test.tsx. +vi.mock("react-i18next", () => ({ + useTranslation: () => ({ t: (key: string) => key }), +})); + +import { arePropsEqual } from "./message-item"; + +/** + * Tests for `arePropsEqual`, the `React.memo` comparator for MessageItem. It must + * return false on any visible prop/content change (so the row re-renders) and + * true when nothing visible changed (so a finalized row is skipped). A FIXED + * message id is used so a content-identical clone yields an equal signature. + */ +const msg = (parts: UIMessage["parts"]): UIMessage => + ({ id: "m1", role: "assistant", parts }) as UIMessage; + +const props = ( + message: UIMessage, + over: Record = {}, +) => ({ + message, + showCitations: true, + neutralizeInternalLinks: false, + assistantName: "AI", + ...over, +}); + +describe("arePropsEqual", () => { + it("returns false when showCitations differs", () => { + const m = msg([{ type: "text", text: "answer" }]); + expect( + arePropsEqual(props(m), props(m, { showCitations: false })), + ).toBe(false); + }); + + it("returns false when neutralizeInternalLinks differs", () => { + const m = msg([{ type: "text", text: "answer" }]); + expect( + arePropsEqual(props(m), props(m, { neutralizeInternalLinks: true })), + ).toBe(false); + }); + + it("returns false when assistantName differs", () => { + const m = msg([{ type: "text", text: "answer" }]); + expect( + arePropsEqual(props(m), props(m, { assistantName: "Other" })), + ).toBe(false); + }); + + it("returns true on the identity fast path (same message object, equal props)", () => { + const m = msg([{ type: "text", text: "answer" }]); + expect(arePropsEqual(props(m), props(m))).toBe(true); + }); + + it("returns true for the same content in a different message object", () => { + const a = msg([{ type: "text", text: "answer" }]); + const b = msg([{ type: "text", text: "answer" }]); + expect(a).not.toBe(b); + expect(arePropsEqual(props(a), props(b))).toBe(true); + }); + + it("returns false when content changed in a different message object", () => { + const a = msg([{ type: "text", text: "answer" }]); + const b = msg([{ type: "text", text: "answer grown" }]); + expect(arePropsEqual(props(a), props(b))).toBe(false); + }); +}); diff --git a/apps/client/src/features/ai-chat/components/message-item.tsx b/apps/client/src/features/ai-chat/components/message-item.tsx index 6436b4d6..6bd4374d 100644 --- a/apps/client/src/features/ai-chat/components/message-item.tsx +++ b/apps/client/src/features/ai-chat/components/message-item.tsx @@ -1,3 +1,4 @@ +import { memo } from "react"; import { Box, Text } from "@mantine/core"; import { useTranslation } from "react-i18next"; import type { UIMessage } from "@ai-sdk/react"; @@ -10,6 +11,7 @@ import { assistantMessageHasVisibleContent } from "@/features/ai-chat/utils/mess import { renderChatMarkdown } from "@/features/ai-chat/utils/markdown.ts"; import { resolveAssistantName } from "@/features/ai-chat/utils/assistant-name.ts"; import { reasoningTokensForPart } from "@/features/ai-chat/utils/reasoning-tokens.ts"; +import { messageSignature } from "@/features/ai-chat/utils/message-signature.ts"; import { describeChatError } from "@/features/ai-chat/utils/error-message.ts"; import classes from "@/features/ai-chat/components/ai-chat.module.css"; @@ -34,6 +36,39 @@ interface MessageItemProps { assistantName?: string; } +/** + * One assistant text part rendered as sanitized markdown. Memoized on its inputs + * so a finalized text part is NOT re-parsed on every streamed delta: during a + * turn only the actively-growing tail part changes its `text`, so every earlier + * part hits the memo and skips the expensive marked + DOMPurify pass. Props are + * primitives, so React.memo's default shallow compare is exactly right (the + * `text` string is compared by value). + */ +const MarkdownPart = memo(function MarkdownPart({ + text, + neutralizeInternalLinks, +}: { + text: string; + neutralizeInternalLinks: boolean; +}) { + const html = renderChatMarkdown(text, { neutralizeInternalLinks }); + if (html) { + return ( +
+ ); + } + // Fallback when markdown could not render synchronously: raw text. + return ( + + {text} + + ); +}); + /** * Render a single UIMessage by iterating its `parts`: * - `text` parts -> sanitized markdown. @@ -41,12 +76,13 @@ interface MessageItemProps { * Other part kinds (reasoning, sources, files, step-start) are ignored for v1. * User messages render their text as a right-aligned plain bubble. * - * This component is intentionally NOT memoized: `useChat` replaces the streaming - * assistant message with a freshly cloned object on every streamed delta, so the - * `message` prop identity (and its `parts`) changes each tick. Re-rendering the - * text parts on each delta is what makes the answer stream in progressively. + * This component is memoized (see `arePropsEqual` at the bottom) on a cheap + * per-message content signature: the streaming TAIL message's signature changes + * on each delta so it still re-renders and streams in, while finalized rows are + * skipped. Each text part's markdown is itself memoized via `MarkdownPart`, so a + * long turn no longer re-parses the whole transcript on every token. */ -export default function MessageItem({ +function MessageItem({ message, showCitations = true, neutralizeInternalLinks = false, @@ -109,24 +145,12 @@ export default function MessageItem({ // starts with an empty text part before the first token arrives); the // typing indicator covers that gap until real content streams in. if (!part.text.trim()) return null; - const html = renderChatMarkdown(part.text, { - neutralizeInternalLinks, - }); - if (html) { - return ( -
- ); - } - // Fallback when markdown could not render synchronously: raw text. return ( - - {part.text} - + ); } @@ -177,3 +201,26 @@ export default function MessageItem({ ); } + +/** Skip re-rendering a message whose visible content is unchanged. The streaming + * TAIL message gets a fresh object whose signature changes each delta, so it + * still re-renders and streams in; every FINALIZED message is skipped, turning a + * per-token whole-transcript re-render into a tail-only one. */ +export function arePropsEqual( + prev: MessageItemProps, + next: MessageItemProps, +): boolean { + if ( + prev.showCitations !== next.showCitations || + prev.neutralizeInternalLinks !== next.neutralizeInternalLinks || + prev.assistantName !== next.assistantName + ) { + return false; + } + // Fast path: identical message object (finalized rows keep their identity + // across deltas) — skip without building signatures. + if (prev.message === next.message) return true; + return messageSignature(prev.message) === messageSignature(next.message); +} + +export default memo(MessageItem, arePropsEqual); diff --git a/apps/client/src/features/ai-chat/components/reasoning-block.tsx b/apps/client/src/features/ai-chat/components/reasoning-block.tsx index de35229a..cb3335f4 100644 --- a/apps/client/src/features/ai-chat/components/reasoning-block.tsx +++ b/apps/client/src/features/ai-chat/components/reasoning-block.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { memo, useMemo, useState } from "react"; import { Box, Collapse, Group, Text, UnstyledButton } from "@mantine/core"; import { IconChevronDown } from "@tabler/icons-react"; import { useTranslation } from "react-i18next"; @@ -27,19 +27,23 @@ interface ReasoningBlockProps { * Providers that don't stream reasoning TEXT still render this block from the * authoritative count alone (header only, empty body) so the cost is visible. */ -export default function ReasoningBlock({ text, tokens }: ReasoningBlockProps) { +function ReasoningBlock({ text, tokens }: ReasoningBlockProps) { const { t } = useTranslation(); const [open, setOpen] = useState(false); // Authoritative count wins; otherwise estimate live from the streamed text. const count = tokens && tokens > 0 ? tokens : estimateTokens(text); const trimmed = text.trim(); - // Collapse the blank-line gaps the model emits between every list item / - // paragraph so the reasoning renders compactly (tight lists, joined - // paragraphs) — see collapseBlankLines. ONLY here, not in the normal answer. - const html = trimmed - ? renderChatMarkdown(collapseBlankLines(trimmed), {}) - : ""; + // Memoize the markdown render so toggling `open` (or a parent re-render caused + // by an unrelated streamed delta) does not re-parse the reasoning text; it + // recomputes only when the reasoning text itself changes (while it streams in). + // collapseBlankLines collapses the blank-line gaps the model emits between every + // list item / paragraph so the reasoning renders compactly (tight lists, joined + // paragraphs) — ONLY here, not in the normal answer. + const html = useMemo( + () => (trimmed ? renderChatMarkdown(collapseBlankLines(trimmed), {}) : ""), + [trimmed], + ); return ( @@ -87,3 +91,8 @@ export default function ReasoningBlock({ text, tokens }: ReasoningBlockProps) { ); } + +// Memoized: re-renders only when `text`/`tokens` change (primitive props, default +// shallow compare), so a parent re-render during streaming of OTHER content does +// not re-run the markdown parse for an already-finalized reasoning block. +export default memo(ReasoningBlock); diff --git a/apps/client/src/features/ai-chat/hooks/use-chat-session.test.tsx b/apps/client/src/features/ai-chat/hooks/use-chat-session.test.tsx index 0080cc80..39a72628 100644 --- a/apps/client/src/features/ai-chat/hooks/use-chat-session.test.tsx +++ b/apps/client/src/features/ai-chat/hooks/use-chat-session.test.tsx @@ -1,5 +1,5 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; -import { renderHook } from "@testing-library/react"; +import { renderHook, act } from "@testing-library/react"; import { useChatSession } from "./use-chat-session"; import type { UseChatSessionOptions } from "./use-chat-session"; @@ -227,6 +227,50 @@ describe("useChatSession", () => { expect(result.current.threadKey).toBe("C"); }); + it("#161: New chat during a streaming first turn forces a fresh thread (remount), not just a no-op", () => { + // Brand-new chat whose first turn is still streaming: the id is adopted only + // at turn end, so activeChatId AND thread.chatId are both null. Pressing "New + // chat" must still remount to a clean thread even though the atom is unchanged + // — the render-phase reconciler (null === null) would otherwise do nothing, + // leaving the old chat/stream/history in place (the bug: only the role badge + // dropped). + const { result } = setup({ activeChatId: null, chats: { items: [] } }); + const keyBefore = result.current.threadKey; + act(() => result.current.startFreshThread()); + expect(result.current.threadKey).not.toBe(keyBefore); + }); + + it("#161: an abandoned thread's late onTurnFinished does NOT adopt its chat (thread-aware guard)", () => { + // New chat mid-stream remounts to a fresh thread, but @ai-sdk/react does not + // abort the abandoned stream on unmount: its onFinish still fires later with + // the real server id, tagged with the OLD (abandoned) mount key. That must not + // adopt — it would yank the user back into the chat they just left. + const { result, setActiveChatId, onInvalidateChatList } = setup({ + activeChatId: null, + chats: { items: [] }, + }); + const abandonedKey = result.current.threadKey; + act(() => result.current.startFreshThread()); + expect(result.current.threadKey).not.toBe(abandonedKey); + // The abandoned turn finishes in the background, streaming its real id "A". + result.current.onTurnFinished("A", abandonedKey); + expect(setActiveChatId).not.toHaveBeenCalledWith("A"); + // It still refreshes the chat list so the left-behind chat shows in history. + expect(onInvalidateChatList).toHaveBeenCalled(); + }); + + it("#161: a turn finishing on the CURRENT thread still adopts (guard is key-scoped, not blanket)", () => { + // The happy path must keep working: onTurnFinished tagged with the mounted + // thread's own key adopts in place as before. + const { result, setActiveChatId } = setup({ + activeChatId: null, + chats: { items: [] }, + }); + const currentKey = result.current.threadKey; + result.current.onTurnFinished("A", currentKey); + expect(setActiveChatId).toHaveBeenCalledWith("A"); + }); + it("waitingForHistory gates the loader only while opening an unloaded existing chat", () => { // Open an existing chat whose history is still loading => loader on. const { result, rerender } = setup({ diff --git a/apps/client/src/features/ai-chat/hooks/use-chat-session.ts b/apps/client/src/features/ai-chat/hooks/use-chat-session.ts index d21ebd11..14420ad0 100644 --- a/apps/client/src/features/ai-chat/hooks/use-chat-session.ts +++ b/apps/client/src/features/ai-chat/hooks/use-chat-session.ts @@ -31,9 +31,19 @@ export interface UseChatSessionResult { threadKey: string; /** Show the history loader instead of the live thread. */ waitingForHistory: boolean; + /** Force a brand-new, empty thread (new mount key, no chat id) UNCONDITIONALLY, + * even when `activeChatId` is unchanged. The window calls this from + * startNewChat so "New chat" pressed WHILE a brand-new chat's first turn is + * still streaming (activeChatId still null, nothing to diverge) actually + * resets the chat instead of only dropping the role badge (#161). */ + startFreshThread: () => void; /** Call when a turn finishes; `serverChatId` is the authoritative streamed id - * (undefined on a failed turn). Handles new-chat id adoption + invalidations. */ - onTurnFinished: (serverChatId?: string) => void; + * (undefined on a failed turn). `finishingThreadKey` is the mount key of the + * thread that produced the turn (omit => "current thread", back-compatible): + * a turn ABANDONED by New chat mid-stream still fires this after its thread + * unmounted, so adoption is gated to the still-mounted thread (#161). Handles + * new-chat id adoption + invalidations. */ + onTurnFinished: (serverChatId?: string, finishingThreadKey?: string) => void; /** Call EARLY (at the stream's `start` chunk) with the authoritative streamed * chat id so a brand-new chat adopts its real id WHILE its first turn is still * streaming — making `activeChatId`-gated affordances (e.g. the Copy/export @@ -98,6 +108,15 @@ export function useChatSession( : switchThread(activeChatId), ); + // Live mirror of the mounted thread's mount key, read by onTurnFinished to tell + // the CURRENT thread from one ABANDONED by New chat mid-stream. @ai-sdk/react + // does not abort a stream on unmount and proxies callbacks through a ref, so an + // abandoned turn's onFinish/onError still fires AFTER its ChatThread unmounted; + // matching its key against this ref keeps that late finish from adopting the + // abandoned chat and yanking the user out of the fresh chat they opened (#161). + const threadKeyRef = useRef(thread.key); + threadKeyRef.current = thread.key; + // Error-path fallback for new-chat id adoption. When a brand-new chat's first // turn errors BEFORE the server's `start` chunk, no authoritative chatId ever // reaches the client, so the primary metadata adoption cannot run. We then ARM @@ -115,7 +134,23 @@ export function useChatSession( // yet) we adopt the server's AUTHORITATIVE streamed id (never the newest in the // list, which races a second tab — #137; see adopt-chat-id.ts). const onTurnFinished = useCallback( - (serverChatId?: string) => { + (serverChatId?: string, finishingThreadKey?: string) => { + // Thread-aware guard (#161). A turn ABANDONED by "New chat" mid-stream still + // fires onFinish/onError after its ChatThread unmounted (@ai-sdk/react does + // not abort on unmount and proxies callbacks through a ref). If that late + // finish ran the adoption path it would set activeChatId to the abandoned + // chat's real id and yank the user out of the fresh chat they just opened. + // So adopt / arm the fallback ONLY for the still-mounted thread; an + // abandoned one merely refreshes the chat list (so the left-behind chat + // surfaces in history) and does nothing else. A missing key (undefined) + // means "current thread" — keeps old call sites/tests working. + if ( + finishingThreadKey !== undefined && + finishingThreadKey !== threadKeyRef.current + ) { + onInvalidateChatList(); + return; + } // Read the live id from the ref, not the closure: on a failed turn this can // run twice in one turn (onFinish + onError) before any re-render, and the // primary branch below updates the ref so the second call sees the adopted id. @@ -258,9 +293,28 @@ export function useChatSession( pendingNewChatRef.current = null; }, []); + // Force a fresh, empty thread regardless of `activeChatId` (#161). The render- + // phase reconciler only remounts when activeChatId diverges from thread.chatId, + // so "New chat" pressed while a brand-new chat's first turn is still streaming + // (activeChatId AND thread.chatId both null — the real id is adopted only at the + // end of the turn) is a no-op for it and the abandoned thread/stream/history + // would persist. Dispatching reconcile with a fresh key and chatId:null here + // always produces a new mount key, so React remounts ChatThread (a clean useChat + // store) and the post-dispatch state (activeChatId null === thread.chatId null) + // keeps the reconciler from interfering. Also disarms any pending fallback. + const startFreshThread = useCallback(() => { + pendingNewChatRef.current = null; + dispatch({ + type: "reconcile", + chatId: null, + newKey: `new-${generateId()}`, + }); + }, []); + return { threadKey: thread.key, waitingForHistory, + startFreshThread, onTurnFinished, onServerChatId, cancelPendingAdoption, diff --git a/apps/client/src/features/ai-chat/utils/message-signature.test.ts b/apps/client/src/features/ai-chat/utils/message-signature.test.ts new file mode 100644 index 00000000..7c4f7a70 --- /dev/null +++ b/apps/client/src/features/ai-chat/utils/message-signature.test.ts @@ -0,0 +1,241 @@ +import { describe, expect, it } from "vitest"; +import type { UIMessage } from "@ai-sdk/react"; +import { messageSignature } from "@/features/ai-chat/utils/message-signature.ts"; + +/** + * Pure-helper tests for `messageSignature`, the cheap per-message content + * signature that drives MessageItem's memo (a streaming row's signature must + * change on every delta so it re-renders, while a finalized row's stays stable + * so it is skipped). Each test exercises ONE change signal and asserts it flips + * the signature; a content-identical clone must keep an EQUAL signature. + * + * The signature embeds `message.id` and `message.role`, so the `msg` factory + * uses a FIXED id/role here (not `Math.random()`): otherwise two messages with + * identical content would get different signatures and the negative case would + * be impossible to express. + */ +const msg = ( + parts: UIMessage["parts"], + metadata?: unknown, +): UIMessage => + ({ + id: "m1", + role: "assistant", + parts, + metadata, + }) as UIMessage; + +describe("messageSignature", () => { + it("changes when a text part grows", () => { + const before = msg([{ type: "text", text: "alpha" }]); + const after = msg([{ type: "text", text: "alpha beta" }]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("changes when a new part is appended", () => { + const before = msg([{ type: "text", text: "alpha" }]); + const after = msg([ + { type: "text", text: "alpha" }, + { type: "text", text: "beta" }, + ]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("changes when a part's state flips", () => { + const before = msg([ + { type: "tool-getPage", state: "input-streaming" } as never, + ]); + const after = msg([ + { type: "tool-getPage", state: "output-available" } as never, + ]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("changes when a tool part gains an output", () => { + const before = msg([ + { type: "tool-getPage", state: "output-available" } as never, + ]); + const after = msg([ + { + type: "tool-getPage", + state: "output-available", + output: { ok: true }, + } as never, + ]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("changes when a part gains an errorText", () => { + const before = msg([ + { type: "tool-getPage", state: "output-error" } as never, + ]); + const after = msg([ + { + type: "tool-getPage", + state: "output-error", + errorText: "boom", + } as never, + ]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("changes when usage.reasoningTokens arrives on finish-step (text/state already frozen)", () => { + // The specifically-commented edge case: the authoritative turn total lands on + // the final finish-step AFTER the reasoning text length and state are frozen. + // Only the token count appears between these two snapshots, so the signature + // MUST still flip — otherwise the "Thinking · N tokens" header would never + // snap from the live estimate to the exact figure. + const before = msg([ + { type: "reasoning", text: "thinking", state: "done" } as never, + ]); + const after = msg( + [{ type: "reasoning", text: "thinking", state: "done" } as never], + { usage: { reasoningTokens: 42 } }, + ); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("changes when metadata.error appears", () => { + const before = msg([{ type: "text", text: "answer" }]); + const after = msg([{ type: "text", text: "answer" }], { error: "boom" }); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("changes when metadata.finishReason changes (e.g. to 'aborted')", () => { + const before = msg([{ type: "text", text: "answer" }], { + finishReason: "stop", + }); + const after = msg([{ type: "text", text: "answer" }], { + finishReason: "aborted", + }); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("is UNCHANGED for a content-identical clone (different object, same values)", () => { + // A finalized row that is re-created as a fresh object (different parts array + // by reference, same parts by value) must keep an EQUAL signature, so the + // memo skips re-rendering it. + const a = msg([ + { type: "text", text: "alpha" }, + { type: "tool-getPage", state: "output-available", output: { ok: true } } as never, + ]); + const b = msg([ + { type: "text", text: "alpha" }, + { type: "tool-getPage", state: "output-available", output: { ok: true } } as never, + ]); + expect(a).not.toBe(b); + expect(messageSignature(a)).toBe(messageSignature(b)); + }); +}); + +/** + * Per-part-kind coupling guard for the load-bearing invariant documented at the + * top of message-signature.ts: the signature MUST sample every VISIBLE field the + * MessageItem render body draws, or the memo freezes a stale row. This is an + * executable lock for the part kinds rendered TODAY — read alongside + * `MessageItem` (message-item.tsx) and the `assistantMessageHasVisibleContent` + * helper (message-content.ts), which "mirrors MessageItem's render decisions + * EXACTLY". For each kind, mutating a field the render body DRAWS must flip the + * signature. If a new visible field is rendered without being added here AND to + * the signature, the corresponding assertion below should fail — that is the + * guard. (This intentionally stops short of the render-descriptor refactor: + * adding a part kind or a visible field still requires a human to extend both + * the signature and this block.) + */ +describe("messageSignature ↔ render coupling (per visible part kind)", () => { + describe("text part — render draws part.text (MarkdownPart text={part.text})", () => { + it("flips when the visible text changes", () => { + // Streaming is append-only, so the visible text only grows; the signature + // samples its length, so the growth is the change signal. + const before = msg([{ type: "text", text: "answer" }]); + const after = msg([{ type: "text", text: "answer extended" }]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + }); + + describe("reasoning part — render draws text + tokens (ReasoningBlock)", () => { + it("flips when the visible reasoning text changes", () => { + const before = msg([ + { type: "reasoning", text: "think", state: "streaming" } as never, + ]); + const after = msg([ + { type: "reasoning", text: "think harder", state: "streaming" } as never, + ]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("flips when the visible token count (metadata.usage.reasoningTokens) lands", () => { + // The header's "Thinking · N tokens" reads reasoningTokensForPart, fed by + // metadata.usage.reasoningTokens — a VISIBLE field that arrives on the final + // finish-step after text length and state are frozen. + const before = msg([ + { type: "reasoning", text: "think", state: "done" } as never, + ]); + const after = msg( + [{ type: "reasoning", text: "think", state: "done" } as never], + { usage: { reasoningTokens: 99 } }, + ); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + }); + + describe("tool-* part — render draws state/errorText/citations (ToolCallCard)", () => { + it("flips when the run state changes (running ↔ done icon + label)", () => { + // toolRunState(part.state) selects the spinner/check/error icon. + const before = msg([ + { type: "tool-getPage", state: "input-available" } as never, + ]); + const after = msg([ + { type: "tool-getPage", state: "output-available" } as never, + ]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("flips when output arrives (drives the rendered citation links)", () => { + // toolCitations reads part.output to render the "/p/{id}" anchors. + const before = msg([ + { type: "tool-getPage", state: "output-available" } as never, + ]); + const after = msg([ + { + type: "tool-getPage", + state: "output-available", + output: { id: "page-1", title: "Doc" }, + } as never, + ]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("flips when errorText appears (the visible red error detail line)", () => { + const before = msg([ + { type: "tool-getPage", state: "output-error" } as never, + ]); + const after = msg([ + { + type: "tool-getPage", + state: "output-error", + errorText: "permission denied", + } as never, + ]); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + }); + + describe("metadata banners — render draws error / aborted notices", () => { + it("flips when metadata.error appears (ChatErrorAlert banner)", () => { + const before = msg([{ type: "text", text: "answer" }]); + const after = msg([{ type: "text", text: "answer" }], { error: "boom" }); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + + it("flips when metadata.finishReason becomes 'aborted' (ChatStoppedNotice)", () => { + const before = msg([{ type: "text", text: "answer" }], { + finishReason: "stop", + }); + const after = msg([{ type: "text", text: "answer" }], { + finishReason: "aborted", + }); + expect(messageSignature(before)).not.toBe(messageSignature(after)); + }); + }); +}); diff --git a/apps/client/src/features/ai-chat/utils/message-signature.ts b/apps/client/src/features/ai-chat/utils/message-signature.ts new file mode 100644 index 00000000..84c37919 --- /dev/null +++ b/apps/client/src/features/ai-chat/utils/message-signature.ts @@ -0,0 +1,44 @@ +import type { UIMessage } from "@ai-sdk/react"; + +/** Cheap content signature for one message: changes iff something VISIBLE in the + * row changed. Streaming is APPEND-ONLY (text parts only grow, parts are only + * appended, a tool/text part flips state once), so a per-part [type, text + * length, state, error/output presence] tuple + the persisted metadata + * (error/finishReason) is a sufficient change signal without comparing full + * strings on every delta. WARNING — load-bearing for the MessageItem memo: + * if a future part kind's VISIBLE content can change WITHOUT changing [type, + * text length, state, error/output presence] (e.g. a tool that streams + * `preliminary` output, or a client-side regenerate that edits a finalized + * row in place), extend this signature or the memo will freeze a stale row. */ +export function messageSignature(message: UIMessage): string { + const parts = message.parts + .map((p) => { + const any = p as { + type: string; + text?: string; + state?: string; + errorText?: string; + output?: unknown; + }; + return [ + any.type, + any.text?.length ?? 0, + any.state ?? "", + any.errorText ? 1 : 0, + any.output !== undefined ? 1 : 0, + ].join(":"); + }) + .join("|"); + const meta = message.metadata as + | { error?: string; finishReason?: string; usage?: { reasoningTokens?: number } } + | undefined; + // `usage.reasoningTokens` is neither append-only nor part-bound: the authoritative + // turn total arrives on the final `finish-step` AFTER the reasoning text length and + // state are already frozen. Without it in the signature the row's signature would be + // unchanged at that point and the re-render skipped, so the "Thinking · N tokens" + // header (reasoningTokensForPart) would keep the live estimate instead of snapping + // to the exact figure. + return `${message.id}#${message.role}#${parts}#${meta?.error ?? ""}#${ + meta?.finishReason ?? "" + }#${meta?.usage?.reasoningTokens ?? ""}`; +} diff --git a/apps/client/src/features/page/tree/model/tree-model.test.ts b/apps/client/src/features/page/tree/model/tree-model.test.ts index a2dbd6b9..01682e2d 100644 --- a/apps/client/src/features/page/tree/model/tree-model.test.ts +++ b/apps/client/src/features/page/tree/model/tree-model.test.ts @@ -752,6 +752,27 @@ describe("treeModel.placeByPosition", () => { }); expect(t.map((n) => n.id)).toEqual(["r1", "child", "r2", "rp"]); }); + + it("returns same reference (no-op) when the destination parent is inside the source's own subtree (#206 ui-state-races-1)", () => { + // Moving `a` under its own descendant `b` is a cycle. Without the guard, + // remove(a) drops b too and insertByPosition can't re-place a -> the whole + // subtree silently vanishes. The guard refuses the move (same reference). + const cyclic: P[] = [ + { + id: "a", + name: "A", + position: "a0", + children: [{ id: "b", name: "B", position: "a1" }], + }, + ]; + const t = treeModel.placeByPosition(cyclic, "a", { + parentId: "b", + position: "a5", + }); + expect(t).toBe(cyclic); + expect(treeModel.find(t, "a")).not.toBeNull(); + expect(treeModel.find(t, "b")).not.toBeNull(); + }); }); describe("treeModel.move", () => { diff --git a/apps/client/src/features/page/tree/model/tree-model.ts b/apps/client/src/features/page/tree/model/tree-model.ts index aa13d8b4..bda4a74b 100644 --- a/apps/client/src/features/page/tree/model/tree-model.ts +++ b/apps/client/src/features/page/tree/model/tree-model.ts @@ -294,6 +294,20 @@ export const treeModel = { const source = treeModel.find(tree, sourceId); if (!source) return tree; if (to.parentId !== null && !treeModel.find(tree, to.parentId)) return tree; + // Cycle guard, mirroring `move`'s `isDescendant` check (#206 ui-state-races-1). + // If the destination parent is INSIDE the moved node's own subtree (reachable + // when server-authoritative move events arrive out of order — e.g. X moved + // under Y, then Y under X, but on this receiver Y is still inside X), then + // `remove(sourceId)` would drop the future parent along with the whole subtree + // and `insertByPosition` could not find it again — the node and ALL its + // descendants would silently vanish. Refuse the move and return the same + // reference so callers can detect the no-op and reconcile (refetch) instead. + if ( + to.parentId !== null && + treeModel.isDescendant(tree, sourceId, to.parentId) + ) { + return tree; + } const removed = treeModel.remove(tree, sourceId); // Reuse the same position-ordered insertion as `insertByPosition` by // stamping the authoritative position onto the moved node first. diff --git a/apps/client/src/features/websocket/tree-socket-reducers.test.ts b/apps/client/src/features/websocket/tree-socket-reducers.test.ts index f59f27cc..20abdf95 100644 --- a/apps/client/src/features/websocket/tree-socket-reducers.test.ts +++ b/apps/client/src/features/websocket/tree-socket-reducers.test.ts @@ -183,6 +183,34 @@ describe("applyMoveTreeNode", () => { expect(moved?.hasChildren).toBe(true); expect(moved?.position).toBe("a4"); }); + + it("does NOT drop a subtree on a cyclic/out-of-order move (parent inside source) (#206 ui-state-races-1)", () => { + // Locally `b` is still nested inside `a` (an earlier "a under b" echo hasn't + // applied yet). An out-of-order "move a under b" event now arrives — b is a + // descendant of a, so re-parenting would make placeByPosition remove a (and + // its whole subtree, incl. b) and fail to re-insert. Before the fix BOTH a + // and b silently vanished; now the reducer leaves the tree untouched. + const tree: SpaceTreeNode[] = [ + node("a", { + position: "a0", + hasChildren: true, + children: [node("b", { position: "a1", parentPageId: "a" })], + }), + ]; + const next = applyMoveTreeNode(tree, { + id: "a", + parentId: "b", + oldParentId: null, + index: 0, + position: "a4", + pageData: {}, + }); + // No silent data loss: both nodes survive. + expect(treeModel.find(next, "a")).not.toBeNull(); + expect(treeModel.find(next, "b")).not.toBeNull(); + // The cyclic move is refused as a no-op (same reference) pending reconcile. + expect(next).toBe(tree); + }); }); describe("applyDeleteTreeNode", () => { diff --git a/apps/client/src/features/websocket/tree-socket-reducers.ts b/apps/client/src/features/websocket/tree-socket-reducers.ts index 3f6226d9..fe3b1a43 100644 --- a/apps/client/src/features/websocket/tree-socket-reducers.ts +++ b/apps/client/src/features/websocket/tree-socket-reducers.ts @@ -76,6 +76,19 @@ export function applyMoveTreeNode( const oldParentId = (sourceBefore as SpaceTreeNode).parentPageId ?? null; const newParentId = payload.parentId as string | null; + // Cyclic / out-of-order move guard (#206 ui-state-races-1): if the + // authoritative new parent is currently INSIDE the moved node's own subtree on + // this client (e.g. server moved X under Y then Y under X and the events + // arrived such that Y is still nested in X here), re-parenting is impossible to + // represent locally. `placeByPosition` returns `prev` for this, but the + // `placed === prev` fallback below would then `remove` the source — dropping + // the node AND every descendant (incl. the would-be parent) silently. Leave the + // tree untouched instead; a later corrective event or a reconnect refetch + // reconciles it. Never delete a subtree we cannot safely re-place. + if (newParentId && treeModel.isDescendant(prev, payload.id, newParentId)) { + return prev; + } + // Place the node by its fractional `position` among the new siblings — NOT by // the sender's absolute `index` (the sender computed that against its own // loaded set, which differs from this receiver's). Using the position keeps diff --git a/apps/server/src/collaboration/extensions/persistence-store.spec.ts b/apps/server/src/collaboration/extensions/persistence-store.spec.ts index 4262d77f..d0fe703d 100644 --- a/apps/server/src/collaboration/extensions/persistence-store.spec.ts +++ b/apps/server/src/collaboration/extensions/persistence-store.spec.ts @@ -182,4 +182,46 @@ describe('PersistenceExtension.onStoreDocument — Approach-A boundary snapshot' expect(pageHistoryRepo.saveHistory).not.toHaveBeenCalled(); expect(historyQueue.add).not.toHaveBeenCalled(); }); + + // persist-1 — a transient DB failure during store must not silently lose the + // edit. hocuspocus unloads (destroys) the in-memory Y.Doc right after this + // hook resolves, so the store has to retry while it still holds the only copy. + it('retries a transient DB failure and still persists the edit (persist-1)', async () => { + const document = ydocFor(doc('NEW HUMAN CONTENT')); + pageRepo.findById.mockResolvedValue(persistedHumanPage('NEW HUMAN CONTENT')); + let attempts = 0; + pageRepo.updatePage.mockImplementation(async () => { + attempts += 1; + if (attempts === 1) throw new Error('deadlock detected'); // transient + callOrder.push('updatePage'); + }); + + await ext.onStoreDocument(buildData(document, 'user') as any); + + // First attempt failed and rolled back; the retry persisted the edit. + expect(pageRepo.updatePage).toHaveBeenCalledTimes(2); + // The edit WAS saved, so the post-store success path runs as normal. + expect((document as any).broadcastStateless).toHaveBeenCalledTimes(1); + expect(historyQueue.add).toHaveBeenCalledTimes(1); + }); + + // persist-1 — when every attempt fails the hook must NOT report a phantom + // success: no "page.updated" badge broadcast and no history snapshot for + // content that was never written. + it('does not run post-store side effects when every store attempt fails (persist-1)', async () => { + const document = ydocFor(doc('NEW HUMAN CONTENT')); + pageRepo.findById.mockResolvedValue(persistedHumanPage('NEW HUMAN CONTENT')); + pageRepo.updatePage.mockRejectedValue(new Error('connection reset')); + + await expect( + ext.onStoreDocument(buildData(document, 'user') as any), + ).resolves.toBeUndefined(); + + // Bounded retry exhausted (MAX_STORE_ATTEMPTS). + expect(pageRepo.updatePage).toHaveBeenCalledTimes(3); + // No false-success: nothing downstream fires for the unsaved content. + expect((document as any).broadcastStateless).not.toHaveBeenCalled(); + expect(historyQueue.add).not.toHaveBeenCalled(); + expect(aiQueue.add).not.toHaveBeenCalled(); + }); }); diff --git a/apps/server/src/collaboration/extensions/persistence.extension.ts b/apps/server/src/collaboration/extensions/persistence.extension.ts index e9119358..f802f229 100644 --- a/apps/server/src/collaboration/extensions/persistence.extension.ts +++ b/apps/server/src/collaboration/extensions/persistence.extension.ts @@ -181,83 +181,113 @@ export class PersistenceExtension implements Extension { context?.actor, ); - try { - await executeTx(this.db, async (trx) => { - page = await this.pageRepo.findById(pageId, { - withLock: true, - includeContent: true, - trx, - }); + // Persist with a small bounded retry. The in-memory Y.Doc is the ONLY copy + // of the latest edit until this hook returns: hocuspocus destroys/unloads the + // doc right after onStoreDocument resolves (see storeDocumentHooks' finally + // -> unloadDocument). If a transient DB error (deadlock, serialization + // failure, dropped connection) is merely logged and swallowed, the function + // resolves "successfully", the doc is unloaded, and the edit is lost silently + // (#206 persist-1). Retrying here re-attempts the write while we still hold + // the doc; on total failure we clear `page` so the post-store side effects + // (badge broadcast, history snapshot) never report a save that didn't happen. + const MAX_STORE_ATTEMPTS = 3; + for (let attempt = 1; attempt <= MAX_STORE_ATTEMPTS; attempt++) { + try { + await executeTx(this.db, async (trx) => { + page = await this.pageRepo.findById(pageId, { + withLock: true, + includeContent: true, + trx, + }); - if (!page) { - this.logger.error(`Page with id ${pageId} not found`); - return; - } - - if (isDeepStrictEqual(tiptapJson, page.content)) { - page = null; - return; - } - - let contributorIds = undefined; - try { - const existingContributors = page.contributorIds || []; - contributorIds = Array.from( - new Set([ - ...existingContributors, - ...editingUserIds, - page.creatorId, - ]), - ); - } catch (err) { - //this.logger.debug('Contributors error:' + err?.['message']); - } - - // Approach A — boundary snapshot before the agent's first edit. - // When this store is the agent's and the page's currently persisted - // state was authored by a human, pin that human state as its own - // history version BEFORE the agent overwrites it. `page` still holds the - // OLD content/provenance here, so saveHistory(page) captures the - // pre-agent state tagged 'user'. The agent's new content is snapshotted - // later by the debounced PAGE_HISTORY job ('agent'). Skip if the prior - // state is already agent-authored (boundary already pinned on the - // user->agent transition), if the page is effectively empty, or if the - // latest existing snapshot already equals this human state (avoid - // duplicates). - if (lastUpdatedSource === 'agent' && page.lastUpdatedSource !== 'agent') { - const lastHistory = await this.pageHistoryRepo.findPageLastHistory( - pageId, - { includeContent: true, trx }, - ); - const humanBaselineMissing = - !lastHistory || !isDeepStrictEqual(lastHistory.content, page.content); - if (!isEmptyParagraphDoc(page.content as any) && humanBaselineMissing) { - await this.pageHistoryRepo.saveHistory(page, { - contributorIds: page.contributorIds ?? undefined, - trx, - }); + if (!page) { + this.logger.error(`Page with id ${pageId} not found`); + return; } - } - await this.pageRepo.updatePage( - { - content: tiptapJson, - textContent: textContent, - ydoc: ydocState, - lastUpdatedById: context.user.id, - // Human stays the responsible author; these annotate the source. - lastUpdatedSource, - lastUpdatedAiChatId: context?.aiChatId ?? null, - contributorIds: contributorIds, - }, - pageId, - trx, + if (isDeepStrictEqual(tiptapJson, page.content)) { + page = null; + return; + } + + let contributorIds = undefined; + try { + const existingContributors = page.contributorIds || []; + contributorIds = Array.from( + new Set([ + ...existingContributors, + ...editingUserIds, + page.creatorId, + ]), + ); + } catch (err) { + //this.logger.debug('Contributors error:' + err?.['message']); + } + + // Approach A — boundary snapshot before the agent's first edit. + // When this store is the agent's and the page's currently persisted + // state was authored by a human, pin that human state as its own + // history version BEFORE the agent overwrites it. `page` still holds + // the OLD content/provenance here, so saveHistory(page) captures the + // pre-agent state tagged 'user'. The agent's new content is + // snapshotted later by the debounced PAGE_HISTORY job ('agent'). Skip + // if the prior state is already agent-authored (boundary already + // pinned on the user->agent transition), if the page is effectively + // empty, or if the latest existing snapshot already equals this human + // state (avoid duplicates). + if ( + lastUpdatedSource === 'agent' && + page.lastUpdatedSource !== 'agent' + ) { + const lastHistory = await this.pageHistoryRepo.findPageLastHistory( + pageId, + { includeContent: true, trx }, + ); + const humanBaselineMissing = + !lastHistory || + !isDeepStrictEqual(lastHistory.content, page.content); + if ( + !isEmptyParagraphDoc(page.content as any) && + humanBaselineMissing + ) { + await this.pageHistoryRepo.saveHistory(page, { + contributorIds: page.contributorIds ?? undefined, + trx, + }); + } + } + + await this.pageRepo.updatePage( + { + content: tiptapJson, + textContent: textContent, + ydoc: ydocState, + lastUpdatedById: context.user.id, + // Human stays the responsible author; these annotate the source. + lastUpdatedSource, + lastUpdatedAiChatId: context?.aiChatId ?? null, + contributorIds: contributorIds, + }, + pageId, + trx, + ); + + this.logger.debug(`Page updated: ${pageId} - SlugId: ${page.slugId}`); + }); + break; + } catch (err) { + this.logger.error( + `Failed to update page ${pageId} (attempt ${attempt}/${MAX_STORE_ATTEMPTS})`, + err, ); - - this.logger.debug(`Page updated: ${pageId} - SlugId: ${page.slugId}`); - }); - } catch (err) { - this.logger.error(`Failed to update page ${pageId}`, err); + // The write failed and rolled back; clear the partially-assigned `page` + // so the post-store success branch below is skipped (no false "saved" + // broadcast / history snapshot for content that was never persisted). + page = null; + if (attempt < MAX_STORE_ATTEMPTS) { + await new Promise((resolve) => setTimeout(resolve, attempt * 50)); + } + } } if (page) { diff --git a/apps/server/src/core/ai-chat/public-share-chat.controller.spec.ts b/apps/server/src/core/ai-chat/public-share-chat.controller.spec.ts index 08b20b43..66289d05 100644 --- a/apps/server/src/core/ai-chat/public-share-chat.controller.spec.ts +++ b/apps/server/src/core/ai-chat/public-share-chat.controller.spec.ts @@ -34,6 +34,7 @@ describe('resolveShareAssistantRequest (extracted controller funnel)', () => { resolveShareRole?: jest.Mock; getShareChatModel?: jest.Mock; tryConsumeWorkspaceQuota?: jest.Mock; + withinShareTokenBudget?: jest.Mock; } = {}) { const aiSettings = { isPublicShareAssistantEnabled: jest @@ -65,6 +66,8 @@ describe('resolveShareAssistantRequest (extracted controller funnel)', () => { over.getShareChatModel ?? jest.fn().mockResolvedValue('MODEL'), tryConsumeWorkspaceQuota: over.tryConsumeWorkspaceQuota ?? jest.fn().mockResolvedValue(true), + withinShareTokenBudget: + over.withinShareTokenBudget ?? jest.fn().mockResolvedValue(true), }; const deps: ShareAssistantDeps = { aiSettings: aiSettings as never, @@ -191,6 +194,39 @@ describe('resolveShareAssistantRequest (extracted controller funnel)', () => { expect(publicShareChat.tryConsumeWorkspaceQuota).toHaveBeenCalledWith('ws-1'); }); + it('withinShareTokenBudget false => 429 thrown BEFORE any stream (cost cap, #159 #5)', async () => { + const { deps, publicShareChat } = makeDeps({ + withinShareTokenBudget: jest.fn().mockResolvedValue(false), + }); + expect(await statusOf(deps, body())).toBe(429); + expect(publicShareChat.withinShareTokenBudget).toHaveBeenCalledWith('ws-1'); + // The token budget is the COST backstop: an over-budget workspace must be + // rejected WITHOUT consuming a request slot, so the request cap never runs. + expect(publicShareChat.tryConsumeWorkspaceQuota).not.toHaveBeenCalled(); + }); + + it('the token budget is checked BEFORE the request cap (over-budget wins, no slot spent)', async () => { + // Over budget AND the request cap would also reject: the read-only budget + // gate must win so the (mutating) request-slot consume is never reached. + const { deps, publicShareChat } = makeDeps({ + withinShareTokenBudget: jest.fn().mockResolvedValue(false), + tryConsumeWorkspaceQuota: jest.fn().mockResolvedValue(false), + }); + expect(await statusOf(deps, body())).toBe(429); + expect(publicShareChat.tryConsumeWorkspaceQuota).not.toHaveBeenCalled(); + }); + + it('the token-budget gate is checked BEFORE the payload caps (429 wins over 413)', async () => { + const { deps } = makeDeps({ + withinShareTokenBudget: jest.fn().mockResolvedValue(false), + }); + const huge = { + role: 'user', + parts: [{ type: 'text', text: 'x'.repeat(MAX_SHARE_MESSAGE_CHARS + 1) }], + }; + expect(await statusOf(deps, body({ messages: [huge] }))).toBe(429); + }); + it('messages over MAX_SHARE_MESSAGES => 413', async () => { const { deps } = makeDeps(); const tooMany = Array.from({ length: MAX_SHARE_MESSAGES + 1 }, () => ({ diff --git a/apps/server/src/core/ai-chat/public-share-chat.controller.ts b/apps/server/src/core/ai-chat/public-share-chat.controller.ts index 74f8b538..fdab8582 100644 --- a/apps/server/src/core/ai-chat/public-share-chat.controller.ts +++ b/apps/server/src/core/ai-chat/public-share-chat.controller.ts @@ -151,6 +151,7 @@ export interface ShareAssistantDeps { | 'resolveShareRole' | 'getShareChatModel' | 'tryConsumeWorkspaceQuota' + | 'withinShareTokenBudget' >; } @@ -267,9 +268,21 @@ export async function resolveShareAssistantRequest( throw new NotFoundException('Not found'); } - // 5. Per-WORKSPACE anti-abuse cap (IP-independent; defense in depth). Checked - // BEFORE res.hijack(), so an over-cap workspace gets a clean 429 and spends - // nothing. + // 5a. Per-WORKSPACE rolling-day TOKEN budget (the COST backstop). Read-only and + // checked FIRST so a workspace that has already burned its day's token + // budget gets a clean 429 WITHOUT consuming a request slot, and spends + // nothing. Counting requests alone does not bound the owner's provider + // bill (issue #159, finding #5). + if (!(await deps.publicShareChat.withinShareTokenBudget(workspaceId))) { + throw new HttpException( + 'This documentation assistant has reached its usage budget. Please try again later.', + HttpStatus.TOO_MANY_REQUESTS, + ); + } + + // 5b. Per-WORKSPACE anti-abuse request cap (IP-independent; defense in depth). + // Checked BEFORE res.hijack(), so an over-cap workspace gets a clean 429 + // and spends nothing. if (!(await deps.publicShareChat.tryConsumeWorkspaceQuota(workspaceId))) { throw new HttpException( 'This documentation assistant is temporarily busy. Please try again later.', diff --git a/apps/server/src/core/ai-chat/public-share-chat.service.ts b/apps/server/src/core/ai-chat/public-share-chat.service.ts index 8011814b..a98e738f 100644 --- a/apps/server/src/core/ai-chat/public-share-chat.service.ts +++ b/apps/server/src/core/ai-chat/public-share-chat.service.ts @@ -17,7 +17,9 @@ import { buildShareSystemPrompt } from './public-share-chat.prompt'; import { roleModelOverride } from './roles/role-model-config'; import { PublicShareWorkspaceLimiter, + PublicShareWorkspaceTokenBudget, createPublicShareWorkspaceLimiter, + createPublicShareWorkspaceTokenBudget, } from './public-share-workspace-limiter'; import { describeProviderError } from '../../integrations/ai/ai-error.util'; import { @@ -125,6 +127,16 @@ export class PublicShareChatService { */ private readonly workspaceLimiter: PublicShareWorkspaceLimiter; + /** + * COST contour two: a per-workspace TOKEN budget over a rolling day. The + * request-count limiter above bounds how many anonymous calls run; this bounds + * how many provider TOKENS they spend (input re-sent per step + output), + * which is what the owner is actually billed for (issue #159, finding #5). + * Checked read-only before a turn streams; the real usage is recorded once the + * turn finishes (`onFinish`). + */ + private readonly tokenBudget: PublicShareWorkspaceTokenBudget; + constructor( private readonly ai: AiService, private readonly aiSettings: AiSettingsService, @@ -133,6 +145,7 @@ export class PublicShareChatService { private readonly aiAgentRoleRepo: AiAgentRoleRepo, ) { this.workspaceLimiter = createPublicShareWorkspaceLimiter(redisService); + this.tokenBudget = createPublicShareWorkspaceTokenBudget(redisService); } /** @@ -144,6 +157,48 @@ export class PublicShareChatService { return this.workspaceLimiter.tryConsume(workspaceId); } + /** + * Read-only pre-stream COST gate: true while the workspace is under its + * rolling-day token budget, false once the trailing-day token spend has + * reached it (the controller must then 429 BEFORE starting the stream). This + * bounds the owner's actual provider bill, which counting requests alone does + * not (issue #159, finding #5). + */ + async withinShareTokenBudget(workspaceId: string): Promise { + return this.tokenBudget.withinBudget(workspaceId); + } + + /** + * Record a finished turn's real token spend against the rolling-day budget. + * Best-effort (the turn already ran): failures are swallowed by the budget. + */ + async recordShareTokens(workspaceId: string, tokens: number): Promise { + return this.tokenBudget.record(workspaceId, tokens); + } + + /** + * `streamText` onFinish hook body: account a finished turn's REAL token spend + * (input re-sent per step + output, summed across all steps) against the + * per-workspace rolling-day budget, so a future turn over budget is rejected up + * front (issue #159, finding #5). `totalUsage` fields are `number | undefined`; + * fall back to the sum of input+output when the provider omits `totalTokens`. + * Fire-and-forget: the turn already streamed, so a record failure must not + * break it. + */ + recordTurnUsage( + workspaceId: string, + totalUsage: { + totalTokens?: number; + inputTokens?: number; + outputTokens?: number; + }, + ): void { + const tokens = + totalUsage.totalTokens ?? + (totalUsage.inputTokens ?? 0) + (totalUsage.outputTokens ?? 0); + void this.recordShareTokens(workspaceId, tokens); + } + /** * Resolve the admin-selected agent role for the anonymous public-share * assistant, scoped to the workspace and soft-delete aware. Returns null when @@ -231,6 +286,8 @@ export class PublicShareChatService { // bill even if the per-IP throttle is evaded; worst case = steps × this. maxOutputTokens: resolveShareAiMaxOutputTokens(), abortSignal: signal, + onFinish: ({ totalUsage }) => + this.recordTurnUsage(workspaceId, totalUsage), onError: ({ error }) => { // Reuse the shared formatter so provider error formatting stays // unified (statusCode + body) with the authenticated path. diff --git a/apps/server/src/core/ai-chat/public-share-chat.spec.ts b/apps/server/src/core/ai-chat/public-share-chat.spec.ts index 3b80e9be..f65058d9 100644 --- a/apps/server/src/core/ai-chat/public-share-chat.spec.ts +++ b/apps/server/src/core/ai-chat/public-share-chat.spec.ts @@ -11,8 +11,11 @@ import { import { PublicShareChatToolsService } from './tools/public-share-chat-tools.service'; import { PublicShareWorkspaceLimiter, + PublicShareWorkspaceTokenBudget, resolveShareAiWorkspaceMax, + resolveShareAiWorkspaceTokenBudget, SHARE_AI_WORKSPACE_MAX_PER_WINDOW, + SHARE_AI_WORKSPACE_TOKEN_BUDGET_DEFAULT, } from './public-share-workspace-limiter'; /** @@ -546,6 +549,228 @@ describe('PublicShareWorkspaceLimiter (cluster-wide sliding-window per-workspace }); }); +/** + * In-memory fake of the ioredis slice the TOKEN budget uses. Unlike the request + * limiter (one Lua), the budget runs TWO scripts over the same sorted set: + * - the read-only CHECK (sums the token counts encoded as each member's leading + * integer, admits while the sum is under budget, never mutates), and + * - the RECORD (ZADDs a finished turn's `:` member). + * The fake faithfully reproduces both (branching on the script body) so the spec + * exercises the REAL budget math, not a re-implementation. + */ +class FakeTokenRedis { + private sets = new Map>(); + + async eval( + script: string, + _numKeys: number, + key: string, + nowStr: string, + windowMsStr: string, + arg3: string, + ): Promise { + const now = Number(nowStr); + const windowMs = Number(windowMsStr); + const cutoff = now - windowMs; + const arr = (this.sets.get(key) ?? []).filter((e) => e.score > cutoff); + if (script.includes('ZADD')) { + // RECORD: arg3 is the `:` member; append at score=now. + arr.push({ score: now, member: arg3 }); + this.sets.set(key, arr); + return 1; + } + // CHECK: arg3 is the budget; sum the leading integer of each survivor. + const budget = Number(arg3); + this.sets.set(key, arr); + const total = arr.reduce((sum, e) => { + const m = /^(\d+)/.exec(e.member); + return sum + (m ? Number(m[1]) : 0); + }, 0); + return total >= budget ? 0 : 1; + } +} + +function makeTokenBudget(budget: number, windowMs: number, clock: () => number) { + const redis = new FakeTokenRedis() as unknown as import('ioredis').Redis; + return new PublicShareWorkspaceTokenBudget(redis, budget, windowMs, clock); +} + +describe('resolveShareAiWorkspaceTokenBudget (env-overridable per-day token budget)', () => { + const KEY = 'SHARE_AI_WORKSPACE_TOKEN_BUDGET_PER_DAY'; + const saved = process.env[KEY]; + afterEach(() => { + if (saved === undefined) delete process.env[KEY]; + else process.env[KEY] = saved; + }); + + it('falls back to the default when unset', () => { + delete process.env[KEY]; + expect(resolveShareAiWorkspaceTokenBudget()).toBe( + SHARE_AI_WORKSPACE_TOKEN_BUDGET_DEFAULT, + ); + }); + + it('honors a positive override', () => { + process.env[KEY] = '250000'; + expect(resolveShareAiWorkspaceTokenBudget()).toBe(250000); + }); + + it('ignores a non-positive / unparseable value (uses the default)', () => { + for (const bad of ['0', '-5', 'nope', '']) { + process.env[KEY] = bad; + expect(resolveShareAiWorkspaceTokenBudget()).toBe( + SHARE_AI_WORKSPACE_TOKEN_BUDGET_DEFAULT, + ); + } + }); +}); + +describe('PublicShareWorkspaceTokenBudget (cluster-wide rolling-day token cap)', () => { + it('admits while under budget and rejects once the recorded spend reaches it', async () => { + const budget = makeTokenBudget(1000, 60_000, () => 1_000); + expect(await budget.withinBudget('ws-1')).toBe(true); // nothing spent yet + await budget.record('ws-1', 600); + expect(await budget.withinBudget('ws-1')).toBe(true); // 600 < 1000 + await budget.record('ws-1', 400); + // 1000 >= 1000: the budget is exhausted, so the next turn is rejected up front. + expect(await budget.withinBudget('ws-1')).toBe(false); + }); + + it('counts TOKENS, not requests: one fat turn can exhaust the budget alone', async () => { + const budget = makeTokenBudget(1000, 60_000, () => 1_000); + // A single accepted turn re-sends the whole transcript across 5 steps; here + // it lands as 1200 tokens — already over the day budget on its own. + await budget.record('ws-1', 1200); + expect(await budget.withinBudget('ws-1')).toBe(false); + }); + + it('ages out spend older than the window so the budget recovers', async () => { + let now = 0; + const budget = makeTokenBudget(1000, 60_000, () => now); + await budget.record('ws-1', 1000); // at budget + now += 59_999; // still inside the day window + expect(await budget.withinBudget('ws-1')).toBe(false); + now += 2; // the spend is now strictly older than windowMs + expect(await budget.withinBudget('ws-1')).toBe(true); + }); + + it('ignores non-positive / non-finite usage (never records phantom spend)', async () => { + const budget = makeTokenBudget(1000, 60_000, () => 1_000); + await budget.record('ws-1', 0); + await budget.record('ws-1', -50); + await budget.record('ws-1', Number.NaN); + await budget.record('ws-1', Infinity); + expect(await budget.withinBudget('ws-1')).toBe(true); // nothing accumulated + }); + + it('keeps separate budgets per workspace', async () => { + const budget = makeTokenBudget(500, 60_000, () => 1_000); + await budget.record('ws-a', 500); // ws-a exhausted + expect(await budget.withinBudget('ws-a')).toBe(false); + expect(await budget.withinBudget('ws-b')).toBe(true); // ws-b untouched + }); + + it('FAILS CLOSED on the read-only check when Redis rejects', async () => { + const failingRedis = { + eval: () => Promise.reject(new Error('redis down')), + } as unknown as import('ioredis').Redis; + const budget = new PublicShareWorkspaceTokenBudget( + failingRedis, + 1000, + 60_000, + () => 1_000, + ); + const errSpy = jest + .spyOn(Logger.prototype, 'error') + .mockImplementation(() => undefined); + expect(await budget.withinBudget('ws-1')).toBe(false); + expect(errSpy).toHaveBeenCalled(); + errSpy.mockRestore(); + }); + + it('SWALLOWS a record failure (best-effort post-accounting, never throws)', async () => { + // The turn already streamed; a record failure must not surface to the caller. + const failingRedis = { + eval: () => Promise.reject(new Error('redis down')), + } as unknown as import('ioredis').Redis; + const budget = new PublicShareWorkspaceTokenBudget( + failingRedis, + 1000, + 60_000, + () => 1_000, + ); + const errSpy = jest + .spyOn(Logger.prototype, 'error') + .mockImplementation(() => undefined); + await expect(budget.record('ws-1', 100)).resolves.toBeUndefined(); + expect(errSpy).toHaveBeenCalled(); + errSpy.mockRestore(); + }); +}); + +describe('PublicShareChatService.withinShareTokenBudget / recordShareTokens', () => { + it('delegates the cost gate + accounting to the redis-backed token budget', async () => { + const redis = new FakeTokenRedis(); + const redisService = { getOrThrow: () => redis } as never; + const service = new PublicShareChatService( + {} as never, + {} as never, + {} as never, + redisService, + {} as never, + ); + // Default budget is large, so a fresh workspace is under budget; recording a + // modest spend keeps it under budget (asserts the wiring the controller + + // onFinish rely on). + expect(await service.withinShareTokenBudget('ws-1')).toBe(true); + await service.recordShareTokens('ws-1', 1234); + expect(await service.withinShareTokenBudget('ws-1')).toBe(true); + }); +}); + +describe('PublicShareChatService.recordTurnUsage (streamText onFinish accounting)', () => { + function makeService() { + const redisService = { getOrThrow: () => new FakeTokenRedis() } as never; + const service = new PublicShareChatService( + {} as never, + {} as never, + {} as never, + redisService, + {} as never, + ); + const recordSpy = jest + .spyOn(service, 'recordShareTokens') + .mockResolvedValue(undefined); + return { service, recordSpy }; + } + + it('sums input+output when the provider omits totalTokens', () => { + const { service, recordSpy } = makeService(); + // The onFinish payload shape: a totalUsage with per-component counts but no + // authoritative total (provider omitted it). + service.recordTurnUsage('ws-1', { inputTokens: 1200, outputTokens: 300 }); + expect(recordSpy).toHaveBeenCalledWith('ws-1', 1500); + }); + + it('treats missing input/output components as 0 in the fallback sum', () => { + const { service, recordSpy } = makeService(); + service.recordTurnUsage('ws-1', { outputTokens: 42 }); + expect(recordSpy).toHaveBeenCalledWith('ws-1', 42); + }); + + it('prefers the authoritative totalTokens when present (not the sum)', () => { + const { service, recordSpy } = makeService(); + // totalTokens is the provider's authoritative figure and may differ from a + // naive input+output sum (e.g. cached/ reasoning tokens); it must win. + service.recordTurnUsage('ws-1', { + totalTokens: 5000, + inputTokens: 1200, + outputTokens: 300, + }); + expect(recordSpy).toHaveBeenCalledWith('ws-1', 5000); + }); +}); + describe('PublicShareChatService.tryConsumeWorkspaceQuota', () => { it('delegates to the redis-backed per-workspace limiter', async () => { const redis = new FakeRedis(); diff --git a/apps/server/src/core/ai-chat/public-share-workspace-limiter.ts b/apps/server/src/core/ai-chat/public-share-workspace-limiter.ts index cf0dd80d..d6f660a8 100644 --- a/apps/server/src/core/ai-chat/public-share-workspace-limiter.ts +++ b/apps/server/src/core/ai-chat/public-share-workspace-limiter.ts @@ -136,6 +136,177 @@ export class PublicShareWorkspaceLimiter { } } +/** + * SECOND cost contour: a per-workspace TOKEN budget over a rolling DAY. + * + * The request-count cap above bounds how MANY anonymous calls a workspace + * admits, but NOT how expensive each one is: one accepted call runs the agent + * loop up to `stepCountIs(5)`, and every step re-sends the WHOLE client-held + * transcript (~hundreds of KB) as input, so the provider input alone can be tens + * of thousands of tokens PER step while `maxOutputTokens` only caps the output. + * The request cap is also hourly with no daily ceiling, so a steady stream at + * the hourly cap sustains ~24x its count per day. Counting requests therefore + * does not bound the owner's actual LLM bill (issue #159, finding #5). + * + * This contour caps the SPEND directly: the actual tokens consumed (input + + * output, summed across all steps of every accepted turn) over the trailing + * `windowMs` (one rolling day) must stay under `budget`. It is checked BEFORE a + * turn streams (read-only) and the turn's real usage is recorded AFTER it + * finishes (`streamText` onFinish). Like the request cap it is cluster-wide + * (shared Redis) and uses a sliding-window LOG so the day boundary cannot be + * gamed for a 2x burst. + * + * Pre-check is read-only, so a turn already over budget is rejected, but the + * tokens of an in-flight turn are not yet known and are accounted only once it + * finishes. The worst-case overshoot past the budget is therefore one turn + * (bounded by steps x (maxOutputTokens + transcript size)) — acceptable for a + * cost backstop on an optional anonymous assistant. + */ + +/** Default per-workspace token budget over the rolling day. */ +export const SHARE_AI_WORKSPACE_TOKEN_BUDGET_DEFAULT = 1_000_000; +/** Default token-budget window length: one rolling day. */ +export const SHARE_AI_WORKSPACE_TOKEN_WINDOW_MS = 24 * 60 * 60 * 1000; + +/** Redis key namespace for the per-workspace token-spend sliding-window log. */ +const TOKEN_KEY_PREFIX = 'share-ai:ws-tokens:'; + +/** + * Read-only sliding-window token-budget check. + * + * KEYS[1] = the per-workspace token sorted-set key + * ARGV[1] = now (epoch ms) + * ARGV[2] = windowMs + * ARGV[3] = budget (max tokens in the trailing window) + * + * Drops entries older than the window, then sums the token counts encoded as the + * leading integer of each surviving member. Returns 1 if the running total is + * still UNDER budget (admit), 0 once it has reached/exceeded the budget. Does NOT + * add anything — the turn's real usage is recorded separately once it finishes. + */ +const TOKEN_BUDGET_CHECK_LUA = ` +local key = KEYS[1] +local now = tonumber(ARGV[1]) +local windowMs = tonumber(ARGV[2]) +local budget = tonumber(ARGV[3]) +redis.call('ZREMRANGEBYSCORE', key, 0, now - windowMs) +local members = redis.call('ZRANGE', key, 0, -1) +local total = 0 +for i = 1, #members do + local t = tonumber(string.match(members[i], '^(%d+)')) + if t then total = total + t end +end +if total >= budget then + return 0 +end +return 1 +`; + +/** + * Record one finished turn's token spend in the sliding-window log. + * + * KEYS[1] = the per-workspace token sorted-set key + * ARGV[1] = now (epoch ms) — the entry score + * ARGV[2] = windowMs + * ARGV[3] = member (`:`; the leading integer is the token count) + * + * Always ZADDs (the turn already ran and spent the tokens) and refreshes the + * key TTL so idle workspaces cost no memory. Trims expired entries first so the + * set never grows unbounded for a busy workspace. + */ +const TOKEN_RECORD_LUA = ` +local key = KEYS[1] +local now = tonumber(ARGV[1]) +local windowMs = tonumber(ARGV[2]) +local member = ARGV[3] +redis.call('ZREMRANGEBYSCORE', key, 0, now - windowMs) +redis.call('ZADD', key, now, member) +redis.call('PEXPIRE', key, windowMs) +return 1 +`; + +/** + * Cluster-wide, sliding-window per-workspace TOKEN budget backed by Redis. + * `withinBudget(key)` is a read-only pre-stream gate; `record(key, tokens)` + * accounts a finished turn's real usage. Decoupled from NestJS so it is testable + * against a mocked/real ioredis client, mirroring the request-count limiter. + */ +export class PublicShareWorkspaceTokenBudget { + private readonly logger = new Logger(PublicShareWorkspaceTokenBudget.name); + private counter = 0; + + constructor( + private readonly redis: Redis, + private readonly budget: number = SHARE_AI_WORKSPACE_TOKEN_BUDGET_DEFAULT, + private readonly windowMs: number = SHARE_AI_WORKSPACE_TOKEN_WINDOW_MS, + private readonly now: () => number = Date.now, + ) {} + + /** + * Read-only pre-stream check. Returns true while the workspace is under its + * rolling-day token budget, false once the trailing-window spend has reached + * it (caller must then 429 BEFORE streaming any tokens). + * + * FAILS CLOSED (false) on a Redis error: identical reasoning to the request + * limiter — when we cannot prove the workspace is under budget we DENY rather + * than admit an unmetered billable call. The assistant is optional, so a + * transient Redis blip briefly disabling it beats an unbounded provider bill. + */ + async withinBudget(key: string): Promise { + const t = this.now(); + try { + const admitted = await this.redis.eval( + TOKEN_BUDGET_CHECK_LUA, + 1, + TOKEN_KEY_PREFIX + key, + String(t), + String(this.windowMs), + String(this.budget), + ); + return admitted === 1; + } catch (err) { + this.logger.error( + `share-ai token budget Redis failure for key "${key}"; failing closed`, + err as Error, + ); + return false; + } + } + + /** + * Record a finished turn's token spend. Best-effort: the turn already ran, so + * a Redis failure here is logged but not propagated — it would only cause a + * slight under-count of the running budget, never a wrong answer to the + * caller. Non-positive / non-finite usage is ignored. + */ + async record(key: string, tokens: number): Promise { + if (!Number.isFinite(tokens) || tokens <= 0) return; + const spend = Math.floor(tokens); + const t = this.now(); + // Member: `:` — the check Lua sums the leading integer, and + // the unique suffix keeps distinct turns in the same ms from colliding on + // the sorted-set member (which would drop one entry and under-count). + const member = `${spend}:${t}-${this.counter++}-${Math.random() + .toString(36) + .slice(2)}`; + try { + await this.redis.eval( + TOKEN_RECORD_LUA, + 1, + TOKEN_KEY_PREFIX + key, + String(t), + String(this.windowMs), + member, + ); + } catch (err) { + this.logger.error( + `share-ai token budget record failure for key "${key}" (${spend} tokens); ignoring`, + err as Error, + ); + } + } +} + /** * Read the per-workspace cap from the environment (overridable seam), falling * back to the sane default. A non-positive / unparseable value uses the default. @@ -162,3 +333,31 @@ export function createPublicShareWorkspaceLimiter( SHARE_AI_WORKSPACE_WINDOW_MS, ); } + +/** + * Read the per-workspace rolling-day token budget from the environment + * (overridable seam), falling back to the sane default. A non-positive / + * unparseable value uses the default. + */ +export function resolveShareAiWorkspaceTokenBudget(): number { + const raw = Number(process.env.SHARE_AI_WORKSPACE_TOKEN_BUDGET_PER_DAY); + return Number.isFinite(raw) && raw > 0 + ? Math.floor(raw) + : SHARE_AI_WORKSPACE_TOKEN_BUDGET_DEFAULT; +} + +/** + * Build the per-workspace token budget from the injected RedisService (the same + * global ioredis client used by the request-count limiter). Tiny factory so the + * service constructor stays declarative and the budget stays unit-testable with + * a hand-rolled fake redis. + */ +export function createPublicShareWorkspaceTokenBudget( + redisService: RedisService, +): PublicShareWorkspaceTokenBudget { + return new PublicShareWorkspaceTokenBudget( + redisService.getOrThrow(), + resolveShareAiWorkspaceTokenBudget(), + SHARE_AI_WORKSPACE_TOKEN_WINDOW_MS, + ); +} diff --git a/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.spec.ts b/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.spec.ts index 5add9494..ebf1cb6a 100644 --- a/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.spec.ts +++ b/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.spec.ts @@ -120,18 +120,25 @@ describe('AiChatToolsService deletePage guardrail (H4)', () => { const tools = await buildTools(); const deletePage = tools.deletePage; - // The Zod input schema only allows `pageId`; parsing strips/ignores extra - // keys, so a permanent/force flag is never part of the validated input. + // The wrapped input schema (modelFriendlyInput) only allows `pageId`; + // validation strips/ignores extra keys, so a permanent/force flag is never + // part of the validated input handed to execute. const schema = (deletePage as unknown as { inputSchema: unknown }) .inputSchema as { - parse: (v: unknown) => Record; + validate: ( + v: unknown, + ) => + | { success: boolean; value?: Record } + | Promise<{ success: boolean; value?: Record }>; }; - const parsed = schema.parse({ + const result = await schema.validate({ pageId: 'page-789', permanentlyDelete: true, forceDelete: true, }); + expect(result.success).toBe(true); + const parsed = result.value as Record; expect(parsed).toHaveProperty('pageId', 'page-789'); expect(parsed).not.toHaveProperty('permanentlyDelete'); expect(parsed).not.toHaveProperty('forceDelete'); @@ -207,19 +214,26 @@ describe('AiChatToolsService expanded toolset guardrails', () => { const tools = await buildTools(); const transformPage = tools.transformPage; - // The Zod input schema only allows pageId/transformJs/dryRun; parsing - // strips unknown keys, so deleteComments can never reach the client. + // The wrapped input schema only allows pageId/transformJs/dryRun; + // validation strips unknown keys, so deleteComments can never reach the + // client. const schema = (transformPage as unknown as { inputSchema: unknown }) .inputSchema as { - parse: (v: unknown) => Record; + validate: ( + v: unknown, + ) => + | { success: boolean; value?: Record } + | Promise<{ success: boolean; value?: Record }>; }; - const parsed = schema.parse({ + const result = await schema.validate({ pageId: 'p', transformJs: '(d)=>d', dryRun: true, deleteComments: true, }); + expect(result.success).toBe(true); + const parsed = result.value as Record; expect(parsed).toHaveProperty('pageId', 'p'); expect(parsed).not.toHaveProperty('deleteComments'); }); @@ -395,3 +409,95 @@ describe('AiChatToolsService node-arg JSON-string coercion', () => { expect(updatePageJsonCalls).toHaveLength(0); }); }); + +/** + * Model-friendly tool-call validation (#190): when the model drops a required + * `pageId` in a parallel/batch tool call, the built-in input schema must return + * a CLEAR, actionable message (naming the parameter, reminding it not to drop + * ids in batches) instead of zod's raw "expected string, received undefined" — + * while a valid call still validates. This is wired centrally via + * modelFriendlyInput, so it applies to every in-app tool; createComment (the + * tool from the bug report) and a sharedTool-built tool (getPage's sibling + * getOutline) are exercised here end-to-end through forUser(). + */ +describe('AiChatToolsService model-friendly input validation (#190)', () => { + const fakeClient: Partial = {}; + const tokenServiceStub = { + generateAccessToken: jest.fn().mockResolvedValue('access-token'), + generateCollabToken: jest.fn().mockResolvedValue('collab-token'), + }; + let service: AiChatToolsService; + + beforeEach(() => { + jest.spyOn(loader, 'loadDocmostMcp').mockResolvedValue( + mockLoaded(function () { + return fakeClient as DocmostClientLike; + } as unknown as loader.DocmostClientCtor), + ); + service = new AiChatToolsService( + tokenServiceStub as never, + {} as never, + {} as never, + {} as never, + {} as never, + ); + }); + + afterEach(() => jest.restoreAllMocks()); + + function buildTools() { + return service.forUser( + { id: 'user-1', email: 'u@example.com', workspaceId: 'ws-1' } as never, + 'session-1', + 'ws-1', + 'chat-1', + ); + } + + // The AI SDK Schema produced by modelFriendlyInput exposes `validate`. + type ValidatableSchema = { + validate: ( + v: unknown, + ) => + | { success: boolean; value?: unknown; error?: Error } + | Promise<{ success: boolean; value?: unknown; error?: Error }>; + }; + const inputSchemaOf = (t: unknown) => + (t as { inputSchema: unknown }).inputSchema as ValidatableSchema; + + it('createComment: a dropped pageId yields a clear, model-actionable message', async () => { + const tools = await buildTools(); + // The exact failing shape from the bug report's second parallel batch: + // content + selection, but pageId silently dropped. + const result = await inputSchemaOf(tools.createComment).validate({ + content: 'A remark', + selection: 'титановый проводник', + }); + expect(result.success).toBe(false); + expect(result.error?.message).toContain('parameter "pageId": missing (required)'); + expect(result.error?.message).toContain('parallel/batch tool calls'); + // Not the raw zod text the model previously received. + expect(result.error?.message).not.toContain('received undefined'); + }); + + it('createComment: a valid call with pageId validates successfully', async () => { + const tools = await buildTools(); + const result = await inputSchemaOf(tools.createComment).validate({ + pageId: '019efe44-0000-0000-0000-000000000000', + content: 'A remark', + selection: 'титановый проводник', + }); + expect(result.success).toBe(true); + expect(result.value).toMatchObject({ + pageId: '019efe44-0000-0000-0000-000000000000', + content: 'A remark', + }); + }); + + it('sharedTool-built tools (getOutline) also get the friendly message on a dropped pageId', async () => { + const tools = await buildTools(); + const result = await inputSchemaOf(tools.getOutline).validate({}); + expect(result.success).toBe(false); + expect(result.error?.message).toContain('parameter "pageId": missing (required)'); + }); +}); diff --git a/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.ts b/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.ts index f6d38487..377d4036 100644 --- a/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.ts +++ b/apps/server/src/core/ai-chat/tools/ai-chat-tools.service.ts @@ -15,6 +15,7 @@ import { } from './docmost-client.loader'; import { resolveCurrentPageResult } from './current-page.util'; import { parseNodeArg } from './parse-node-arg'; +import { modelFriendlyInput } from './model-friendly-input'; /** * Per-user, per-request adapter that exposes Docmost READ operations to the @@ -102,9 +103,13 @@ export class AiChatToolsService { ): Tool => tool({ description: spec.description, - inputSchema: spec.buildShape - ? z.object(spec.buildShape(z) as z.ZodRawShape) - : z.object({}), + // Wrap via modelFriendlyInput so a dropped/invalid parameter (e.g. a + // pageId omitted in a parallel batch, #190) yields a clear, actionable + // tool error instead of zod's raw text. No-arg specs still get an empty + // object schema. + inputSchema: modelFriendlyInput( + spec.buildShape ? (spec.buildShape(z) as z.ZodRawShape) : {}, + ), execute, }); @@ -118,7 +123,7 @@ export class AiChatToolsService { 'and entities), not a full sentence. If the first results look weak ' + 'or incomplete, search again with different wording or synonyms ' + 'before answering.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ query: z.string().describe('The search query.'), limit: z .number() @@ -227,7 +232,7 @@ export class AiChatToolsService { '"the current page", or "here" refers to. Returns the page id and title, ' + 'or null if the user is not currently on a page. Call this first whenever ' + 'the user refers to the current page without giving an explicit id.', - inputSchema: z.object({}), + inputSchema: modelFriendlyInput({}), execute: async () => resolveCurrentPageResult(openedPage), }), @@ -235,7 +240,7 @@ export class AiChatToolsService { description: 'Fetch a single page as Markdown by its page id. Returns the page ' + 'title and its Markdown content.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id (or slugId) of the page.'), }), execute: async ({ pageId }) => { @@ -259,7 +264,7 @@ export class AiChatToolsService { 'Create a new page with a Markdown body in a space, optionally under ' + 'a parent page. Returns the new page id and title. Reversible: a page ' + 'can be moved to trash later.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ title: z.string().describe('The title of the new page.'), content: z .string() @@ -294,7 +299,7 @@ export class AiChatToolsService { description: "Replace a page's body with new Markdown content (and optionally its " + 'title). Reversible: the previous version is kept in page history.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to update.'), content: z.string().describe('The new page body as Markdown.'), title: z @@ -316,7 +321,7 @@ export class AiChatToolsService { description: "Rename a page (change its title only; the body is untouched). " + 'Reversible: rename back at any time.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to rename.'), title: z.string().describe('The new title.'), }), @@ -331,7 +336,7 @@ export class AiChatToolsService { description: 'Move a page under a new parent page, or to the space root when no ' + 'parent is given. Reversible: move it back at any time.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to move.'), parentPageId: z .string() @@ -353,7 +358,7 @@ export class AiChatToolsService { description: 'Move a page to the trash (SOFT delete only — fully reversible; the ' + 'page can be restored from trash). This NEVER permanently deletes.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to move to trash.'), }), // GUARDRAIL (§14 H4): the only field ever passed to the client is @@ -379,7 +384,7 @@ export class AiChatToolsService { '"selection not found" error, retry with a corrected EXACT selection ' + 'copied verbatim from a single paragraph/block. Reversible via the ' + 'comment UI.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to comment on.'), content: z.string().describe('The comment body as Markdown.'), selection: z @@ -428,7 +433,7 @@ export class AiChatToolsService { description: 'Resolve or reopen a top-level comment thread (reversible — toggle ' + 'the resolved flag). Only top-level comments can be resolved.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ commentId: z .string() .describe('The id of the top-level comment to resolve/reopen.'), @@ -460,7 +465,7 @@ export class AiChatToolsService { 'List the most recent pages, optionally scoped to a single space. ' + 'Returns a bounded list (default 50, max 100). Pass tree:true (with ' + "spaceId) to instead get the space's full page hierarchy as a nested tree.", - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ spaceId: z .string() .optional() @@ -488,7 +493,7 @@ export class AiChatToolsService { 'List sidebar pages for a space. With no pageId, returns the ' + "space's ROOT pages; with a pageId, returns that page's direct " + 'CHILDREN.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ spaceId: z.string().describe('The id of the space.'), pageId: z .string() @@ -520,7 +525,7 @@ export class AiChatToolsService { description: 'Read a table as a matrix of cell texts (plus a parallel cellIds ' + 'matrix so cells can be addressed for rich edits).', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page.'), tableRef: z .string() @@ -536,7 +541,7 @@ export class AiChatToolsService { listComments: tool({ description: 'List all comments on a page (content as Markdown).', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page.'), }), execute: async ({ pageId }) => await client.listComments(pageId), @@ -544,7 +549,7 @@ export class AiChatToolsService { getComment: tool({ description: 'Fetch a single comment by id (content as Markdown).', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ commentId: z.string().describe('The id of the comment.'), }), execute: async ({ commentId }) => await client.getComment(commentId), @@ -554,7 +559,7 @@ export class AiChatToolsService { description: 'Find new comments across a space (optionally scoped to a subtree) ' + 'created after a given timestamp.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ spaceId: z.string().describe('The id of the space to scan.'), since: z .string() @@ -586,7 +591,7 @@ export class AiChatToolsService { description: 'Fetch a single page-history version including its lossless ' + 'ProseMirror content.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ historyId: z.string().describe('The id of the history version.'), }), execute: async ({ historyId }) => @@ -604,7 +609,7 @@ export class AiChatToolsService { 'Export a page to a single self-contained Docmost-flavoured ' + 'Markdown file (meta + body + comment threads). Lossless round-trip ' + 'with importPageMarkdown.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to export.'), }), execute: async ({ pageId }) => { @@ -630,7 +635,7 @@ export class AiChatToolsService { '{"type":"text","text":"x","marks":[{"type":"bold"}]}. The node arg ' + 'may be a JSON object or a JSON string (both accepted). Reversible: ' + 'the previous version is kept in page history.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page.'), nodeId: z .string() @@ -663,7 +668,7 @@ export class AiChatToolsService { '{"type":"text","text":"x","marks":[{"type":"bold"}]}. The node arg ' + 'may be a JSON object or a JSON string (both accepted). Reversible ' + 'via page history.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page.'), node: z .any() @@ -722,7 +727,7 @@ export class AiChatToolsService { 'object or a JSON string (both accepted). Omit content for a ' + 'title-only update. Reversible: the previous version is kept in page ' + 'history.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to update.'), content: z .any() @@ -753,7 +758,7 @@ export class AiChatToolsService { description: 'Insert a row of plain-text cells into a table. Reversible via ' + 'page history.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page.'), tableRef: z .string() @@ -772,7 +777,7 @@ export class AiChatToolsService { tableDeleteRow: tool({ description: 'Delete a table row at a 0-based index. Reversible via page history.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page.'), tableRef: z .string() @@ -787,7 +792,7 @@ export class AiChatToolsService { description: 'Set the plain-text content of a table cell at [row, col] (0-based). ' + 'Reversible via page history.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page.'), tableRef: z .string() @@ -817,7 +822,7 @@ export class AiChatToolsService { 'Make a page PUBLICLY accessible and return its public URL. ' + 'Reversible via unsharePage. Only share when the user explicitly ' + 'asked, since this exposes the page to anyone with the link.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to share.'), searchIndexing: z .boolean() @@ -844,7 +849,7 @@ export class AiChatToolsService { "page's ProseMirror document for complex/scripted rewrites. dryRun " + '(default true) previews a diff WITHOUT writing; set dryRun:false to ' + 'apply. Reversible: applying creates a new page-history snapshot.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z.string().describe('The id of the page to transform.'), transformJs: z .string() diff --git a/apps/server/src/core/ai-chat/tools/model-friendly-input.spec.ts b/apps/server/src/core/ai-chat/tools/model-friendly-input.spec.ts new file mode 100644 index 00000000..e1c5cad6 --- /dev/null +++ b/apps/server/src/core/ai-chat/tools/model-friendly-input.spec.ts @@ -0,0 +1,101 @@ +import { z } from 'zod'; +import { + modelFriendlyInput, + buildModelFriendlyMessage, +} from './model-friendly-input'; + +/** + * Unit tests for the centralized in-app tool input wrapper (#190). A dropped or + * invalid parameter must surface a clear, model-actionable message (naming the + * parameter and reminding the model not to drop ids in parallel batches), while + * a valid call validates cleanly and strips unknown keys — and the advertised + * JSON Schema keeps the unchanged required/description contract. + */ +describe('modelFriendlyInput', () => { + // Mirrors createComment's shape: pageId is the required id the model drops in + // parallel batches; selection is optional with a min length. + const shape = { + pageId: z.string().describe('The id of the page to comment on.'), + content: z.string().describe('The comment body as Markdown.'), + selection: z.string().min(1).max(250).optional(), + }; + + // Loose return type: the AI SDK ValidationResult is a discriminated union, but + // these tests assert on both branches, so a flat optional shape is simpler. + async function validate( + value: unknown, + ): Promise<{ success: boolean; value?: unknown; error?: Error }> { + const schema = modelFriendlyInput(shape); + return await schema.validate!(value); + } + + it('rejects a dropped required pageId with a clear, actionable message', async () => { + const result = await validate({ + content: 'Looks off here', + selection: 'титановый проводник', + }); + expect(result.success).toBe(false); + const msg = result.error?.message ?? ''; + // Names the dropped parameter... + expect(msg).toContain('parameter "pageId": missing (required)'); + // ...and gives an explicit, non-raw instruction (not zod's raw text). + expect(msg).toContain('parallel/batch tool calls'); + expect(msg).not.toContain('expected string, received undefined'); + }); + + it('distinguishes a present-but-invalid parameter from a missing one', async () => { + // selection is present but too short (invalid), pageId is missing. + const result = await validate({ content: 'x', selection: '' }); + expect(result.success).toBe(false); + const msg = result.error?.message ?? ''; + expect(msg).toContain('parameter "pageId": missing (required)'); + expect(msg).toContain('parameter "selection": invalid'); + }); + + it('accepts a valid call and strips unknown keys from the validated value', async () => { + const result = await validate({ + pageId: 'page-1', + content: 'A comment', + selection: 'anchor text', + bogus: true, + }); + expect(result.success).toBe(true); + if (!result.success) throw new Error('expected success'); + expect(result.value).toEqual({ + pageId: 'page-1', + content: 'A comment', + selection: 'anchor text', + }); + expect(result.value).not.toHaveProperty('bogus'); + }); + + it('preserves the required/description contract in the advertised JSON Schema', async () => { + const schema = modelFriendlyInput(shape); + const json = (await schema.jsonSchema) as { + required?: string[]; + properties?: Record; + }; + // pageId + content stay required; selection stays optional. + expect(json.required).toEqual(expect.arrayContaining(['pageId', 'content'])); + expect(json.required).not.toContain('selection'); + expect(json.properties?.pageId.description).toBe( + 'The id of the page to comment on.', + ); + }); + + it('handles a no-arg tool (empty shape) without error', async () => { + const schema = modelFriendlyInput({}); + const result = await schema.validate!({}); + expect(result.success).toBe(true); + }); +}); + +describe('buildModelFriendlyMessage', () => { + it('falls back to a generic message when issues carry an empty path', () => { + // safeParse on a non-object yields a root-level issue (empty path). + const error = z.object({ a: z.string() }).safeParse('not-an-object'); + if (error.success) throw new Error('expected failure'); + const msg = buildModelFriendlyMessage(error.error, 'not-an-object'); + expect(msg).toContain('parameter "input"'); + }); +}); diff --git a/apps/server/src/core/ai-chat/tools/model-friendly-input.ts b/apps/server/src/core/ai-chat/tools/model-friendly-input.ts new file mode 100644 index 00000000..e4ba92a7 --- /dev/null +++ b/apps/server/src/core/ai-chat/tools/model-friendly-input.ts @@ -0,0 +1,93 @@ +import { jsonSchema, type Schema } from 'ai'; +import type { JSONSchema7 } from '@ai-sdk/provider'; +import { z } from 'zod'; + +/** + * Centralized input-schema wrapper for every in-app AI-chat tool. + * + * THE PROBLEM (#190): when the model issues PARALLEL / batch tool calls it + * sometimes drops an "obvious" repeated required argument (typically `pageId`) + * from some of the calls. zod v4 correctly rejects the missing value, but the + * AI SDK forwards zod's RAW message ("Invalid input: expected string, received + * undefined") straight back to the model, which is not actionable — the model + * cannot tell WHICH parameter it dropped or that it must re-send it. + * + * THE FIX: keep the exact same validation, but replace the raw zod text with a + * model-friendly message that names every problematic parameter and tells the + * model to re-issue the call with all required parameters present. We do NOT + * guess/backfill the value (a silently-assumed "current page" could comment on + * the wrong page — cf. #159); the model is simply told to retry correctly. + * + * HOW IT WORKS: we build the tool's JSON Schema from the zod shape via + * `z.toJSONSchema(..., { target: 'draft-7' })` (so the advertised contract — + * `required` / `description` / field constraints — is unchanged) and hand the + * AI SDK a custom `validate` that runs `z.object(shape).safeParse(value)`. On + * failure the AI SDK wraps our returned `Error` in `InvalidToolInputError`, so + * our clear text is what reaches the model as the tool error. + */ +export function modelFriendlyInput( + shape: T, +): Schema>> { + const objectSchema = z.object(shape); + // draft-07 keeps required/description/constraints intact, matching what the + // model already saw — the tool contract does not change. + const json = z.toJSONSchema(objectSchema, { + target: 'draft-7', + }) as JSONSchema7; + + return jsonSchema>>(json, { + validate: (value) => { + const result = objectSchema.safeParse(value); + if (result.success) { + return { success: true, value: result.data }; + } + return { + success: false, + error: new Error(buildModelFriendlyMessage(result.error, value)), + }; + }, + }); +} + +/** + * Turn a zod validation failure into a clear, model-actionable message naming + * each problematic parameter (and whether it is missing vs. invalid), plus an + * explicit reminder not to drop required ids in parallel/batch tool calls. + */ +export function buildModelFriendlyMessage( + error: z.ZodError, + value: unknown, +): string { + const seen = new Set(); + const parts: string[] = []; + for (const issue of error.issues) { + const name = issue.path.length ? issue.path.map(String).join('.') : 'input'; + // A parameter the model omitted entirely reads as `undefined` at its path; + // anything else is present-but-invalid (wrong type, too short, etc.). + const missing = valueAtPath(value, issue.path) === undefined; + const part = `parameter "${name}": ${missing ? 'missing (required)' : 'invalid'}`; + if (seen.has(part)) continue; + seen.add(part); + parts.push(part); + } + if (parts.length === 0) { + // Defensive: a ZodError always has issues, but never emit an empty list. + parts.push('input: invalid'); + } + return ( + `Invalid input for this tool — ${parts.join('; ')}. ` + + 'Re-issue the call with EVERY required parameter present and valid. ' + + "Do not drop ids like pageId, even when making parallel/batch tool calls — " + + 'each tool call must carry its own pageId.' + ); +} + +/** Read the value at a zod issue path; returns undefined if any hop is absent. */ +function valueAtPath(value: unknown, path: ReadonlyArray): unknown { + let current: unknown = value; + for (const key of path) { + if (current === null || typeof current !== 'object') return undefined; + current = (current as Record)[key]; + } + return current; +} diff --git a/apps/server/src/core/ai-chat/tools/public-share-chat-tools.service.ts b/apps/server/src/core/ai-chat/tools/public-share-chat-tools.service.ts index 5d3e8df0..2d2da79d 100644 --- a/apps/server/src/core/ai-chat/tools/public-share-chat-tools.service.ts +++ b/apps/server/src/core/ai-chat/tools/public-share-chat-tools.service.ts @@ -5,6 +5,7 @@ import { ShareService } from '../../share/share.service'; import { SearchService } from '../../search/search.service'; import { PageRepo } from '@docmost/db/repos/page/page.repo'; import { jsonToMarkdown } from '../../../collaboration/collaboration.util'; +import { modelFriendlyInput } from './model-friendly-input'; /** * Isolated, READ-ONLY toolset for the ANONYMOUS public-share assistant. @@ -52,7 +53,7 @@ export class PublicShareChatToolsService { '(key terms and entities), not a full sentence. If the first ' + 'results look weak, search again with different wording before ' + 'answering. Only pages inside this share are ever returned.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ query: z.string().describe('The search query.'), limit: z .number() @@ -87,7 +88,7 @@ export class PublicShareChatToolsService { 'Markdown, by its page id. Returns the page title and its Markdown ' + 'content. Only pages inside this share can be read; reading any ' + 'other page fails.', - inputSchema: z.object({ + inputSchema: modelFriendlyInput({ pageId: z .string() .describe('The id (or slugId) of a page within this share.'), @@ -142,7 +143,7 @@ export class PublicShareChatToolsService { 'List the pages (titles + ids) that make up THIS published ' + 'documentation share, so you can orient yourself before reading or ' + 'searching. Only pages inside this share are listed.', - inputSchema: z.object({}), + inputSchema: modelFriendlyInput({}), execute: async () => { // Reuse the same share-tree logic the public /shares/tree route uses: // it validates the share + workspace, excludes restricted subtrees, diff --git a/apps/server/src/core/page/services/page.service.spec.ts b/apps/server/src/core/page/services/page.service.spec.ts index bb387b31..61cd74fe 100644 --- a/apps/server/src/core/page/services/page.service.spec.ts +++ b/apps/server/src/core/page/services/page.service.spec.ts @@ -57,11 +57,28 @@ describe('PageService', () => { const eventEmitter = { emit: jest.fn() }; + // movePage now runs the cycle-check + UPDATE inside executeTx(this.db), + // i.e. this.db.transaction().execute(fn => fn(trx)). A permissive chainable + // Proxy stands in for the Kysely trx so the per-space advisory-lock + // `sql``.execute(trx)` resolves; a thrown BadRequestException still + // propagates out of the transaction unchanged. + const trxStub: any = new Proxy(function () {}, { + get: (_t, p) => + p === 'then' + ? undefined + : p === 'execute' || p === 'executeTakeFirst' + ? () => Promise.resolve([]) + : () => trxStub, + }); + const db = { + transaction: () => ({ execute: (fn: any) => fn(trxStub) }), + }; + const svc = new PageService( pageRepo as any, // pageRepo {} as any, // pagePermissionRepo {} as any, // attachmentRepo - {} as any, // db + db as any, // db {} as any, // storageService {} as any, // attachmentQueue {} as any, // aiQueue @@ -268,9 +285,23 @@ describe('PageService', () => { }), updatePage: jest.fn().mockResolvedValue({ numUpdatedRows: 1n }), }; + // movePage now runs the cycle-check + UPDATE inside executeTx(this.db), + // which calls this.db.transaction().execute(fn => fn(trx)). A permissive + // chainable Proxy stands in for the Kysely trx so the per-space + // advisory-lock `sql``.execute(trx)` resolves and updatePage receives it. + const trxStub: any = new Proxy(function () {}, { + get: (_t, p) => + p === 'then' + ? undefined + : p === 'execute' || p === 'executeTakeFirst' + ? () => Promise.resolve([]) + : () => trxStub, + }); const svc = makeSvc({ pageRepo, - db: {} as any, + db: { + transaction: () => ({ execute: (fn: any) => fn(trxStub) }), + } as any, }); // Legitimate move: destination ancestors do NOT include the moved page. jest diff --git a/apps/server/src/core/page/services/page.service.ts b/apps/server/src/core/page/services/page.service.ts index ff205350..354a80fb 100644 --- a/apps/server/src/core/page/services/page.service.ts +++ b/apps/server/src/core/page/services/page.service.ts @@ -15,13 +15,13 @@ import { executeWithCursorPagination, } from '@docmost/db/pagination/cursor-pagination'; import { InjectKysely } from 'nestjs-kysely'; -import { KyselyDB } from '@docmost/db/types/kysely.types'; +import { KyselyDB, KyselyTransaction } from '@docmost/db/types/kysely.types'; import { generateJitteredKeyBetween } from 'fractional-indexing-jittered'; import { MovePageDto } from '../dto/move-page.dto'; import { shapeSidebarPagesTree } from './sidebar-pages-tree.util'; import { generateSlugId } from '../../../common/helpers'; import { getPageTitle } from '../../../common/helpers'; -import { executeTx } from '@docmost/db/utils'; +import { dbOrTx, executeTx } from '@docmost/db/utils'; import { AttachmentRepo } from '@docmost/db/repos/attachment/attachment.repo'; import { v7 as uuid7 } from 'uuid'; import { @@ -62,6 +62,23 @@ import { agentSourceFields, } from '../../../common/decorators/auth-provenance.decorator'; +// Hard upper bound on how deep the recursive page-tree CTEs (ancestor / +// descendant traversals) may walk. Real page trees are only a handful of levels +// deep, so this cap never truncates a legitimate result; it purely defends the +// recursive CTEs against runaway iteration if a parent/child cycle ever exists +// in the data (e.g. one slipped in before the move guard, #207 #8). Without it a +// cycle makes `withRecursive` loop forever (hang / statement timeout), and the +// move guard itself calls one of these CTEs — so a cycle would disable the very +// guard meant to prevent it. Each CTE carries a depth counter and stops here. +const MAX_PAGE_TREE_DEPTH = 10_000; + +// Advisory-lock namespace (the first key of pg_advisory_xact_lock) used to +// serialize concurrent page moves within a single space so the cycle check and +// the move UPDATE stay atomic (see movePage, #207 #7). A dedicated namespace +// constant keeps these locks from colliding with any other advisory lock; the +// second key is hashtext(spaceId). Fits a signed int4 ('page' in ASCII). +const PAGE_MOVE_LOCK_NAMESPACE = 0x70616765; + @Injectable() export class PageService { private readonly logger = new Logger(PageService.name); @@ -601,7 +618,13 @@ export class PageService { slugIdMap.set(entry.oldSlugId, entry); } - const attachmentMap = new Map(); + // Keyed by old attachmentId. A single attachment can be referenced by more + // than one page in the copied subtree (e.g. a block copy-pasted into a child + // page keeps the same attachmentId). Each referencing page needs its own + // fresh attachment id / row / blob copy, so the value is a LIST of copy + // entries rather than a single one — otherwise the last page's entry would + // clobber the others and their images would 404 in the copies (#206 attach-1). + const attachmentMap = new Map(); const insertablePages: InsertablePage[] = await Promise.all( pages.map(async (page) => { @@ -617,12 +640,14 @@ export class PageService { attachmentIds.forEach((attachmentId: string) => { const newPageId = pageFromMap.newPageId; const newAttachmentId = uuid7(); - attachmentMap.set(attachmentId, { + const existingEntries = attachmentMap.get(attachmentId) ?? []; + existingEntries.push({ newPageId: newPageId, oldPageId: page.id, oldAttachmentId: attachmentId, newAttachmentId: newAttachmentId, }); + attachmentMap.set(attachmentId, existingEntries); prosemirrorDoc.descendants((node: PMNode) => { if (isAttachmentNode(node.type.name)) { @@ -819,51 +844,53 @@ export class PageService { .execute(); for (const attachment of attachments) { - try { - const pageAttachment = attachmentMap.get(attachment.id); - - // make sure the copied attachment belongs to the page it was copied from - if (attachment.pageId !== pageAttachment.oldPageId) { - continue; - } - - const newAttachmentId = pageAttachment.newAttachmentId; - - const newPageId = pageAttachment.newPageId; - - const newPathFile = attachment.filePath.replace( - attachment.id, - newAttachmentId, - ); - + // One source attachment may need to be copied for several destination + // pages (it is referenced by more than one page in the subtree). Copy a + // distinct blob + row for every referencing page so each copy resolves + // (#206 attach-1). The old per-page ownership guard is gone: when the + // same attachmentId is shared, only one page would ever match the row's + // pageId, silently dropping the other copies. + const pageAttachments = attachmentMap.get(attachment.id) ?? []; + for (const pageAttachment of pageAttachments) { try { - await this.storageService.copy(attachment.filePath, newPathFile); + const newAttachmentId = pageAttachment.newAttachmentId; - await this.db - .insertInto('attachments') - .values({ - id: newAttachmentId, - type: attachment.type, - filePath: newPathFile, - fileName: attachment.fileName, - fileSize: attachment.fileSize, - mimeType: attachment.mimeType, - fileExt: attachment.fileExt, - creatorId: attachment.creatorId, - workspaceId: attachment.workspaceId, - pageId: newPageId, - spaceId: spaceId, - }) - .execute(); - } catch (err) { - this.logger.error( - `Duplicate page: failed to copy attachment ${attachment.id}`, - err, + const newPageId = pageAttachment.newPageId; + + const newPathFile = attachment.filePath.replace( + attachment.id, + newAttachmentId, ); - // Continue with other attachments even if one fails + + try { + await this.storageService.copy(attachment.filePath, newPathFile); + + await this.db + .insertInto('attachments') + .values({ + id: newAttachmentId, + type: attachment.type, + filePath: newPathFile, + fileName: attachment.fileName, + fileSize: attachment.fileSize, + mimeType: attachment.mimeType, + fileExt: attachment.fileExt, + creatorId: attachment.creatorId, + workspaceId: attachment.workspaceId, + pageId: newPageId, + spaceId: spaceId, + }) + .execute(); + } catch (err) { + this.logger.error( + `Duplicate page: failed to copy attachment ${attachment.id}`, + err, + ); + // Continue with other attachments even if one fails + } + } catch (err) { + this.logger.error(err); } - } catch (err) { - this.logger.error(err); } } } @@ -915,34 +942,61 @@ export class PageService { } } - // Server-side cycle guard: a page may not be moved into itself or into any - // page within its own subtree. Without this, an MCP/REST/agent caller (or a - // fast drag racing the client check) could persist a cycle and broadcast it. - // Only relevant when re-parenting under a concrete parent; moving to root - // (parentPageId null/undefined) can never create a cycle. - if (dto.parentPageId) { - if (dto.parentPageId === dto.pageId) { - throw new BadRequestException('Cannot move a page into its own subtree'); - } - // Walk the destination parent's ancestor chain (reusing the breadcrumb - // ancestor CTE). If the page being moved appears among those ancestors, - // the destination lives inside the moved page's subtree -> cycle. - const destAncestors = await this.getPageBreadCrumbs(dto.parentPageId); - if (destAncestors.some((ancestor) => ancestor.id === dto.pageId)) { - throw new BadRequestException('Cannot move a page into its own subtree'); - } - } + // Server-side cycle guard + the move UPDATE run in ONE transaction. A page + // may not be moved into itself or into any page within its own subtree; + // without this an MCP/REST/agent caller (or a fast drag racing the client + // check) could persist a cycle and broadcast it. Crucially, doing the guard + // and the write as two separate, unlocked statements is a TOCTOU race: two + // concurrent moves ("A under B" and "B under A") can each read the same + // pre-write acyclic snapshot, both pass the guard, then persist + // A.parentPageId=B AND B.parentPageId=A — a parent/child cycle (#207 #7). A + // per-space advisory lock (held until COMMIT) serializes all moves within a + // space: the second mover blocks until the first commits and then sees the + // freshly written parent, so its guard rejects the cycle. + const updateResult = await executeTx(this.db, async (trx) => { + await sql`select pg_advisory_xact_lock(${sql.lit( + PAGE_MOVE_LOCK_NAMESPACE, + )}, hashtext(${movedPage.spaceId}))`.execute(trx); - const updateResult = await this.pageRepo.updatePage( - { - position: dto.position, - parentPageId: parentPageId, - // Agent-edit provenance: annotate the source on an agent move. A normal - // user request leaves the existing source value unchanged. - ...agentSourceFields(provenance, 'lastUpdatedSource', 'lastUpdatedAiChatId'), - }, - dto.pageId, - ); + // Only relevant when re-parenting under a concrete parent; moving to root + // (parentPageId null/undefined) can never create a cycle. + if (dto.parentPageId) { + if (dto.parentPageId === dto.pageId) { + throw new BadRequestException( + 'Cannot move a page into its own subtree', + ); + } + // Walk the destination parent's ancestor chain (reusing the breadcrumb + // ancestor CTE) inside the lock. If the page being moved appears among + // those ancestors, the destination lives inside the moved page's + // subtree -> cycle. + const destAncestors = await this.getPageBreadCrumbs( + dto.parentPageId, + trx, + ); + if (destAncestors.some((ancestor) => ancestor.id === dto.pageId)) { + throw new BadRequestException( + 'Cannot move a page into its own subtree', + ); + } + } + + return this.pageRepo.updatePage( + { + position: dto.position, + parentPageId: parentPageId, + // Agent-edit provenance: annotate the source on an agent move. A + // normal user request leaves the existing source value unchanged. + ...agentSourceFields( + provenance, + 'lastUpdatedSource', + 'lastUpdatedAiChatId', + ), + }, + dto.pageId, + trx, + ); + }); // Guard against a phantom broadcast: if the row was concurrently deleted or // otherwise not updated, skip the PAGE_MOVED event so we don't replay a move @@ -981,8 +1035,8 @@ export class PageService { }); } - async getPageBreadCrumbs(childPageId: string) { - const ancestors = await this.db + async getPageBreadCrumbs(childPageId: string, trx?: KyselyTransaction) { + const ancestors = await dbOrTx(this.db, trx) .withRecursive('page_ancestors', (db) => db .selectFrom('pages') @@ -996,6 +1050,9 @@ export class PageService { 'spaceId', 'deletedAt', ]) + // Depth counter: bounds the walk so a parent/child cycle in the data + // can't make this recursive CTE loop forever (#207 #8). + .select(sql`0`.as('depth')) .where('id', '=', childPageId) .where('deletedAt', 'is', null) .unionAll((exp) => @@ -1011,12 +1068,25 @@ export class PageService { 'p.spaceId', 'p.deletedAt', ]) + .select(sql`pa.depth + 1`.as('depth')) .innerJoin('page_ancestors as pa', 'pa.parentPageId', 'p.id') - .where('p.deletedAt', 'is', null), + .where('p.deletedAt', 'is', null) + .where(sql`pa.depth`, '<', MAX_PAGE_TREE_DEPTH), ), ) .selectFrom('page_ancestors') - .selectAll('page_ancestors') + // Explicit column list (not selectAll) so the internal `depth` counter + // never leaks into the breadcrumb result shape. + .select([ + 'id', + 'slugId', + 'title', + 'icon', + 'position', + 'parentPageId', + 'spaceId', + 'deletedAt', + ]) .select((eb) => eb .exists( @@ -1137,16 +1207,21 @@ export class PageService { db .selectFrom('pages') .select(['id']) + // Depth counter: bounds the walk so a parent/child cycle in the data + // can't make this recursive CTE loop forever (#207 #8). + .select(sql`0`.as('depth')) .where('id', '=', pageId) .unionAll((exp) => exp .selectFrom('pages as p') .select(['p.id']) - .innerJoin('page_descendants as pd', 'pd.id', 'p.parentPageId'), + .select(sql`pd.depth + 1`.as('depth')) + .innerJoin('page_descendants as pd', 'pd.id', 'p.parentPageId') + .where(sql`pd.depth`, '<', MAX_PAGE_TREE_DEPTH), ), ) .selectFrom('page_descendants') - .selectAll() + .select(['id']) .execute(); const pageIds = descendants.map((d) => d.id); diff --git a/apps/server/test/integration/duplicate-page-shared-attachment.int-spec.ts b/apps/server/test/integration/duplicate-page-shared-attachment.int-spec.ts new file mode 100644 index 00000000..8b61fbd1 --- /dev/null +++ b/apps/server/test/integration/duplicate-page-shared-attachment.int-spec.ts @@ -0,0 +1,207 @@ +import { randomUUID } from 'node:crypto'; +import { Kysely } from 'kysely'; +import { PageRepo } from '@docmost/db/repos/page/page.repo'; +import { PagePermissionRepo } from '@docmost/db/repos/page/page-permission.repo'; +import { PageService } from 'src/core/page/services/page.service'; +import { + getTestDb, + destroyTestDb, + createWorkspace, + createSpace, + createUser, +} from './db'; + +/** + * #206 attach-1 — Duplicating a subtree where the SAME attachment is referenced + * by more than one page must copy a working blob/row for EVERY copy, not just + * the last page processed. + * + * Setup: root page A and child page B both embed the same image (attachmentId X, + * the attachment row owned by A in the DB). Duplicating A produces copies A' and + * B'. Before the fix the per-attachmentId map held a single entry, so B's entry + * clobbered A's and the row-ownership guard (`attachment.pageId !== oldPageId`) + * then skipped the only DB row entirely: zero blobs copied, zero new rows, both + * copies' images 404. The fix keys the map to a LIST and copies once per + * referencing page, dropping the broken guard. + * + * This drives the real PageService.duplicatePage against a real Postgres with a + * recording storage stub, and asserts: storage.copy called twice and two fresh + * attachment rows exist (one owned by A', one by B'), each matching the rewritten + * attachmentId in its page's content. + */ +describe('PageService.duplicatePage shared attachment [integration]', () => { + let db: Kysely; + let pageRepo: PageRepo; + let pagePermissionRepo: PagePermissionRepo; + let pageService: PageService; + let workspaceId: string; + let spaceId: string; + let userId: string; + + // Records every (source, dest) blob copy the service requests. + const copyCalls: Array<{ from: string; to: string }> = []; + const storageService = { + copy: async (from: string, to: string) => { + copyCalls.push({ from, to }); + }, + } as any; + + // Duplicate persists transclusion/reference rows in best-effort try/catch + // blocks; a no-op stub keeps the harness focused on the attachment path. + const transclusionService = { + insertTransclusionsForPages: async () => {}, + insertReferencesForPages: async () => {}, + insertTemplateReferencesForPages: async () => {}, + } as any; + + const eventEmitter = { emit: () => true } as any; + + function imageDoc(attachmentId: string) { + return { + type: 'doc', + content: [ + { + type: 'image', + attrs: { + attachmentId, + src: `/api/files/${attachmentId}/image.png`, + width: '100%', + align: 'center', + }, + }, + ], + }; + } + + beforeAll(async () => { + db = getTestDb(); + pageRepo = new PageRepo(db as any, {} as any, eventEmitter); + // filterAccessiblePageIds short-circuits to the input ids when the space has + // no restricted pages, so groupRepo/cache (2nd/3rd ctor args) are never hit. + pagePermissionRepo = new PagePermissionRepo( + db as any, + {} as any, + {} as any, + ); + pageService = new PageService( + pageRepo, + pagePermissionRepo, + undefined as any, // attachmentRepo (unused on duplicate path) + db as any, + storageService, + undefined as any, // attachmentQueue + undefined as any, // aiQueue + undefined as any, // generalQueue + eventEmitter, + undefined as any, // collaborationGateway + undefined as any, // watcherService + transclusionService, + ); + + workspaceId = (await createWorkspace(db)).id; + spaceId = (await createSpace(db, workspaceId)).id; + userId = (await createUser(db, workspaceId)).id; + }); + + afterAll(async () => { + await destroyTestDb(); + }); + + it('copies a shared attachment for every page that references it', async () => { + copyCalls.length = 0; + + const attachmentId = randomUUID(); + const pageAId = randomUUID(); + const pageBId = randomUUID(); + + // Root A and child B both embed the same attachmentId. + await db + .insertInto('pages') + .values({ + id: pageAId, + slugId: `a-${pageAId.slice(0, 8)}`, + title: 'A', + content: imageDoc(attachmentId) as any, + position: 'a0', + spaceId, + workspaceId, + creatorId: userId, + }) + .execute(); + await db + .insertInto('pages') + .values({ + id: pageBId, + slugId: `b-${pageBId.slice(0, 8)}`, + title: 'B', + content: imageDoc(attachmentId) as any, + position: 'a0', + parentPageId: pageAId, + spaceId, + workspaceId, + creatorId: userId, + }) + .execute(); + + // Single attachment row, owned by A. + await db + .insertInto('attachments') + .values({ + id: attachmentId, + type: 'image', + filePath: `${spaceId}/${attachmentId}/image.png`, + fileName: 'image.png', + fileExt: 'png', + mimeType: 'image/png', + creatorId: userId, + workspaceId, + pageId: pageAId, + spaceId, + }) + .execute(); + + const rootPage = await pageRepo.findById(pageAId); + const result = await pageService.duplicatePage( + rootPage as any, + undefined, + { id: userId, workspaceId } as any, + ); + + const newRootId = result.id; + const newChildIds = result.childPageIds; + expect(newChildIds).toHaveLength(1); + const newChildId = newChildIds[0]; + + // Both pages' images were copied: one blob per referencing page. + expect(copyCalls).toHaveLength(2); + + // Two fresh attachment rows exist, one owned by each copied page. + const newAttachments = await db + .selectFrom('attachments') + .selectAll() + .where('pageId', 'in', [newRootId, newChildId]) + .where('workspaceId', '=', workspaceId) + .execute(); + expect(newAttachments).toHaveLength(2); + + const ownerIds = newAttachments.map((a) => a.pageId).sort(); + expect(ownerIds).toEqual([newRootId, newChildId].sort()); + + // Each copied page's content points at a rewritten attachmentId that now has + // a real row (i.e. the image src resolves instead of 404ing). + for (const pageId of [newRootId, newChildId]) { + const page = await db + .selectFrom('pages') + .select(['content']) + .where('id', '=', pageId) + .executeTakeFirstOrThrow(); + const node = (page.content as any).content[0]; + expect(node.type).toBe('image'); + const referencedId = node.attrs.attachmentId; + expect(referencedId).not.toBe(attachmentId); // remapped to a fresh id + const row = newAttachments.find((a) => a.id === referencedId); + expect(row).toBeDefined(); + expect(row!.pageId).toBe(pageId); + } + }); +}); diff --git a/apps/server/test/integration/page-move-cycle.int-spec.ts b/apps/server/test/integration/page-move-cycle.int-spec.ts new file mode 100644 index 00000000..be7dd5c6 --- /dev/null +++ b/apps/server/test/integration/page-move-cycle.int-spec.ts @@ -0,0 +1,133 @@ +import { Kysely } from 'kysely'; +import { generateJitteredKeyBetween } from 'fractional-indexing-jittered'; +import { PageRepo } from '@docmost/db/repos/page/page.repo'; +import { PageService } from 'src/core/page/services/page.service'; +import { Page } from '@docmost/db/types/entity.types'; +import { + getTestDb, + destroyTestDb, + createWorkspace, + createSpace, + createPage, +} from './db'; + +/** + * #207 #7 — TOCTOU in PageService.movePage: two concurrent moves + * ("A under B" + "B under A") must NOT be able to persist a parent/child cycle. + * + * Before the fix the cycle check (getPageBreadCrumbs) and the UPDATE were two + * separate, unlocked statements, so both movers could read the same pre-write + * acyclic snapshot, both pass the guard, and persist A.parentPageId=B AND + * B.parentPageId=A. The fix runs the guard + UPDATE in one transaction behind a + * per-space advisory lock, so the moves serialize: whichever commits second + * sees the first's write and its guard rejects the cycle. + * + * This test drives the real PageService.movePage against a real Postgres, + * firing the two opposing moves concurrently, and asserts that no cycle ever + * persists (walking parentPageId from both pages always reaches a root with no + * repeated id) and that exactly one of the two opposing moves is rejected. + */ +describe('PageService.movePage concurrent A<->B cycle guard [integration]', () => { + let db: Kysely; + let pageRepo: PageRepo; + let pageService: PageService; + let workspaceId: string; + let spaceId: string; + + // A valid fractional-index position key; movePage validates the position. + const position = generateJitteredKeyBetween(null, null); + + beforeAll(async () => { + db = getTestDb(); + // Event emission is a side effect movePage performs but the cycle behaviour + // does not depend on; a no-op emitter keeps the harness minimal. + const eventEmitter = { emit: () => true } as any; + pageRepo = new PageRepo(db as any, {} as any, eventEmitter); + // Only pageRepo (1), db (4) and eventEmitter (9) are touched by movePage; + // the remaining constructor deps are unused on this path. + pageService = new PageService( + pageRepo, + undefined as any, + undefined as any, + db as any, + undefined as any, + undefined as any, + undefined as any, + undefined as any, + eventEmitter, + undefined as any, + undefined as any, + undefined as any, + ); + + workspaceId = (await createWorkspace(db)).id; + spaceId = (await createSpace(db, workspaceId)).id; + }); + + afterAll(async () => { + await destroyTestDb(); + }); + + async function findPage(id: string): Promise { + const page = await pageRepo.findById(id); + if (!page) throw new Error(`page ${id} not found`); + return page; + } + + // Walk parentPageId upward from startId. Throws if a node repeats (cycle) or + // the walk fails to terminate; returns normally only when a root is reached. + async function assertReachesRoot(startId: string): Promise { + const seen = new Set(); + let cur: string | null = startId; + let steps = 0; + while (cur) { + if (seen.has(cur)) { + throw new Error(`cycle detected: revisited ${cur}`); + } + seen.add(cur); + const row: { parentPageId: string | null } | undefined = await db + .selectFrom('pages') + .select('parentPageId') + .where('id', '=', cur) + .executeTakeFirst(); + cur = row?.parentPageId ?? null; + if (++steps > 1000) { + throw new Error('parent walk did not terminate'); + } + } + } + + it('two opposing concurrent moves never persist a parent/child cycle', async () => { + // Repeat to exercise different scheduler interleavings of the two moves. + for (let i = 0; i < 8; i++) { + const a = await createPage(db, { workspaceId, spaceId, title: `A-${i}` }); + const b = await createPage(db, { workspaceId, spaceId, title: `B-${i}` }); + + const movedA = await findPage(a.id); + const movedB = await findPage(b.id); + + const results = await Promise.allSettled([ + pageService.movePage( + { pageId: a.id, parentPageId: b.id, position } as any, + movedA, + ), + pageService.movePage( + { pageId: b.id, parentPageId: a.id, position } as any, + movedB, + ), + ]); + + // No cycle may have been persisted by either ordering. + await assertReachesRoot(a.id); + await assertReachesRoot(b.id); + + // The serialization guarantees exactly one of the opposing moves wins; + // the other must be rejected as a subtree cycle. + const rejected = results.filter( + (r): r is PromiseRejectedResult => r.status === 'rejected', + ); + expect(rejected).toHaveLength(1); + expect(rejected[0].reason?.message).toMatch(/into its own subtree/); + } + }); +}); diff --git a/apps/server/test/integration/page-recursive-cte-cycle-guard.int-spec.ts b/apps/server/test/integration/page-recursive-cte-cycle-guard.int-spec.ts new file mode 100644 index 00000000..a415edb7 --- /dev/null +++ b/apps/server/test/integration/page-recursive-cte-cycle-guard.int-spec.ts @@ -0,0 +1,134 @@ +import { CamelCasePlugin, Kysely } from 'kysely'; +import { PostgresJSDialect } from 'kysely-postgres-js'; +import * as postgres from 'postgres'; +import { PageService } from 'src/core/page/services/page.service'; +import { + getTestDb, + destroyTestDb, + createWorkspace, + createSpace, + createPage, + TEST_DATABASE_URL, +} from './db'; + +/** + * #207 #8 — recursive page-tree CTEs (ancestors in getPageBreadCrumbs, + * descendants in forceDelete) must not hang when a parent/child cycle already + * exists in the data. Before the fix neither CTE had a CYCLE clause or a depth + * cap, so a cycle (e.g. one persisted by the #7 TOCTOU race) made withRecursive + * loop forever — and since the move guard itself runs the ancestor CTE, a cycle + * would disable the very guard meant to prevent it. + * + * The fix adds a depth counter bounded by MAX_PAGE_TREE_DEPTH to both CTEs. + * These tests seed an A<->B cycle directly (bypassing the guard), then run the + * real CTE paths against Postgres with a short connection-level statement_timeout + * so a regression (an unbounded CTE) fails fast as a query timeout instead of a + * bounded result. + */ +describe('recursive page-tree CTEs cycle/depth guard [integration]', () => { + // Upper bound on rows the depth-capped CTEs can emit for a 2-node cycle: one + // row per depth level 0..MAX. Kept loose so the assertion does not couple to + // the exact constant, only to "bounded". + const BOUNDED_MAX_ROWS = 20_000; + + let db: Kysely; + // Dedicated Kysely whose connections carry a short statement_timeout, so an + // unbounded recursive CTE aborts quickly instead of hanging the suite. + let timeoutDb: Kysely; + let workspaceId: string; + let spaceId: string; + + beforeAll(async () => { + db = getTestDb(); + timeoutDb = new Kysely({ + dialect: new PostgresJSDialect({ + postgres: postgres(TEST_DATABASE_URL, { + max: 2, + onnotice: () => {}, + // Applied to every connection on connect: cap any single statement. + connection: { statement_timeout: 4000 }, + types: { + bigint: { + to: 20, + from: [20, 1700], + serialize: (value: number) => value.toString(), + parse: (value: string) => Number.parseInt(value), + }, + }, + }), + }), + plugins: [new CamelCasePlugin()], + }); + workspaceId = (await createWorkspace(db)).id; + spaceId = (await createSpace(db, workspaceId)).id; + }); + + afterAll(async () => { + await timeoutDb.destroy(); + await destroyTestDb(); + }); + + // Seed two fresh pages and wire them into a direct parent/child cycle, + // bypassing PageService.movePage's guard the way the #7 race would. + async function seedCycle(): Promise<{ aId: string; bId: string }> { + const a = await createPage(db, { workspaceId, spaceId, title: 'cycle-A' }); + const b = await createPage(db, { workspaceId, spaceId, title: 'cycle-B' }); + await db + .updateTable('pages') + .set({ parentPageId: b.id }) + .where('id', '=', a.id) + .execute(); + await db + .updateTable('pages') + .set({ parentPageId: a.id }) + .where('id', '=', b.id) + .execute(); + return { aId: a.id, bId: b.id }; + } + + function makeService(database: Kysely): PageService { + const eventEmitter = { emit: () => true } as any; + const attachmentQueue = { add: async () => undefined } as any; + return new PageService( + undefined as any, // pageRepo (unused by these paths) + undefined as any, // pagePermissionRepo + undefined as any, // attachmentRepo + database as any, // db + undefined as any, // storageService + attachmentQueue, // attachmentQueue + undefined as any, // aiQueue + undefined as any, // generalQueue + eventEmitter, // eventEmitter + undefined as any, // collaborationGateway + undefined as any, // watcherService + undefined as any, // transclusionService + ); + } + + it('getPageBreadCrumbs returns a bounded result (no hang) when a cycle exists', async () => { + const { aId } = await seedCycle(); + const service = makeService(timeoutDb); + + // Must resolve (the depth cap stops the walk) rather than time out. + const crumbs = await service.getPageBreadCrumbs(aId); + + expect(Array.isArray(crumbs)).toBe(true); + expect(crumbs.length).toBeGreaterThan(1); + expect(crumbs.length).toBeLessThanOrEqual(BOUNDED_MAX_ROWS); + }); + + it('forceDelete descendant CTE is bounded (no hang) and removes the cyclic pages', async () => { + const { aId, bId } = await seedCycle(); + const service = makeService(timeoutDb); + + // Must complete instead of looping on the descendant CTE. + await service.forceDelete(aId, workspaceId); + + const survivors = await db + .selectFrom('pages') + .select('id') + .where('id', 'in', [aId, bId]) + .execute(); + expect(survivors).toHaveLength(0); + }); +}); diff --git a/packages/editor-ext/src/lib/unique-id/unique-id.util.test.ts b/packages/editor-ext/src/lib/unique-id/unique-id.util.test.ts new file mode 100644 index 00000000..24d30408 --- /dev/null +++ b/packages/editor-ext/src/lib/unique-id/unique-id.util.test.ts @@ -0,0 +1,103 @@ +import { describe, it, expect } from "vitest"; +import StarterKit from "@tiptap/starter-kit"; +import { addUniqueIdsToDoc } from "./unique-id.util"; +import { UniqueID } from "./unique-id"; +import { TransclusionSource } from "../transclusion/transclusion-source"; + +// Minimal extension set: StarterKit (paragraph/heading) + the UniqueID config +// the server uses for the addressing anchors. +const extensions = [ + StarterKit, + UniqueID.configure({ types: ["heading", "paragraph"] }), +]; + +// `transclusionSource` is also an addressed type, but its id is a cross-reference +// KEY (a transclusionReference / the page_transclusions table resolves a source +// by it), so it lives in the NO_REASSIGN set: a missing id is filled, a colliding +// id is NOT reassigned (rewriting it would orphan its references). +const extensionsWithSource = [ + StarterKit, + // Narrow the content expression to `paragraph+` so the schema builds from + // StarterKit alone (the real allow-list references image/table/etc. nodes this + // minimal harness doesn't register). The node name — what NO_REASSIGN keys on + // — is unchanged. + TransclusionSource.extend({ content: "paragraph+" }), + UniqueID.configure({ + types: ["heading", "paragraph", "transclusionSource"], + }), +]; + +const para = (id: string | undefined, text: string) => ({ + type: "paragraph", + ...(id !== undefined ? { attrs: { id } } : {}), + content: [{ type: "text", text }], +}); + +const source = (id: string | undefined, text: string) => ({ + type: "transclusionSource", + ...(id !== undefined ? { attrs: { id } } : {}), + // The schema requires at least one block child (content expression is `+`). + content: [{ type: "paragraph", content: [{ type: "text", text }] }], +}); + +const ids = (doc: any): (string | undefined)[] => + (doc.content ?? []).map((n: any) => n.attrs?.id); + +describe("addUniqueIdsToDoc", () => { + it("fills ids on nodes that are missing one", () => { + const doc = { type: "doc", content: [para(undefined, "a"), para(undefined, "b")] }; + const out = addUniqueIdsToDoc(doc, extensions); + const [a, b] = ids(out); + expect(a).toBeTruthy(); + expect(b).toBeTruthy(); + expect(a).not.toBe(b); + }); + + it("deduplicates two nodes that share the same id (#206 editor-pm-7)", () => { + // A copy/paste or bulk-JSON duplicate keeps the original id on both nodes. + const doc = { + type: "doc", + content: [para("dup", "first"), para("dup", "second")], + }; + const out = addUniqueIdsToDoc(doc, extensions); + const [first, second] = ids(out); + // The first occurrence keeps the id (stable anchor); the duplicate is + // reassigned a fresh one so MCP addressing can't hit the wrong/both nodes. + expect(first).toBe("dup"); + expect(second).toBeTruthy(); + expect(second).not.toBe("dup"); + }); + + it("leaves already-unique ids untouched", () => { + const doc = { + type: "doc", + content: [para("x1", "first"), para("x2", "second")], + }; + const out = addUniqueIdsToDoc(doc, extensions); + expect(ids(out)).toEqual(["x1", "x2"]); + }); + + it("does NOT reassign a colliding transclusionSource id — BOTH keep it (NO_REASSIGN)", () => { + // Two sync-block sources sharing an id: rewriting either would orphan the + // transclusionReferences / page_transclusions rows that resolve a source by + // this key, so the dedupe MUST leave both ids intact. If the NO_REASSIGN + // guard is removed, the second source is reassigned a fresh id and this fails. + const doc = { + type: "doc", + content: [source("src", "first"), source("src", "second")], + }; + const out = addUniqueIdsToDoc(doc, extensionsWithSource); + const [first, second] = ids(out); + expect(first).toBe("src"); + expect(second).toBe("src"); + }); + + it("still FILLS a missing id on a transclusionSource (only reassignment is suppressed)", () => { + // NO_REASSIGN suppresses dedupe of an EXISTING id, not filling a missing one: + // a source with no id still needs a key its references can resolve. + const doc = { type: "doc", content: [source(undefined, "only")] }; + const out = addUniqueIdsToDoc(doc, extensionsWithSource); + const [id] = ids(out); + expect(id).toBeTruthy(); + }); +}); diff --git a/packages/editor-ext/src/lib/unique-id/unique-id.util.ts b/packages/editor-ext/src/lib/unique-id/unique-id.util.ts index 8d1991ed..88e81324 100644 --- a/packages/editor-ext/src/lib/unique-id/unique-id.util.ts +++ b/packages/editor-ext/src/lib/unique-id/unique-id.util.ts @@ -59,18 +59,44 @@ export function addUniqueIdsToDoc( ]); const contentNode = Node.fromJSON(schema, doc); - // Find nodes that don't have a unique ID - const nodesWithoutId = findChildren(contentNode, (node) => { - return !node.attrs[attributeName] && types.includes(node.type.name); + // All nodes of the configured types, in document order, so that the FIRST + // occurrence of any given id keeps it and later duplicates get reassigned. + const idNodes = findChildren(contentNode, (node) => { + return types.includes(node.type.name); }); - // Edit the document to add unique IDs to the nodes that don't have a unique ID + // `transclusionSource` ids are cross-reference keys (a transclusionReference / + // the page_transclusions table resolves a source by this id), so rewriting one + // would orphan its references. We only fill a MISSING id for those, never + // reassign an existing one; plain block anchors (heading/paragraph) are safe to + // dedupe. + const NO_REASSIGN = new Set(["transclusionSource"]); + + // Edit the document to (a) add ids where missing and (b) dedupe collisions. A + // duplicate id otherwise lets copy/paste/import produce two nodes sharing an + // id, so MCP addressed edits (patch_node / delete_node "before/after id") hit + // the wrong node or both (#206 editor-pm-7). This previously only filled + // missing ids and never deduplicated existing ones. + const seenIds = new Set(); let tr = EditorState.create({ doc: contentNode, }).tr; // eslint-disable-next-line no-restricted-syntax - for (const { node, pos } of nodesWithoutId) { - tr = tr.setNodeAttribute(pos, attributeName, generateID({ node, pos })); + for (const { node, pos } of idNodes) { + const currentId = node.attrs[attributeName]; + const isDuplicate = currentId != null && seenIds.has(currentId); + const needsNewId = + currentId == null || (isDuplicate && !NO_REASSIGN.has(node.type.name)); + + if (needsNewId) { + // setNodeAttribute only changes attributes (no size change), so positions + // from the original node stay valid across the whole loop. + const newId = generateID({ node, pos }); + tr = tr.setNodeAttribute(pos, attributeName, newId); + seenIds.add(newId); + } else if (currentId != null) { + seenIds.add(currentId); + } } // Return the updated document