diff --git a/.gitignore b/.gitignore index e814fb29..4c966923 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ lerna-debug.log* # TypeScript incremental build artifacts *.tsbuildinfo +apps/client/coverage/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 43255596..e85d3ef6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,20 @@ embeds — plus a large batch of security hardening and test coverage. - **Voice dictation (STT)**: server-side speech-to-text with a mic button in the chat and the editor, OpenRouter STT support, an endpoint test, and real provider-error surfacing. +- **Realtime streaming dictation**: a new live-dictation mic mode layered on top + of the existing batch STT. Audio streams over a dedicated `/ai-realtime` + Socket.IO namespace and text is inserted as you speak (interim partials shown + as a ghost decoration, only finals committed to the document). Gated by a new + `dictationRealtime` workspace toggle, with `sttRealtimeModel` and + `sttRealtimeBaseUrl` settings (empty model falls back to `sttModel`; empty base + URL falls back to the STT base URL server-side). + - **Ops caveat (single-process assumption):** the realtime concurrency caps + (1 concurrent session per user, 5 per workspace) are enforced **in-memory, + per API process**. They are therefore authoritative only on a **single API + replica** — running multiple API instances (horizontal scale / load + balancing) lets a user or workspace exceed these caps, since each process + counts only its own sessions. Treat the limits as per-process until the + counters are moved to a shared store. - **Footnotes**: an editor footnotes model (inline references + a definitions list). - **Page templates**: live whole-page embed (MVP) with a template-marker icon diff --git a/apps/client/package.json b/apps/client/package.json index 0433c97f..3bfb9d87 100644 --- a/apps/client/package.json +++ b/apps/client/package.json @@ -79,6 +79,7 @@ "@types/react": "18.3.12", "@types/react-dom": "18.3.1", "@vitejs/plugin-react": "6.0.1", + "@vitest/coverage-v8": "4.1.6", "eslint": "9.28.0", "eslint-plugin-react": "7.37.5", "eslint-plugin-react-hooks": "7.0.1", diff --git a/apps/client/src/features/ai-chat/components/chat-input.append.test.ts b/apps/client/src/features/ai-chat/components/chat-input.append.test.ts new file mode 100644 index 00000000..f60bc6b7 --- /dev/null +++ b/apps/client/src/features/ai-chat/components/chat-input.append.test.ts @@ -0,0 +1,26 @@ +import { describe, it, expect } from "vitest"; +import { appendFinalToDraft } from "./chat-input"; + +describe("appendFinalToDraft", () => { + it("an empty draft becomes the final verbatim", () => { + expect(appendFinalToDraft("", "hello")).toBe("hello"); + }); + + it("a non-empty draft gets the final appended with exactly one space", () => { + expect(appendFinalToDraft("draft", "final")).toBe("draft final"); + }); + + it("never introduces a leading or double space", () => { + const out = appendFinalToDraft("draft", "final"); + expect(out.startsWith(" ")).toBe(false); + expect(out).not.toContain(" "); + }); + + it("accumulates left-to-right across repeated calls", () => { + let draft = ""; + draft = appendFinalToDraft(draft, "a"); + draft = appendFinalToDraft(draft, "b"); + draft = appendFinalToDraft(draft, "c"); + expect(draft).toBe("a b c"); + }); +}); diff --git a/apps/client/src/features/ai-chat/components/chat-input.tsx b/apps/client/src/features/ai-chat/components/chat-input.tsx index 2728e7cf..1bbb689a 100644 --- a/apps/client/src/features/ai-chat/components/chat-input.tsx +++ b/apps/client/src/features/ai-chat/components/chat-input.tsx @@ -22,6 +22,16 @@ interface ChatInputProps { disabled?: boolean; } +/** + * Merge a finalized dictation segment into the existing draft. Pure + + * unit-testable. An empty draft becomes the final verbatim; a non-empty draft + * gets the final appended with exactly one space separator. Repeated calls + * accumulate left-to-right ("a" then "b" -> "a b"). + */ +export function appendFinalToDraft(draft: string, final: string): string { + return draft ? `${draft} ${final}` : final; +} + /** * Message composer. Enter sends, Shift+Enter inserts a newline. While the agent * is streaming, the send button becomes a Stop button (calls `stop()`); the @@ -82,7 +92,7 @@ export default function ChatInput({ disabled={isStreaming || disabled} onInterim={(text) => setInterim(text)} onFinal={(text) => { - setValue((v) => (v ? `${v} ${text}` : text)); + setValue((v) => appendFinalToDraft(v, text)); setInterim(""); }} /> @@ -90,7 +100,7 @@ export default function ChatInput({ setValue((v) => (v ? `${v} ${text}` : text))} + onText={(text) => setValue((v) => appendFinalToDraft(v, text))} /> ))} {isStreaming ? ( diff --git a/apps/client/src/features/dictation/audio/mic-capture.test.ts b/apps/client/src/features/dictation/audio/mic-capture.test.ts new file mode 100644 index 00000000..a7cc9a50 --- /dev/null +++ b/apps/client/src/features/dictation/audio/mic-capture.test.ts @@ -0,0 +1,87 @@ +import { describe, it, expect } from "vitest"; +import { + mapGetUserMediaError, + canStartCapture, + MicUnavailableError, +} from "./mic-capture"; + +// Identity translator so assertions read the source key (the i18n layer is not +// under test here). +const t = (k: string) => k; + +describe("mapGetUserMediaError", () => { + it("maps NotAllowedError / SecurityError to denied", () => { + expect(mapGetUserMediaError({ name: "NotAllowedError" }, t)).toBe( + "Microphone access denied", + ); + expect(mapGetUserMediaError({ name: "SecurityError" }, t)).toBe( + "Microphone access denied", + ); + }); + + it("maps NotFoundError / OverconstrainedError to not found", () => { + expect(mapGetUserMediaError({ name: "NotFoundError" }, t)).toBe( + "No microphone found", + ); + expect(mapGetUserMediaError({ name: "OverconstrainedError" }, t)).toBe( + "No microphone found", + ); + }); + + it("maps NotReadableError / AbortError to in-use", () => { + expect(mapGetUserMediaError({ name: "NotReadableError" }, t)).toBe( + "Microphone is unavailable or already in use", + ); + expect(mapGetUserMediaError({ name: "AbortError" }, t)).toBe( + "Microphone is unavailable or already in use", + ); + }); + + it("falls back to a detailed message for unknown errors", () => { + const msg = mapGetUserMediaError( + { name: "WeirdError", message: "boom" }, + t, + ); + expect(msg).toContain("Could not start recording"); + expect(msg).toContain("WeirdError"); + expect(msg).toContain("boom"); + }); + + it("falls back without a name", () => { + const msg = mapGetUserMediaError(new Error("nope"), t); + expect(msg).toContain("Could not start recording"); + expect(msg).toContain("nope"); + }); +}); + +describe("canStartCapture", () => { + const base = { + starting: false, + hasStream: false, + hasLiveResource: false, + statusIsIdle: true, + }; + it("allows when idle and nothing live", () => { + expect(canStartCapture(base)).toBe(true); + }); + it("blocks while already starting", () => { + expect(canStartCapture({ ...base, starting: true })).toBe(false); + }); + it("blocks when a stream is live", () => { + expect(canStartCapture({ ...base, hasStream: true })).toBe(false); + }); + it("blocks when a downstream resource is live", () => { + expect(canStartCapture({ ...base, hasLiveResource: true })).toBe(false); + }); + it("blocks when status is not idle", () => { + expect(canStartCapture({ ...base, statusIsIdle: false })).toBe(false); + }); +}); + +describe("MicUnavailableError", () => { + it("is identifiable via instanceof", () => { + const e = new MicUnavailableError(); + expect(e).toBeInstanceOf(MicUnavailableError); + expect(e.name).toBe("MicUnavailableError"); + }); +}); diff --git a/apps/client/src/features/dictation/audio/mic-capture.ts b/apps/client/src/features/dictation/audio/mic-capture.ts new file mode 100644 index 00000000..1c2c4c3b --- /dev/null +++ b/apps/client/src/features/dictation/audio/mic-capture.ts @@ -0,0 +1,68 @@ +// Shared microphone-acquisition front-end used by BOTH the batch (`use-dictation`) +// and streaming (`use-realtime-dictation`) hooks. Only the getUserMedia handshake +// and its error→message mapping live here — the two hooks keep their own distinct +// downstream graphs (MediaRecorder vs AudioWorklet) and their own streamRef +// ownership. This collapses the ~37 duplicated lines without merging the hooks. + +// Translate function shape (react-i18next's `t`). Kept structural so this module +// has no i18next dependency and stays trivially testable. +export type Translate = (key: string) => string; + +/** Thrown by `acquireMicStream` when the environment cannot capture audio. */ +export class MicUnavailableError extends Error { + constructor() { + super("navigator.mediaDevices.getUserMedia is unavailable in this context"); + this.name = "MicUnavailableError"; + } +} + +/** + * Map a getUserMedia rejection to a user-facing, localized message. Mirrors the + * branching both hooks used previously so behavior is identical. Pure aside from + * the injected `t`; safe to unit-test with a stub translator. + */ +export function mapGetUserMediaError(err: unknown, t: Translate): string { + const name = (err as { name?: string })?.name; + const detail = (err as { message?: string })?.message ?? String(err); + if (name === "NotAllowedError" || name === "SecurityError") { + return t("Microphone access denied"); + } + if (name === "NotFoundError" || name === "OverconstrainedError") { + return t("No microphone found"); + } + if (name === "NotReadableError" || name === "AbortError") { + return t("Microphone is unavailable or already in use"); + } + // Unknown failure: show the real reason instead of a generic string. + return `${t("Could not start recording")}: ${name ? `${name}: ` : ""}${detail}`; +} + +/** + * Request the microphone. Throws `MicUnavailableError` when the API is missing + * (so callers can show the "not available in this context" notification), and + * otherwise rethrows the raw getUserMedia error for `mapGetUserMediaError`. The + * caller owns the returned stream (assigns it to its own streamRef and is + * responsible for stopping the tracks on every exit path). + */ +export async function acquireMicStream(): Promise { + if (!navigator.mediaDevices?.getUserMedia) { + throw new MicUnavailableError(); + } + return navigator.mediaDevices.getUserMedia({ audio: true }); +} + +/** + * Shared synchronous double-start guard. Returns true when a new capture may + * begin, false when one is already starting or live (so the second click is a + * no-op and never opens a leaking second MediaStream). `status` is the React + * status; the refs cover the window before the next render commits. + */ +export function canStartCapture(args: { + starting: boolean; + hasStream: boolean; + hasLiveResource: boolean; + statusIsIdle: boolean; +}): boolean { + if (args.starting || args.hasStream || args.hasLiveResource) return false; + return args.statusIsIdle; +} diff --git a/apps/client/src/features/dictation/audio/pcm16-dsp.test.ts b/apps/client/src/features/dictation/audio/pcm16-dsp.test.ts new file mode 100644 index 00000000..1431576a --- /dev/null +++ b/apps/client/src/features/dictation/audio/pcm16-dsp.test.ts @@ -0,0 +1,178 @@ +import { describe, it, expect } from "vitest"; +import { + floatSampleToInt16, + floatToPcm16LE, + LinearResampler, + OnePoleLowPass, + FrameAccumulator, + FRAME_SAMPLES, +} from "./pcm16-dsp"; + +// Read back the LE int16 values from a PCM16 ArrayBuffer for assertions. +function readInt16LE(buf: ArrayBuffer): number[] { + const view = new DataView(buf); + const out: number[] = []; + for (let i = 0; i < buf.byteLength; i += 2) out.push(view.getInt16(i, true)); + return out; +} + +describe("floatSampleToInt16 / floatToPcm16LE", () => { + it("maps +1 → 32767, -1 → -32768, 0 → 0", () => { + expect(floatSampleToInt16(1)).toBe(32767); + expect(floatSampleToInt16(-1)).toBe(-32768); + expect(floatSampleToInt16(0)).toBe(0); + }); + + it("clamps +2 / -2 without overflow", () => { + expect(floatSampleToInt16(2)).toBe(32767); + expect(floatSampleToInt16(-2)).toBe(-32768); + expect(floatSampleToInt16(1000)).toBe(32767); + expect(floatSampleToInt16(-1000)).toBe(-32768); + }); + + it("handles NaN and Infinity", () => { + expect(floatSampleToInt16(NaN)).toBe(0); + expect(floatSampleToInt16(Infinity)).toBe(32767); + expect(floatSampleToInt16(-Infinity)).toBe(-32768); + }); + + it("writes little-endian byte order", () => { + // 1 → 32767 = 0x7FFF → LE bytes [0xFF, 0x7F]. + const buf = floatToPcm16LE([1]); + const bytes = new Uint8Array(buf); + expect(bytes[0]).toBe(0xff); + expect(bytes[1]).toBe(0x7f); + expect(buf.byteLength).toBe(2); + }); + + it("emits exactly length*2 bytes and round-trips", () => { + const input = [0, 1, -1, 0.5, -0.5]; + const buf = floatToPcm16LE(input); + expect(buf.byteLength).toBe(input.length * 2); + const back = readInt16LE(buf); + expect(back[0]).toBe(0); + expect(back[1]).toBe(32767); + expect(back[2]).toBe(-32768); + }); + + it("property: output is always within [-32768, 32767]", () => { + for (let i = 0; i < 1000; i++) { + const v = (Math.random() - 0.5) * 10; // span well beyond [-1,1] + const out = floatSampleToInt16(v); + expect(out).toBeGreaterThanOrEqual(-32768); + expect(out).toBeLessThanOrEqual(32767); + } + // Include hostile values explicitly. + for (const v of [NaN, Infinity, -Infinity, 1e308, -1e308]) { + const out = floatSampleToInt16(v); + expect(out).toBeGreaterThanOrEqual(-32768); + expect(out).toBeLessThanOrEqual(32767); + } + }); +}); + +describe("LinearResampler", () => { + function ramp(n: number): Float32Array { + const a = new Float32Array(n); + for (let i = 0; i < n; i++) a[i] = i / n; + return a; + } + + it("48k → 24k produces ~half the samples", () => { + const rs = new LinearResampler(48000, 24000); + const out = rs.process(ramp(1000)); + expect(out.length).toBeGreaterThan(480); + expect(out.length).toBeLessThan(520); + }); + + it("ratio = 1 is approximately a passthrough (length-wise)", () => { + const rs = new LinearResampler(24000, 24000); + const input = ramp(1000); + const out = rs.process(input); + expect(Math.abs(out.length - input.length)).toBeLessThanOrEqual(1); + }); + + it("44.1k → 24k fractional ratio yields the expected count", () => { + const rs = new LinearResampler(44100, 24000); + const n = 4410; + const out = rs.process(ramp(n)); + const expected = n * (24000 / 44100); // ~2400 + expect(Math.abs(out.length - expected)).toBeLessThan(3); + }); + + it("cross-quantum continuity: split == single", () => { + const input = ramp(2000); + const single = new LinearResampler(48000, 24000).process(input); + + const split = new LinearResampler(48000, 24000); + const a = split.process(input.subarray(0, 777)); + const b = split.process(input.subarray(777)); + const joined = new Float32Array(a.length + b.length); + joined.set(a, 0); + joined.set(b, a.length); + + expect(joined.length).toBe(single.length); + for (let i = 0; i < single.length; i++) { + expect(joined[i]).toBeCloseTo(single[i], 6); + } + }); + + it("never reads out of bounds (no NaN in output)", () => { + const rs = new LinearResampler(48000, 24000); + for (let q = 0; q < 50; q++) { + const out = rs.process(ramp(128)); + for (const v of out) expect(Number.isNaN(v)).toBe(false); + } + }); +}); + +describe("OnePoleLowPass", () => { + it("is a passthrough when not downsampling", () => { + const lp = new OnePoleLowPass(24000, 24000); + for (const v of [0.5, -0.3, 1, -1]) expect(lp.process(v)).toBe(v); + }); + + it("attenuates a step (smooths) when downsampling", () => { + const lp = new OnePoleLowPass(48000, 24000); + const first = lp.process(1); + // One-pole on a step from 0 should not jump straight to 1. + expect(first).toBeLessThan(1); + expect(first).toBeGreaterThan(0); + }); +}); + +describe("FrameAccumulator", () => { + it("emits exactly one 7200-byte frame for FRAME_SAMPLES samples", () => { + expect(FRAME_SAMPLES).toBe(3600); + const acc = new FrameAccumulator(); + const frames = acc.push(new Float32Array(FRAME_SAMPLES)); + expect(frames).toHaveLength(1); + expect(frames[0].byteLength).toBe(7200); + expect(acc.pending).toBe(0); + }); + + it("emits no frame for FRAME_SAMPLES-1 samples and carries the remainder", () => { + const acc = new FrameAccumulator(); + const frames = acc.push(new Float32Array(FRAME_SAMPLES - 1)); + expect(frames).toHaveLength(0); + expect(acc.pending).toBe(FRAME_SAMPLES - 1); + }); + + it("carries the remainder across pushes", () => { + const acc = new FrameAccumulator(); + expect(acc.push(new Float32Array(2000))).toHaveLength(0); + const frames = acc.push(new Float32Array(2000)); // 4000 total → one frame + expect(frames).toHaveLength(1); + expect(acc.pending).toBe(400); // 4000 - 3600 + }); + + it("flush emits the partial tail then clears", () => { + const acc = new FrameAccumulator(); + acc.push(new Float32Array(100)); + const tail = acc.flush(); + expect(tail).not.toBeNull(); + expect(tail!.byteLength).toBe(200); + expect(acc.pending).toBe(0); + expect(acc.flush()).toBeNull(); + }); +}); diff --git a/apps/client/src/features/dictation/audio/pcm16-dsp.ts b/apps/client/src/features/dictation/audio/pcm16-dsp.ts new file mode 100644 index 00000000..e3f4f201 --- /dev/null +++ b/apps/client/src/features/dictation/audio/pcm16-dsp.ts @@ -0,0 +1,187 @@ +// Pure DSP primitives for the realtime dictation capture path. These functions +// carry NO Web Audio / worklet dependencies so they can be unit-tested directly +// in jsdom/node. The AudioWorklet processor (`pcm16-worklet.ts`) re-implements +// the same math inline (the worklet global scope forbids ES imports at runtime), +// but THIS module is the single canonical reference the tests exercise and the +// worklet is kept byte-identical in behavior to it. See the note in +// `pcm16-worklet.ts`. + +// Target output rate required by the upstream transcription contract. +export const TARGET_RATE = 24000; +// ~150 ms of audio at the target rate: 24000 * 0.15 = 3600 samples per message. +export const FRAME_SAMPLES = Math.round(TARGET_RATE * 0.15); + +/** + * Convert a single normalized float audio sample in [-1, 1] to a signed 16-bit + * integer. Values outside the range are clamped; NaN/Inf collapse to 0/±range so + * the output is ALWAYS within [-32768, 32767]. Negative values scale by 0x8000 + * and non-negative by 0x7fff so that +1 → 32767 and -1 → -32768 exactly. + */ +export function floatSampleToInt16(sample: number): number { + let s = sample; + if (Number.isNaN(s)) return 0; + if (s > 1) s = 1; + else if (s < -1) s = -1; + const scaled = s < 0 ? s * 0x8000 : s * 0x7fff; + // Math.round to the nearest integer, then a hard clamp as a final guard. + let v = Math.round(scaled); + if (v > 32767) v = 32767; + else if (v < -32768) v = -32768; + return v; +} + +/** + * Convert a Float32 sample buffer to little-endian PCM16 bytes. The returned + * ArrayBuffer is exactly `float32.length * 2` bytes; byte order is LE regardless + * of host endianness (DataView writes are explicit). + */ +export function floatToPcm16LE(float32: ArrayLike): ArrayBuffer { + const count = float32.length; + const buffer = new ArrayBuffer(count * 2); + const view = new DataView(buffer); + for (let i = 0; i < count; i++) { + view.setInt16(i * 2, floatSampleToInt16(float32[i]), true); + } + return buffer; +} + +/** + * A simple one-pole IIR low-pass filter used as a cheap anti-aliasing stage + * before downsampling (e.g. 48k → 24k). The coefficient is derived from the + * normalized cutoff so the filter attenuates content above the output Nyquist, + * reducing aliasing noise that would otherwise confuse the STT model. State is + * carried across quanta via the returned `prev` so there are no per-quantum + * seams. When `inputRate <= outputRate` (no downsampling) the filter is a + * passthrough. + */ +export class OnePoleLowPass { + private alpha: number; + private prev: number; + private readonly enabled: boolean; + + constructor(inputRate: number, outputRate: number, primed = 0) { + // Cutoff a touch below the output Nyquist to leave transition room. + const cutoff = (outputRate / 2) * 0.9; + this.enabled = inputRate > outputRate && cutoff > 0 && inputRate > 0; + // Standard one-pole alpha: dt / (rc + dt), rc = 1 / (2π fc). + const dt = 1 / Math.max(inputRate, 1); + const rc = 1 / (2 * Math.PI * Math.max(cutoff, 1)); + this.alpha = dt / (rc + dt); + this.prev = primed; + } + + /** Filter one sample in place; passthrough when disabled. */ + process(sample: number): number { + if (!this.enabled) return sample; + this.prev = this.prev + this.alpha * (sample - this.prev); + return this.prev; + } +} + +/** + * Stateful linear resampler that converts a stream of input quanta at + * `inputRate` to `outputRate`, carrying the fractional read position and the + * boundary sample across calls so splitting a signal into two `process()` calls + * yields the same output as one call (cross-quantum continuity). Never reads out + * of bounds: the right neighbor of every emitted sample is guaranteed to exist + * within the current quantum; any leftover position is carried. + */ +export class LinearResampler { + private readonly ratio: number; + private resamplePos = 0; + private prevSample = 0; + private primed = false; + + constructor(inputRate: number, outputRate: number) { + // Input samples consumed per output sample. >1 when downsampling. + this.ratio = inputRate / outputRate; + } + + /** + * Resample one quantum and return the produced output samples. The optional + * `filter` is applied to each input sample as it is consumed (anti-aliasing). + */ + process(channel: ArrayLike, filter?: OnePoleLowPass): Float32Array { + const n = channel.length; + if (n === 0) return new Float32Array(0); + + // Apply the anti-aliasing filter once over the raw input, keeping the result + // in a local buffer so resampling reads filtered values. The filter state is + // carried inside `filter` across calls. + let src: ArrayLike = channel; + if (filter) { + const filtered = new Float32Array(n); + for (let i = 0; i < n; i++) filtered[i] = filter.process(channel[i]); + src = filtered; + } + + if (!this.primed) { + this.prevSample = src[0]; + this.primed = true; + this.resamplePos = 0; + } + + // Worst case output count for sizing; trim at the end. + const out: number[] = []; + let pos = this.resamplePos; + while (pos < n - 1) { + const floor = Math.floor(pos); + const frac = pos - floor; + const s0 = floor < 0 ? this.prevSample : src[floor]; + const s1 = src[floor + 1]; + out.push(s0 + (s1 - s0) * frac); + pos += this.ratio; + } + + this.resamplePos = pos - n; + this.prevSample = src[n - 1]; + return Float32Array.from(out); + } +} + +/** + * Accumulates resampled Float32 samples and emits whole PCM16 frames of exactly + * FRAME_SAMPLES (7200-byte ArrayBuffers). The remainder is carried until the + * next push completes a frame. `flush()` emits any partial remainder (used on + * teardown so the final ~150 ms is not lost). + */ +export class FrameAccumulator { + private acc: Float32Array; + private accLen = 0; + private readonly frameSamples: number; + + constructor(frameSamples: number = FRAME_SAMPLES) { + this.frameSamples = frameSamples; + this.acc = new Float32Array(frameSamples); + } + + /** + * Push samples; returns zero or more complete PCM16 frame buffers (each + * `frameSamples * 2` bytes). The carried remainder stays buffered. + */ + push(samples: ArrayLike): ArrayBuffer[] { + const frames: ArrayBuffer[] = []; + for (let i = 0; i < samples.length; i++) { + this.acc[this.accLen] = samples[i]; + this.accLen += 1; + if (this.accLen >= this.frameSamples) { + frames.push(floatToPcm16LE(this.acc.subarray(0, this.accLen))); + this.accLen = 0; + } + } + return frames; + } + + /** Emit the partial remainder (if any) as one frame and clear it. */ + flush(): ArrayBuffer | null { + if (this.accLen === 0) return null; + const buf = floatToPcm16LE(this.acc.subarray(0, this.accLen)); + this.accLen = 0; + return buf; + } + + /** Number of buffered samples not yet flushed. */ + get pending(): number { + return this.accLen; + } +} diff --git a/apps/client/src/features/dictation/audio/pcm16-worklet.ts b/apps/client/src/features/dictation/audio/pcm16-worklet.ts index b67488d2..f9f0a59f 100644 --- a/apps/client/src/features/dictation/audio/pcm16-worklet.ts +++ b/apps/client/src/features/dictation/audio/pcm16-worklet.ts @@ -3,12 +3,21 @@ // upstream. It runs in the AudioWorklet global scope, so it MUST NOT import anything // (the worklet module has no module graph / bundler runtime around it). // +// IMPORTANT — single source of truth: the DSP math below (float→PCM16 conversion, +// the one-pole anti-aliasing low-pass, linear resampling, and frame accumulation) +// is the SAME algorithm exported as pure, unit-tested functions from the sibling +// `pcm16-dsp.ts`. Because the worklet scope cannot `import` at runtime, the logic +// is mirrored here inline rather than imported, and the tests assert that the pure +// module behaves identically. Any change to one MUST be mirrored in the other. +// // Per `process()` call the host hands us a render quantum (typically 128 frames) at -// the context sample rate. We read the first input channel (mono), linearly resample -// to 24000 Hz while carrying the fractional read position across calls (so we never -// assume a particular input rate, e.g. 44.1k or 48k), accumulate the resampled -// samples, and once we have ~150 ms worth (3600 samples) we emit them as an -// Int16 ArrayBuffer transferred to the main thread. +// the context sample rate. We read the first input channel (mono), apply a cheap +// anti-aliasing low-pass, linearly resample to 24000 Hz while carrying the +// fractional read position across calls (so we never assume a particular input +// rate, e.g. 44.1k or 48k), accumulate the resampled samples, and once we have +// ~150 ms worth (3600 samples) we emit them as an Int16 ArrayBuffer transferred to +// the main thread. A 'flush' message from the main thread emits the partial tail so +// the last ~150 ms is not lost on stop. // Target output rate required by the upstream transcription contract. const TARGET_RATE = 24000; @@ -34,6 +43,41 @@ class Pcm16Worklet extends AudioWorkletProcessor { private acc: Float32Array = new Float32Array(FRAME_SAMPLES); private accLen = 0; + // --- Anti-aliasing one-pole low-pass state (see OnePoleLowPass in pcm16-dsp) --- + // Configured lazily on the first quantum once `sampleRate` is known. + private lpAlpha = 1; + private lpPrev = 0; + private lpEnabled = false; + private lpConfigured = false; + + constructor() { + super(); + // The main thread asks for a tail flush on stop so the last partial frame + // (~150 ms) is not dropped. Any message triggers a flush of the remainder. + this.port.onmessage = (event: MessageEvent) => { + if (event.data === "flush") this.flush(); + }; + } + + private configureLowPass(): void { + if (this.lpConfigured) return; + this.lpConfigured = true; + const inputRate = sampleRate; + const outputRate = TARGET_RATE; + const cutoff = (outputRate / 2) * 0.9; + this.lpEnabled = inputRate > outputRate && cutoff > 0 && inputRate > 0; + const dt = 1 / Math.max(inputRate, 1); + const rc = 1 / (2 * Math.PI * Math.max(cutoff, 1)); + this.lpAlpha = dt / (rc + dt); + this.lpPrev = 0; + } + + private lowPass(sample: number): number { + if (!this.lpEnabled) return sample; + this.lpPrev = this.lpPrev + this.lpAlpha * (sample - this.lpPrev); + return this.lpPrev; + } + process(inputs: Float32Array[][]): boolean { const input = inputs[0]; // No connected input (or a momentarily empty quantum): keep the node alive @@ -49,16 +93,22 @@ class Pcm16Worklet extends AudioWorkletProcessor { return true; } - // Linearly resample `channel` (at the context `sampleRate`) to TARGET_RATE and - // push the results into the accumulator, flushing whole frames as they fill. + // Apply anti-aliasing, linearly resample `channel` (at the context `sampleRate`) + // to TARGET_RATE, and push the results into the accumulator, flushing whole + // frames as they fill. private resampleAndAccumulate(channel: Float32Array): void { + this.configureLowPass(); const ratio = sampleRate / TARGET_RATE; // input samples consumed per output sample const n = channel.length; + // Anti-alias the raw input first; carry the filter state across quanta. + const src = new Float32Array(n); + for (let i = 0; i < n; i++) src[i] = this.lowPass(channel[i]); + if (!this.primed) { // First quantum: there is no real predecessor, so seed the virtual index -1 // with this quantum's first sample and start reading from 0. - this.prevSample = channel[0]; + this.prevSample = src[0]; this.primed = true; this.resamplePos = 0; } @@ -70,13 +120,13 @@ class Pcm16Worklet extends AudioWorkletProcessor { // neighbor at floor === -1 is the carried `prevSample`; floor >= 0 reads the // quantum directly. Any leftover position (whose right neighbor would be the // NEXT quantum's first sample) is carried via `resamplePos` and resolved on - // the next call. This guarantees we never read `channel[n]` (out of bounds). + // the next call. This guarantees we never read `src[n]` (out of bounds). while (pos < n - 1) { const floor = Math.floor(pos); const frac = pos - floor; - const s0 = floor < 0 ? this.prevSample : channel[floor]; - const s1 = channel[floor + 1]; + const s0 = floor < 0 ? this.prevSample : src[floor]; + const s1 = src[floor + 1]; this.pushSample(s0 + (s1 - s0) * frac); pos += ratio; @@ -85,7 +135,7 @@ class Pcm16Worklet extends AudioWorkletProcessor { // Rebase the leftover position relative to the next quantum's start and carry // this quantum's last sample as the predecessor for the boundary interval. this.resamplePos = pos - n; - this.prevSample = channel[n - 1]; + this.prevSample = src[n - 1]; } // Append one resampled sample; flush a full PCM16 frame whenever the @@ -101,6 +151,7 @@ class Pcm16Worklet extends AudioWorkletProcessor { // Convert the accumulated Float32 samples to Int16 LE and post the ArrayBuffer // to the main thread, transferring ownership (zero-copy). DataView writes are // little-endian to match the PCM16 contract regardless of host endianness. + // Also invoked on a 'flush' message to emit a partial tail frame on stop. private flush(): void { const count = this.accLen; if (count === 0) return; @@ -108,11 +159,16 @@ class Pcm16Worklet extends AudioWorkletProcessor { const buffer = new ArrayBuffer(count * 2); const view = new DataView(buffer); for (let i = 0; i < count; i++) { - // Clamp to [-1, 1] then scale to the signed 16-bit range. + // Clamp to [-1, 1] then scale to the signed 16-bit range. Mirrors + // floatSampleToInt16 in pcm16-dsp.ts. let s = this.acc[i]; - if (s > 1) s = 1; + if (Number.isNaN(s)) s = 0; + else if (s > 1) s = 1; else if (s < -1) s = -1; - view.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7fff, true); + let v = Math.round(s < 0 ? s * 0x8000 : s * 0x7fff); + if (v > 32767) v = 32767; + else if (v < -32768) v = -32768; + view.setInt16(i * 2, v, true); } this.accLen = 0; diff --git a/apps/client/src/features/dictation/components/realtime-mic-button.test.tsx b/apps/client/src/features/dictation/components/realtime-mic-button.test.tsx new file mode 100644 index 00000000..d33f3ce4 --- /dev/null +++ b/apps/client/src/features/dictation/components/realtime-mic-button.test.tsx @@ -0,0 +1,103 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { render, screen, fireEvent, cleanup } from "@testing-library/react"; +import { MantineProvider } from "@mantine/core"; + +// jsdom has no matchMedia; Mantine's color-scheme provider needs it. Stub a +// minimal, inert implementation before any MantineProvider mounts. +if (typeof window.matchMedia !== "function") { + window.matchMedia = (query: string) => + ({ + matches: false, + media: query, + onchange: null, + addListener: () => undefined, + removeListener: () => undefined, + addEventListener: () => undefined, + removeEventListener: () => undefined, + dispatchEvent: () => false, + }) as unknown as MediaQueryList; +} + +// Mock i18n so labels render the raw key. +vi.mock("react-i18next", () => ({ + useTranslation: () => ({ t: (k: string) => k, i18n: {} }), +})); + +// Controllable mock of the dictation hook. Tests set the returned status and +// inspect the start/stop spies. +const hookState: { + status: "idle" | "recording" | "error"; + start: ReturnType; + stop: ReturnType; + cancel: ReturnType; +} = { + status: "idle", + start: vi.fn(), + stop: vi.fn(), + cancel: vi.fn(), +}; + +vi.mock("@/features/dictation/hooks/use-realtime-dictation", () => ({ + useRealtimeDictation: () => hookState, +})); + +import { RealtimeMicButton } from "./realtime-mic-button"; + +function renderButton(props: Partial[0]> = {}) { + const onInterim = vi.fn(); + const onFinal = vi.fn(); + const utils = render( + + + , + ); + return { onInterim, onFinal, ...utils }; +} + +beforeEach(() => { + cleanup(); + hookState.status = "idle"; + hookState.start = vi.fn(); + hookState.stop = vi.fn(); + hookState.cancel = vi.fn(); +}); + +describe("RealtimeMicButton", () => { + it("idle: clicking calls start", () => { + renderButton(); + fireEvent.click(screen.getByLabelText("Start dictation")); + expect(hookState.start).toHaveBeenCalledTimes(1); + expect(hookState.stop).not.toHaveBeenCalled(); + }); + + it("recording: clicking calls stop", () => { + hookState.status = "recording"; + renderButton(); + fireEvent.click(screen.getByLabelText("Stop recording")); + expect(hookState.stop).toHaveBeenCalledTimes(1); + expect(hookState.start).not.toHaveBeenCalled(); + }); + + it("recording → idle transition fires onInterim('') exactly once", () => { + hookState.status = "recording"; + const { onInterim, rerender } = renderButton(); + expect(onInterim).not.toHaveBeenCalled(); + + hookState.status = "idle"; + rerender( + + + , + ); + expect(onInterim).toHaveBeenCalledTimes(1); + expect(onInterim).toHaveBeenCalledWith(""); + + // A further re-render in idle does not fire it again. + rerender( + + + , + ); + expect(onInterim).toHaveBeenCalledTimes(1); + }); +}); diff --git a/apps/client/src/features/dictation/hooks/use-dictation.ts b/apps/client/src/features/dictation/hooks/use-dictation.ts index 0d32402f..6859afc5 100644 --- a/apps/client/src/features/dictation/hooks/use-dictation.ts +++ b/apps/client/src/features/dictation/hooks/use-dictation.ts @@ -2,6 +2,12 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { notifications } from "@mantine/notifications"; import { useTranslation } from "react-i18next"; import { transcribeAudio } from "@/features/dictation/services/dictation-service"; +import { + acquireMicStream, + canStartCapture, + mapGetUserMediaError, + MicUnavailableError, +} from "@/features/dictation/audio/mic-capture"; export type DictationStatus = "idle" | "recording" | "transcribing" | "error"; @@ -180,46 +186,38 @@ export function useDictation( }, []); const start = useCallback(async (): Promise => { - // Synchronous live guard: status is stale between renders, so also block on - // refs to prevent a double-click from opening two MediaStreams (the first - // would leak). - if (startingRef.current || recorderRef.current || streamRef.current) return; - if (status !== "idle") return; - startingRef.current = true; - - if (!navigator.mediaDevices?.getUserMedia) { - const reason = - "navigator.mediaDevices.getUserMedia is unavailable in this context"; - console.error("[dictation] " + reason); - notifications.show({ - color: "red", - message: t("Audio recording is not available in this browser/context"), - }); - setStatus("idle"); - startingRef.current = false; + // Synchronous live guard (shared with the streaming hook): status is stale + // between renders, so also block on refs to prevent a double-click from + // opening two MediaStreams (the first would leak). + if ( + !canStartCapture({ + starting: startingRef.current, + hasStream: streamRef.current !== null, + hasLiveResource: recorderRef.current !== null, + statusIsIdle: status === "idle", + }) + ) { return; } + startingRef.current = true; let stream: MediaStream; try { - stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + stream = await acquireMicStream(); } catch (err) { - // Always log the full error for diagnosis (name, message, stack). - console.error("[dictation] getUserMedia failed", err); - const name = (err as { name?: string })?.name; - const detail = (err as { message?: string })?.message ?? String(err); - let message: string; - if (name === "NotAllowedError" || name === "SecurityError") { - message = t("Microphone access denied"); - } else if (name === "NotFoundError" || name === "OverconstrainedError") { - message = t("No microphone found"); - } else if (name === "NotReadableError" || name === "AbortError") { - message = t("Microphone is unavailable or already in use"); + if (err instanceof MicUnavailableError) { + console.error("[dictation] " + err.message); + notifications.show({ + color: "red", + message: t( + "Audio recording is not available in this browser/context", + ), + }); } else { - // Unknown failure: show the real reason instead of a generic string. - message = `${t("Could not start recording")}: ${name ? `${name}: ` : ""}${detail}`; + // Always log the full error for diagnosis (name, message, stack). + console.error("[dictation] getUserMedia failed", err); + notifications.show({ color: "red", message: mapGetUserMediaError(err, t) }); } - notifications.show({ color: "red", message }); setStatus("idle"); startingRef.current = false; return; diff --git a/apps/client/src/features/dictation/hooks/use-realtime-dictation.ts b/apps/client/src/features/dictation/hooks/use-realtime-dictation.ts index 8c89c279..a4ae2526 100644 --- a/apps/client/src/features/dictation/hooks/use-realtime-dictation.ts +++ b/apps/client/src/features/dictation/hooks/use-realtime-dictation.ts @@ -2,6 +2,13 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { notifications } from "@mantine/notifications"; import { useTranslation } from "react-i18next"; import { RealtimeDictationClient } from "@/features/dictation/services/realtime-dictation-client"; +import { + acquireMicStream, + canStartCapture, + mapGetUserMediaError, + MicUnavailableError, +} from "@/features/dictation/audio/mic-capture"; +import { baseLanguageSubtag } from "@/features/dictation/services/dictation-reducer"; // The worklet module URL is produced via `new URL(..., import.meta.url)` so Vite // emits the processor as a separate, self-contained module chunk (it must run in @@ -66,6 +73,9 @@ export function useRealtimeDictation( const timerRef = useRef | null>(null); const errorTimerRef = useRef | null>(null); + // Defers the upstream/socket teardown a short beat after a graceful stop so the + // worklet's flushed tail frame can round-trip and be forwarded before we close. + const flushTimerRef = useRef | null>(null); const canceledRef = useRef(false); const startingRef = useRef(false); @@ -82,6 +92,10 @@ export function useRealtimeDictation( clearTimeout(timerRef.current); timerRef.current = null; } + if (flushTimerRef.current !== null) { + clearTimeout(flushTimerRef.current); + flushTimerRef.current = null; + } }, []); const stopTracks = useCallback(() => { @@ -125,7 +139,9 @@ export function useRealtimeDictation( // Full teardown shared by stop/cancel/unmount. Order: stop streaming upstream, // disconnect the socket, then dismantle the local audio graph and tracks, then - // clear timers and reset the ready/pending state. + // clear timers and reset the ready/pending state. Also clears the interim + // "ghost" decoration in the consumer so it does not stick when the toolbar + // closes mid-recording (the unmount path runs teardown). const teardown = useCallback(() => { const client = clientRef.current; if (client) { @@ -145,8 +161,28 @@ export function useRealtimeDictation( readyRef.current = false; pendingAudioRef.current = []; startingRef.current = false; + + // Clear any leftover interim decoration. Guarded so a throwing consumer + // callback can never break teardown. + try { + optionsRef.current.onInterim(""); + } catch (err) { + console.error("[realtime-dictation] onInterim('') during teardown threw", err); + } }, [teardownAudio, stopTracks, clearTimer]); + // Ask the worklet to emit its partial tail frame (the last ~150 ms that has not + // yet filled a full frame) so it is not lost on stop. The worklet posts the + // remaining samples back over the existing port.onmessage handler, which + // forwards them upstream before the socket is closed. + const flushWorklet = useCallback(() => { + try { + workletRef.current?.port.postMessage("flush"); + } catch { + // Port may already be closed; ignore. + } + }, []); + // Surface a concrete failure: log it, notify, flip to "error", and reset to // "idle" after a short delay (mirrors use-dictation's error timer). const handleError = useCallback( @@ -169,61 +205,54 @@ export function useRealtimeDictation( ); const start = useCallback(async (): Promise => { - // Synchronous live guard: status is stale between renders, so also block on - // refs to prevent a double-click from opening two MediaStreams / sockets. + // Synchronous live guard (shared with the batch hook): status is stale between + // renders, so also block on refs to prevent a double-click from opening two + // MediaStreams / sockets. if ( - startingRef.current || - streamRef.current || - audioContextRef.current || - clientRef.current + !canStartCapture({ + starting: startingRef.current, + hasStream: streamRef.current !== null, + hasLiveResource: + audioContextRef.current !== null || clientRef.current !== null, + statusIsIdle: status === "idle", + }) ) { return; } - if (status !== "idle") return; startingRef.current = true; canceledRef.current = false; readyRef.current = false; pendingAudioRef.current = []; - if (!navigator.mediaDevices?.getUserMedia) { - const reason = - "navigator.mediaDevices.getUserMedia is unavailable in this context"; - console.error("[realtime-dictation] " + reason); - notifications.show({ - color: "red", - message: t("Audio recording is not available in this browser/context"), - }); - setStatus("idle"); - startingRef.current = false; - return; - } - let stream: MediaStream; try { - stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + stream = await acquireMicStream(); } catch (err) { - // Always log the full error for diagnosis (name, message, stack). - console.error("[realtime-dictation] getUserMedia failed", err); - const name = (err as { name?: string })?.name; - const detail = (err as { message?: string })?.message ?? String(err); - let message: string; - if (name === "NotAllowedError" || name === "SecurityError") { - message = t("Microphone access denied"); - } else if (name === "NotFoundError" || name === "OverconstrainedError") { - message = t("No microphone found"); - } else if (name === "NotReadableError" || name === "AbortError") { - message = t("Microphone is unavailable or already in use"); + if (err instanceof MicUnavailableError) { + console.error("[realtime-dictation] " + err.message); + notifications.show({ + color: "red", + message: t( + "Audio recording is not available in this browser/context", + ), + }); } else { - // Unknown failure: show the real reason instead of a generic string. - message = `${t("Could not start recording")}: ${name ? `${name}: ` : ""}${detail}`; + // Always log the full error for diagnosis (name, message, stack). + console.error("[realtime-dictation] getUserMedia failed", err); + notifications.show({ + color: "red", + message: mapGetUserMediaError(err, t), + }); } - notifications.show({ color: "red", message }); setStatus("idle"); startingRef.current = false; return; } - // If a cancel landed during the await, drop the stream and bail out. + // If a stop/cancel landed during the await (the button was pressed while the + // permission prompt was still pending), drop the just-acquired stream and bail + // out cleanly so the mic does not stay physically on and the button does not + // stick on "recording". if (canceledRef.current) { stream.getTracks().forEach((track) => track.stop()); startingRef.current = false; @@ -372,7 +401,7 @@ export function useRealtimeDictation( // code, not a region-tagged locale; the server omits it upstream when absent. client.connect(); const locale = i18n.resolvedLanguage || i18n.language || ""; - const language = locale.split("-")[0] || undefined; + const language = baseLanguageSubtag(locale); client.start({ language }); setStatus("recording"); @@ -387,7 +416,7 @@ export function useRealtimeDictation( }, [status, t, i18n, stopTracks, teardownAudio, handleError]); const stop = useCallback((): void => { - // Nothing live → no-op (never crash on an idle/destroyed state). + // Nothing live and not mid-acquisition → no-op (never crash on idle). if ( !clientRef.current && !audioContextRef.current && @@ -396,9 +425,45 @@ export function useRealtimeDictation( ) { return; } + + // If stop() is pressed while getUserMedia / addModule is still pending, the + // start() continuation has not yet stored every ref. Set the cancel flag so + // the awaiting start() path bails after its await and stops the stream it just + // acquired (otherwise the mic stays physically ON, the red indicator sticks, + // and the button stays on "recording"). teardown() below also tears down + // anything already wired (partial graph / socket), so either path leaves us + // fully idle. The flag is the same one cancel() uses to neutralize late + // socket/worklet callbacks. + if (startingRef.current) { + // Mid-acquisition: no worklet/socket to flush. Set the cancel flag (the + // awaiting start() bails and stops the just-acquired stream) and tear down + // anything already wired. UI returns to idle immediately. + canceledRef.current = true; + teardown(); + setStatus("idle"); + return; + } + + // Graceful stop of a fully-live session: ask the worklet to emit its partial + // tail frame, then defer the socket/graph teardown a short beat so that tail + // can round-trip and be forwarded upstream before the session closes. The UI + // returns to idle right away; the deferred teardown is idempotent and is also + // cancelled by clearTimer() on any subsequent start/cancel/unmount. + if (workletRef.current && clientRef.current) { + flushWorklet(); + if (flushTimerRef.current !== null) clearTimeout(flushTimerRef.current); + flushTimerRef.current = setTimeout(() => { + flushTimerRef.current = null; + teardown(); + }, 60); + setStatus("idle"); + return; + } + + // No live worklet (e.g. graph half-built): tear down immediately. teardown(); setStatus("idle"); - }, [teardown]); + }, [teardown, flushWorklet]); // Keep the stop ref pointed at the latest stop() for the max-duration timer. stopRef.current = stop; diff --git a/apps/client/src/features/dictation/services/dictation-reducer.test.ts b/apps/client/src/features/dictation/services/dictation-reducer.test.ts new file mode 100644 index 00000000..e008ed75 --- /dev/null +++ b/apps/client/src/features/dictation/services/dictation-reducer.test.ts @@ -0,0 +1,91 @@ +import { describe, it, expect } from "vitest"; +import { + baseLanguageSubtag, + initSessionState, + onAudio, + onReady, + onInterim, + onFinal, + onCancel, + onStop, +} from "./dictation-reducer"; + +describe("baseLanguageSubtag", () => { + it("reduces a region-tagged locale to its base subtag", () => { + expect(baseLanguageSubtag("en-US")).toBe("en"); + }); + it("returns a bare subtag unchanged", () => { + expect(baseLanguageSubtag("en")).toBe("en"); + }); + it("returns undefined for empty / blank / nullish", () => { + expect(baseLanguageSubtag("")).toBeUndefined(); + expect(baseLanguageSubtag(" ")).toBeUndefined(); + expect(baseLanguageSubtag(undefined)).toBeUndefined(); + expect(baseLanguageSubtag(null)).toBeUndefined(); + }); +}); + +function ab(n: number): ArrayBuffer { + return new ArrayBuffer(n); +} + +describe("dictation session reducer", () => { + it("buffers audio until ready then flushes in order", () => { + const s = initSessionState(); + const a = ab(1); + const b = ab(2); + expect(onAudio(s, a).send).toEqual([]); + expect(onAudio(s, b).send).toEqual([]); + expect(s.pending).toHaveLength(2); + + const ready = onReady(s); + expect(ready.send).toEqual([a, b]); // flushed in arrival order + expect(s.pending).toHaveLength(0); + expect(s.ready).toBe(true); + + // After ready, audio is sent immediately. + const c = ab(3); + expect(onAudio(s, c).send).toEqual([c]); + }); + + it("interim replaces interim", () => { + const s = initSessionState(); + expect(onInterim(s, "hel").emitInterim).toBe("hel"); + expect(onInterim(s, "hello").emitInterim).toBe("hello"); + expect(s.interim).toBe("hello"); + }); + + it("final trims and drops empty, clearing the interim", () => { + const s = initSessionState(); + onInterim(s, "draft"); + expect(onFinal(s, " hi there ").emitFinal).toBe("hi there"); + expect(s.interim).toBe(""); + + const empty = onFinal(s, " "); + expect(empty.emitFinal).toBeUndefined(); + }); + + it("cancel drops pending and ignores later events", () => { + const s = initSessionState(); + onAudio(s, ab(1)); + onCancel(s); + expect(s.pending).toHaveLength(0); + expect(s.canceled).toBe(true); + + // Later events are no-ops. + expect(onAudio(s, ab(2)).send).toEqual([]); + expect(onReady(s).send).toEqual([]); + expect(onInterim(s, "late").emitInterim).toBeUndefined(); + expect(onFinal(s, "late").emitFinal).toBeUndefined(); + }); + + it("closed/stop after stop is a no-op", () => { + const s = initSessionState(); + onReady(s); + onStop(s); + expect(s.canceled).toBe(true); + // Audio arriving after stop is ignored (server has no session). + expect(onAudio(s, ab(1)).send).toEqual([]); + expect(onInterim(s, "x").emitInterim).toBeUndefined(); + }); +}); diff --git a/apps/client/src/features/dictation/services/dictation-reducer.ts b/apps/client/src/features/dictation/services/dictation-reducer.ts new file mode 100644 index 00000000..2dd4bf72 --- /dev/null +++ b/apps/client/src/features/dictation/services/dictation-reducer.ts @@ -0,0 +1,113 @@ +// Pure logic extracted from `use-realtime-dictation` so the transcript/session +// state machine can be unit-tested without React or a live socket. The hook wires +// these to refs/callbacks; nothing here touches the DOM or Web Audio. + +/** + * Reduce a BCP-47 locale to its base language subtag for the upstream STT model, + * which expects an ISO language code, not a region-tagged locale. + * "en-US" → "en", "en" → "en", "" → undefined, " " → undefined. + * Returns undefined when no usable subtag exists so the server can omit the hint. + */ +export function baseLanguageSubtag(locale: string | undefined | null): string | undefined { + if (!locale) return undefined; + const base = locale.trim().split("-")[0]?.trim(); + return base && base.length > 0 ? base : undefined; +} + +/** + * Session/transcript reducer. Models the audio-buffering + interim/final/cancel + * lifecycle as a pure state object so the ordering rules (buffer-until-ready, + * cancel ignores later events, closed-after-stop is a no-op) are testable. The + * hook keeps the live socket/graph; this only decides what to emit. + */ +export interface DictationSessionState { + // Server has confirmed the upstream session; audio may flow. + ready: boolean; + // Local stop/cancel happened; later interim/final/audio are ignored. + canceled: boolean; + // Audio captured before `ready`; flushed in arrival order once ready. + pending: ArrayBuffer[]; + // Latest interim transcript for the live (not-yet-final) segment. + interim: string; +} + +export function initSessionState(): DictationSessionState { + return { ready: false, canceled: false, pending: [], interim: "" }; +} + +// Effects the hook should perform after a reduction. Keeps the reducer pure: it +// describes what to do, the hook does it (send over the socket, call callbacks). +export interface DictationEffects { + // Audio chunks to send upstream now, in order. + send: ArrayBuffer[]; + // Interim text to surface, if it changed. + emitInterim?: string; + // Final (trimmed, non-empty) text to surface. + emitFinal?: string; +} + +const NONE: DictationEffects = { send: [] }; + +/** Audio chunk captured: send immediately if ready, else buffer it. */ +export function onAudio( + state: DictationSessionState, + buf: ArrayBuffer, +): DictationEffects { + if (state.canceled) return NONE; + if (state.ready) return { send: [buf] }; + state.pending.push(buf); + return NONE; +} + +/** Server ready: flush all buffered audio in order, then stream live. */ +export function onReady(state: DictationSessionState): DictationEffects { + if (state.canceled) return NONE; + state.ready = true; + const send = state.pending; + state.pending = []; + return { send }; +} + +/** Interim transcript: replaces the previous interim for the live segment. */ +export function onInterim( + state: DictationSessionState, + text: string, +): DictationEffects { + if (state.canceled) return NONE; + state.interim = text; + return { send: [], emitInterim: text }; +} + +/** + * Final transcript: trim and drop if empty; the live interim segment is cleared + * (the final supersedes it). + */ +export function onFinal( + state: DictationSessionState, + text: string, +): DictationEffects { + if (state.canceled) return NONE; + const trimmed = text.trim(); + state.interim = ""; + if (trimmed.length === 0) return { send: [] }; + return { send: [], emitFinal: trimmed }; +} + +/** Cancel: drop pending audio and ignore all later events. */ +export function onCancel(state: DictationSessionState): DictationEffects { + state.canceled = true; + state.pending = []; + state.interim = ""; + return NONE; +} + +/** + * Stop: like cancel for the purposes of "no more events should be processed". + * Distinct name kept so the hook can flush the worklet tail before stopping; the + * reducer treats post-stop events as no-ops the same way. + */ +export function onStop(state: DictationSessionState): DictationEffects { + state.canceled = true; + state.pending = []; + return NONE; +} diff --git a/apps/client/src/features/dictation/services/realtime-dictation-client.test.ts b/apps/client/src/features/dictation/services/realtime-dictation-client.test.ts new file mode 100644 index 00000000..1fb71ef6 --- /dev/null +++ b/apps/client/src/features/dictation/services/realtime-dictation-client.test.ts @@ -0,0 +1,185 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +// --- Mock socket.io-client with a controllable fake socket -------------------- +// The mock records registered listeners so tests can fire server events, and +// records emits so the start/reconnect behavior can be asserted. +interface FakeSocket { + connected: boolean; + listeners: Record void)[]>; + emits: { event: string; args: unknown[] }[]; + on: (e: string, cb: (...a: unknown[]) => void) => FakeSocket; + emit: (e: string, ...a: unknown[]) => void; + connect: () => void; + disconnect: () => void; + removeAllListeners: () => void; + fire: (e: string, ...a: unknown[]) => void; +} + +function makeFakeSocket(): FakeSocket { + const socket: FakeSocket = { + connected: false, + listeners: {}, + emits: [], + on(e, cb) { + (socket.listeners[e] ??= []).push(cb); + return socket; + }, + emit(e, ...a) { + socket.emits.push({ event: e, args: a }); + }, + connect() { + socket.connected = true; + socket.fire("connect"); + }, + disconnect() { + socket.connected = false; + }, + removeAllListeners() { + socket.listeners = {}; + }, + fire(e, ...a) { + (socket.listeners[e] ?? []).forEach((cb) => cb(...a)); + }, + }; + return socket; +} + +let lastSocket: FakeSocket; +const ioMock = vi.fn((..._args: unknown[]) => { + lastSocket = makeFakeSocket(); + return lastSocket; +}); + +vi.mock("socket.io-client", () => ({ + io: (...args: unknown[]) => ioMock(...args), + Socket: class {}, +})); + +vi.mock("@/features/websocket/types", () => ({ SOCKET_URL: undefined })); + +import { RealtimeDictationClient } from "./realtime-dictation-client"; + +function makeHandlers() { + return { + onReady: vi.fn(), + onInterim: vi.fn(), + onFinal: vi.fn(), + onError: vi.fn(), + onClosed: vi.fn(), + }; +} + +beforeEach(() => { + ioMock.mockClear(); +}); + +describe("RealtimeDictationClient", () => { + it("uses a single io() call with the bare namespace URL and shared opts", () => { + const c = new RealtimeDictationClient(makeHandlers()); + c.connect(); + expect(ioMock).toHaveBeenCalledTimes(1); + const call = ioMock.mock.calls[0] as unknown[]; + expect(call[0]).toBe("/ai-realtime"); + expect(call[1]).toMatchObject({ + transports: ["websocket"], + withCredentials: true, + autoConnect: false, + }); + }); + + it("decodes ready/interim/final with ?? '' defaults", () => { + const h = makeHandlers(); + const c = new RealtimeDictationClient(h); + c.connect(); + + lastSocket.fire("ready"); + expect(h.onReady).toHaveBeenCalledTimes(1); + + lastSocket.fire("interim", { itemId: "a", text: "hi" }); + expect(h.onInterim).toHaveBeenCalledWith("a", "hi"); + lastSocket.fire("interim", {}); + expect(h.onInterim).toHaveBeenCalledWith("", ""); + + lastSocket.fire("final", { itemId: "b", text: "done" }); + expect(h.onFinal).toHaveBeenCalledWith("b", "done"); + lastSocket.fire("final", undefined); + expect(h.onFinal).toHaveBeenCalledWith("", ""); + }); + + it("surfaces error (string and object) and connect_error", () => { + const h = makeHandlers(); + const c = new RealtimeDictationClient(h); + c.connect(); + lastSocket.fire("error", "boom"); + expect(h.onError).toHaveBeenCalledWith("boom"); + }); + + it("error fires at most once per connection (error-once guard)", () => { + const h = makeHandlers(); + const c = new RealtimeDictationClient(h); + c.connect(); + lastSocket.fire("error", { message: "first" }); + lastSocket.fire("connect_error", new Error("second")); + lastSocket.fire("error", "third"); + expect(h.onError).toHaveBeenCalledTimes(1); + expect(h.onError).toHaveBeenCalledWith("first"); + }); + + it("connect_error builds a concrete message", () => { + const h = makeHandlers(); + const c = new RealtimeDictationClient(h); + c.connect(); + lastSocket.fire("connect_error", new Error("handshake")); + expect(h.onError).toHaveBeenCalledWith( + "Realtime connection failed: handshake", + ); + }); + + it("emits start once on first connect after start()", () => { + const c = new RealtimeDictationClient(makeHandlers()); + c.connect(); // fires connect, socket.connected = true + c.start({ language: "en" }); + const starts = lastSocket.emits.filter((e) => e.event === "start"); + expect(starts).toHaveLength(1); + expect(starts[0].args[0]).toEqual({ language: "en" }); + }); + + it("re-emits start on reconnect (does not double-start while live)", () => { + const c = new RealtimeDictationClient(makeHandlers()); + c.connect(); + c.start({ language: "en" }); + // A second connect with no disconnect must NOT re-start (still live). + lastSocket.fire("connect"); + let starts = lastSocket.emits.filter((e) => e.event === "start"); + expect(starts).toHaveLength(1); + + // Transient drop then reconnect → re-establish the session exactly once. + lastSocket.fire("disconnect"); + lastSocket.fire("connect"); + starts = lastSocket.emits.filter((e) => e.event === "start"); + expect(starts).toHaveLength(2); + expect(starts[1].args[0]).toEqual({ language: "en" }); + }); + + it("disconnect removes listeners and resets the error flag", () => { + const h = makeHandlers(); + const c = new RealtimeDictationClient(h); + c.connect(); + const removeSpy = vi.spyOn(lastSocket, "removeAllListeners"); + c.disconnect(); + expect(removeSpy).toHaveBeenCalled(); + expect(lastSocket.connected).toBe(false); + + // A fresh connect on the reused instance can error again. + c.connect(); + lastSocket.fire("error", "again"); + expect(h.onError).toHaveBeenCalledWith("again"); + }); + + it("connect() is a no-op while a socket already exists", () => { + const c = new RealtimeDictationClient(makeHandlers()); + c.connect(); + c.connect(); + expect(ioMock).toHaveBeenCalledTimes(1); + }); +}); diff --git a/apps/client/src/features/dictation/services/realtime-dictation-client.ts b/apps/client/src/features/dictation/services/realtime-dictation-client.ts index 98969c71..31c9b1d2 100644 --- a/apps/client/src/features/dictation/services/realtime-dictation-client.ts +++ b/apps/client/src/features/dictation/services/realtime-dictation-client.ts @@ -31,6 +31,13 @@ export class RealtimeDictationClient { // `connect_error` can both arrive (e.g. an error then a failed reconnect), but // the hook owns the error→idle flow and a second call would double-fire it. private erroredFlag = false; + // The last `start` params, retained so we can re-establish the upstream session + // after a transient socket.io reconnect (otherwise the server has no session and + // silently drops audio). Null until start() is first called. + private startOptions: StartOptions | null = null; + // True between a successful `start` emit and the next disconnect, so a reconnect + // re-emits `start` exactly once and we never double-start a live session. + private started = false; constructor(private readonly handlers: RealtimeDictationHandlers) {} @@ -50,21 +57,36 @@ export class RealtimeDictationClient { // SOCKET_URL is undefined in this app (socket.io derives the page origin), so // the `/ai-realtime` namespace rides the same `/socket.io` path as the main - // socket — which the Vite dev server proxies as a websocket. - const socket: Socket = SOCKET_URL - ? io(`${SOCKET_URL}/ai-realtime`, { - transports: ["websocket"], - withCredentials: true, - autoConnect: false, - }) - : io("/ai-realtime", { - transports: ["websocket"], - withCredentials: true, - autoConnect: false, - }); + // socket — which the Vite dev server proxies as a websocket. The URL is the + // only thing that varies; the options are shared (single io() call). + const url = SOCKET_URL ? `${SOCKET_URL}/ai-realtime` : "/ai-realtime"; + const socket: Socket = io(url, { + transports: ["websocket"], + withCredentials: true, + autoConnect: false, + }); this.socket = socket; + // On every (re)connect, re-establish the upstream session if start() has run + // but we are not currently in a started session. The first connect after + // start() is handled by start() itself (started === true by then); this branch + // covers reconnects after a transient drop, where the server lost the session + // and would otherwise silently discard all subsequent audio. The `started` + // guard prevents double-starting a live session. + socket.on("connect", () => { + if (this.startOptions && !this.started) { + this.started = true; + socket.emit("start", { language: this.startOptions.language }); + } + }); + + // A disconnect (transient drop or close) ends the server-side session; clear + // `started` so the next `connect` re-emits `start`. + socket.on("disconnect", () => { + this.started = false; + }); + socket.on("ready", () => this.handlers.onReady()); socket.on("interim", (payload: { itemId: string; text: string }) => { @@ -96,9 +118,16 @@ export class RealtimeDictationClient { socket.connect(); } - // Ask the server to resolve config and open the upstream STT session. + // Ask the server to resolve config and open the upstream STT session. The params + // are retained so a post-reconnect `connect` can re-establish the session. start(opts: StartOptions): void { - this.socket?.emit("start", { language: opts.language }); + this.startOptions = opts; + // If the socket is already connected, emit now and mark started; otherwise the + // `connect` handler will emit once the handshake completes. + if (this.socket?.connected && !this.started) { + this.started = true; + this.socket.emit("start", { language: opts.language }); + } } // Forward a raw PCM16 chunk; socket.io serializes the ArrayBuffer as binary. @@ -116,8 +145,11 @@ export class RealtimeDictationClient { const socket = this.socket; if (!socket) return; this.socket = null; - // Reset so a subsequent connect() on a reused instance can error again. + // Reset so a subsequent connect() on a reused instance can error again and a + // fresh session can be started. this.erroredFlag = false; + this.started = false; + this.startOptions = null; socket.removeAllListeners(); socket.disconnect(); } diff --git a/apps/client/src/features/editor/components/fixed-toolbar/groups/dictation-group.clamp.test.ts b/apps/client/src/features/editor/components/fixed-toolbar/groups/dictation-group.clamp.test.ts new file mode 100644 index 00000000..13331c24 --- /dev/null +++ b/apps/client/src/features/editor/components/fixed-toolbar/groups/dictation-group.clamp.test.ts @@ -0,0 +1,28 @@ +import { describe, it, expect } from "vitest"; +import { clampRange } from "./dictation-group"; + +describe("clampRange", () => { + it("returns the range unchanged when it is within bounds", () => { + expect(clampRange(2, 5, 10)).toEqual({ from: 2, to: 5 }); + }); + + it("clamps the upper bound to the doc size (off-by-one at the end)", () => { + expect(clampRange(8, 12, 10)).toEqual({ from: 8, to: 10 }); + // Exactly at the size is in-bounds, not clamped. + expect(clampRange(10, 10, 10)).toEqual({ from: 10, to: 10 }); + }); + + it("clamps negative positions up to 0 (off-by-one at the start)", () => { + expect(clampRange(-3, 4, 10)).toEqual({ from: 0, to: 4 }); + expect(clampRange(-5, -1, 10)).toEqual({ from: 0, to: 0 }); + }); + + it("clamps both ends when both are out of range", () => { + expect(clampRange(-2, 99, 6)).toEqual({ from: 0, to: 6 }); + }); + + it("handles an empty document (size 0)", () => { + expect(clampRange(0, 0, 0)).toEqual({ from: 0, to: 0 }); + expect(clampRange(3, 7, 0)).toEqual({ from: 0, to: 0 }); + }); +}); diff --git a/apps/client/src/features/editor/components/fixed-toolbar/groups/dictation-group.tsx b/apps/client/src/features/editor/components/fixed-toolbar/groups/dictation-group.tsx index 0d416129..7913a35c 100644 --- a/apps/client/src/features/editor/components/fixed-toolbar/groups/dictation-group.tsx +++ b/apps/client/src/features/editor/components/fixed-toolbar/groups/dictation-group.tsx @@ -13,6 +13,21 @@ interface Props { editor: Editor; } +/** + * Clamp a [from, to] range into the current document bounds [0, size]. The + * document may have shrunk during transcription (e.g. a collaborative edit), so + * a captured snapshot range can point past the end; clamp it before inserting. + * Pure + unit-testable. + */ +export function clampRange( + from: number, + to: number, + size: number, +): { from: number; to: number } { + const clamp = (p: number) => Math.max(0, Math.min(p, size)); + return { from: clamp(from), to: clamp(to) }; +} + export const DictationGroup: FC = ({ editor }) => { const workspace = useAtomValue(workspaceAtom); const isRealtime = workspace?.settings?.ai?.dictationRealtime === true; @@ -32,18 +47,15 @@ export const DictationGroup: FC = ({ editor }) => { // The document may have shrunk during transcription (e.g. a collaborative // edit), so clamp the snapshot into the current bounds before inserting. const docSize = editor.state.doc.content.size; - const clamp = (p: number) => Math.max(0, Math.min(p, docSize)); try { if (snapshot) { // Insert at the snapshotted caret; a trailing space keeps words // separated (the hook already trims the transcribed text). + const range = clampRange(snapshot.from, snapshot.to, docSize); editor .chain() .focus() - .insertContentAt( - { from: clamp(snapshot.from), to: clamp(snapshot.to) }, - `${text} `, - ) + .insertContentAt(range, `${text} `) .run(); } else { editor.chain().focus().insertContent(`${text} `).run(); @@ -59,27 +71,60 @@ export const DictationGroup: FC = ({ editor }) => { } }; - // Realtime path: commit each final segment at the LIVE caret (inserts happen - // during recording, so no fixed snapshot is needed); interim is shown via the - // ghost decoration only. + // Realtime path: commit each final segment at the SNAPSHOT range captured on + // start — the same snapshot the batch branch uses. The live caret can drift + // while a segment streams (the user keeps typing, a collaborator edits), so + // inserting at the current selection would land text in the wrong place and + // split words. We advance the snapshot past each committed segment so the + // next final lands right after it (left-to-right accumulation). Interim is + // shown via the ghost decoration only. if (isRealtime) { return ( { - if (editor && !editor.isDestroyed) clearDictationInterim(editor); + if (editor && !editor.isDestroyed) { + const { from, to } = editor.state.selection; + rangeRef.current = { from, to }; + clearDictationInterim(editor); + } }} onInterim={(text) => { if (editor && !editor.isDestroyed) setDictationInterim(editor, text); }} onFinal={(text) => { if (!editor || editor.isDestroyed) return; + // Never write into a read-only page (e.g. a collaborator revoked edit + // access, or this is a read-only view): dictated text must not land. + if (!editor.isEditable) return; clearDictationInterim(editor); + const snapshot = rangeRef.current; + const docSize = editor.state.doc.content.size; + const content = `${text} `; try { - editor.chain().focus().insertContent(`${text} `).run(); + if (snapshot) { + const range = clampRange(snapshot.from, snapshot.to, docSize); + editor + .chain() + .focus() + .insertContentAt(range, content) + .run(); + // Advance the snapshot past what we just inserted so the next + // final segment appends after it instead of overwriting it. + const next = range.from + content.length; + rangeRef.current = { from: next, to: next }; + } else { + editor.chain().focus().insertContent(content).run(); + } } catch { - // The editor may have been destroyed mid-stream; ignore. + // The snapshot drifted out of range or the editor was destroyed + // mid-stream; fall back to the current caret. + try { + editor.chain().focus().insertContent(content).run(); + } catch { + // The editor may have been destroyed; ignore. + } } }} /> diff --git a/apps/client/src/features/editor/extensions/dictation-interim/dictation-interim.test.ts b/apps/client/src/features/editor/extensions/dictation-interim/dictation-interim.test.ts new file mode 100644 index 00000000..1369dfa9 --- /dev/null +++ b/apps/client/src/features/editor/extensions/dictation-interim/dictation-interim.test.ts @@ -0,0 +1,158 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { Editor } from "@tiptap/core"; +import Document from "@tiptap/extension-document"; +import Paragraph from "@tiptap/extension-paragraph"; +import TiptapText from "@tiptap/extension-text"; +import { History } from "@tiptap/extension-history"; +import { + DictationInterim, + applyInterimMeta, + setDictationInterim, + clearDictationInterim, +} from "./dictation-interim"; + +// --- applyInterimMeta (pure reducer) --------------------------------------- + +describe("applyInterimMeta", () => { + it("replaces the interim text on a meta-only update", () => { + expect(applyInterimMeta({ text: "hello" }, { text: "" })).toEqual({ + text: "hello", + }); + }); + + it("replaces with empty text (clear) when meta carries an empty string", () => { + expect(applyInterimMeta({ text: "" }, { text: "old" })).toEqual({ + text: "", + }); + }); + + it("passes the previous state through when there is no meta", () => { + const prev = { text: "kept" }; + expect(applyInterimMeta(undefined, prev)).toBe(prev); + }); +}); + +// --- editor integration (regression guard) --------------------------------- +// +// A minimal headless editor in jsdom: doc/paragraph/text only, plus History so +// we can assert the interim never creates an undo step. The whole point of the +// feature is that interim text is a DECORATION, never written into the doc — +// these tests are the regression guard for that invariant. + +function makeEditor(content = "

seed

") { + const element = document.createElement("div"); + document.body.appendChild(element); + return new Editor({ + element, + extensions: [ + Document, + Paragraph, + TiptapText, + History, + DictationInterim, + ], + content, + }); +} + +describe("DictationInterim editor integration", () => { + let editor: Editor; + + beforeEach(() => { + editor = makeEditor(); + }); + + it("set/clear produce NO doc change and NO steps (interim is never inserted)", () => { + const before = editor.state.doc.toJSON(); + + let captured: { docChanged: boolean; steps: number } | null = null; + const handler = ({ transaction }: { transaction: any }) => { + captured = { + docChanged: transaction.docChanged, + steps: transaction.steps.length, + }; + }; + + editor.on("transaction", handler); + setDictationInterim(editor, "partial words"); + expect(captured).not.toBeNull(); + expect(captured!.docChanged).toBe(false); + expect(captured!.steps).toBe(0); + + captured = null; + clearDictationInterim(editor); + expect(captured!.docChanged).toBe(false); + expect(captured!.steps).toBe(0); + editor.off("transaction", handler); + + expect(editor.state.doc.toJSON()).toEqual(before); + }); + + it("interim updates add no undo steps; undo reverts typed text, not the interim", () => { + // Type real text (this IS an undoable step). + editor.commands.insertContent(" typed"); + const afterTyping = editor.getText(); + expect(afterTyping).toContain("typed"); + + // Interim updates must not stack onto history. + setDictationInterim(editor, "ghost"); + setDictationInterim(editor, "ghost two"); + clearDictationInterim(editor); + + // A single undo reverts the typed text — proving the interim added no + // history entries that would have to be undone first. + editor.commands.undo(); + expect(editor.getText()).not.toContain("typed"); + }); + + it("empty interim text yields no decoration widget", () => { + setDictationInterim(editor, ""); + const decos = decorationCount(editor); + expect(decos).toBe(0); + }); + + it("non-empty interim yields exactly one contenteditable=false widget at the caret", () => { + setDictationInterim(editor, "live"); + expect(decorationCount(editor)).toBe(1); + + const span = editor.view.dom.querySelector( + 'span[contenteditable="false"]', + ) as HTMLElement | null; + expect(span).not.toBeNull(); + expect(span!.textContent).toBe("live"); + }); + + it("the decoration remaps to follow the caret/selection head on edits", () => { + setDictationInterim(editor, "tail"); + const headBefore = decorationPos(editor); + expect(headBefore).toBe(editor.state.selection.head); + + // Move the caret to the document start; the widget must follow selection.head. + editor.commands.setTextSelection(1); + expect(decorationPos(editor)).toBe(editor.state.selection.head); + }); +}); + +// Helpers reaching into the plugin's decoration set via the editor view. +function decorationCount(editor: Editor): number { + let count = 0; + editor.state.plugins.forEach((plugin) => { + const decos = plugin.props?.decorations?.call(plugin, editor.state); + if (decos && (decos as any).find) { + count += (decos as any).find().length; + } + }); + return count; +} + +function decorationPos(editor: Editor): number | null { + let pos: number | null = null; + editor.state.plugins.forEach((plugin) => { + const decos = plugin.props?.decorations?.call(plugin, editor.state); + if (decos && (decos as any).find) { + const found = (decos as any).find(); + if (found.length > 0) pos = found[0].from; + } + }); + return pos; +} diff --git a/apps/client/src/features/editor/extensions/dictation-interim/dictation-interim.ts b/apps/client/src/features/editor/extensions/dictation-interim/dictation-interim.ts index d9019e54..2eb52258 100644 --- a/apps/client/src/features/editor/extensions/dictation-interim/dictation-interim.ts +++ b/apps/client/src/features/editor/extensions/dictation-interim/dictation-interim.ts @@ -14,6 +14,23 @@ interface DictationInterimState { text: string; } +/** + * Pure interim-state reducer (extracted for unit testing): a meta-only update + * replaces the interim text; any other transaction passes the previous state + * through unchanged. The decoration follows the caret on its own because it is + * recomputed against the live selection on every render — so non-meta edits do + * not need to touch this state. + */ +export function applyInterimMeta( + meta: DictationInterimState | undefined, + prev: DictationInterimState, +): DictationInterimState { + if (meta) { + return { text: meta.text }; + } + return prev; +} + /** * B2 editor decoration: shows the realtime interim (partial) transcript as a * ghost widget at the caret. The interim is held ONLY in plugin meta state and @@ -34,13 +51,7 @@ export const DictationInterim = Extension.create({ const meta = tr.getMeta(dictationInterimKey) as | DictationInterimState | undefined; - // Meta-only updates replace the interim text; everything else keeps - // the existing value (it follows the caret on its own since the - // decoration is recomputed against the live selection). - if (meta) { - return { text: meta.text }; - } - return value; + return applyInterimMeta(meta, value); }, }, props: { diff --git a/apps/client/src/features/workspace/components/settings/components/ai-provider-settings.pure.test.ts b/apps/client/src/features/workspace/components/settings/components/ai-provider-settings.pure.test.ts new file mode 100644 index 00000000..1f0cbec9 --- /dev/null +++ b/apps/client/src/features/workspace/components/settings/components/ai-provider-settings.pure.test.ts @@ -0,0 +1,103 @@ +import { describe, it, expect } from "vitest"; +import { + resolveUrl, + resolveKeyField, + resolveCardStatus, + isEndpointConfigured, +} from "./ai-provider-settings"; + +describe("resolveUrl", () => { + it("trims a single trailing slash before appending the path", () => { + expect(resolveUrl("https://api.example.com/", "/chat/completions")).toBe( + "https://api.example.com/chat/completions", + ); + }); + + it("leaves a base without a trailing slash intact", () => { + expect(resolveUrl("https://api.example.com", "/embeddings")).toBe( + "https://api.example.com/embeddings", + ); + }); + + it("falls back to the fallback base when the base is empty", () => { + expect(resolveUrl("", "/audio/transcriptions", "https://chat.example.com")).toBe( + "https://chat.example.com/audio/transcriptions", + ); + }); + + it("falls back when the base is whitespace-only", () => { + expect(resolveUrl(" ", "/embeddings", "https://chat.example.com")).toBe( + "https://chat.example.com/embeddings", + ); + }); + + it("returns just the path when both base and fallback are empty", () => { + expect(resolveUrl("", "/chat/completions")).toBe("/chat/completions"); + }); +}); + +describe("resolveKeyField", () => { + it("a non-empty buffer is set to that value", () => { + expect(resolveKeyField("secret", false)).toEqual({ + set: true, + value: "secret", + }); + }); + + it("explicitly cleared with an empty buffer sets the empty string", () => { + expect(resolveKeyField("", true)).toEqual({ set: true, value: "" }); + }); + + it("untouched (empty buffer, not cleared) omits the key", () => { + expect(resolveKeyField("", false)).toEqual({ set: false }); + }); + + it("a buffer wins over the cleared flag (typed key takes precedence)", () => { + // Security-relevant for the write-only sttApiKey: a freshly typed secret + // must be written even if a prior clear was requested. + expect(resolveKeyField("new-secret", true)).toEqual({ + set: true, + value: "new-secret", + }); + }); +}); + +describe("isEndpointConfigured", () => { + it("model + own base URL -> configured", () => { + expect(isEndpointConfigured("model", "https://own", "")).toBe(true); + }); + + it("model + inherited chat base URL (own empty) -> configured", () => { + expect(isEndpointConfigured("model", "", "https://chat")).toBe(true); + }); + + it("model set but both base URLs empty -> not configured", () => { + expect(isEndpointConfigured("model", "", "")).toBe(false); + }); + + it("whitespace-only base URLs do not count as filled", () => { + expect(isEndpointConfigured("model", " ", " ")).toBe(false); + }); + + it("empty model -> not configured even with a base URL", () => { + expect(isEndpointConfigured("", "https://own", "https://chat")).toBe(false); + }); +}); + +describe("resolveCardStatus", () => { + it("configured + enabled -> ready", () => { + expect(resolveCardStatus(true, true)).toBe("ready"); + }); + + it("configured + disabled -> configured", () => { + expect(resolveCardStatus(true, false)).toBe("configured"); + }); + + it("not configured + disabled -> off", () => { + expect(resolveCardStatus(false, false)).toBe("off"); + }); + + it("enabled but not configured -> warning (a real misconfiguration)", () => { + expect(resolveCardStatus(false, true)).toBe("warning"); + }); +}); diff --git a/apps/client/src/features/workspace/components/settings/components/ai-provider-settings.tsx b/apps/client/src/features/workspace/components/settings/components/ai-provider-settings.tsx index c4b93a96..eceb505a 100644 --- a/apps/client/src/features/workspace/components/settings/components/ai-provider-settings.tsx +++ b/apps/client/src/features/workspace/components/settings/components/ai-provider-settings.tsx @@ -86,7 +86,7 @@ type CardStatus = "ready" | "configured" | "off" | "warning"; // Resolve a "Base URL + path" hint defensively: trim a single trailing slash // off the base, then append the path. Empty base falls back to `fallback` // (the chat base URL for the embedding/voice endpoints). Purely cosmetic. -function resolveUrl(base: string, path: string, fallback = ""): string { +export function resolveUrl(base: string, path: string, fallback = ""): string { const trimmed = (base.trim() || fallback.trim()).replace(/\/$/, ""); return `${trimmed}${path}`; } diff --git a/apps/client/src/features/workspace/services/ai-settings-service.test.ts b/apps/client/src/features/workspace/services/ai-settings-service.test.ts new file mode 100644 index 00000000..aff6f21c --- /dev/null +++ b/apps/client/src/features/workspace/services/ai-settings-service.test.ts @@ -0,0 +1,47 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +// Mock the api client module so we can assert the route and unwrap behavior +// without a real network call. +const post = vi.fn(); +vi.mock("@/lib/api-client", () => ({ + default: { post: (...args: unknown[]) => post(...args) }, +})); + +import { testRealtimeConnection } from "./ai-settings-service"; + +describe("testRealtimeConnection", () => { + beforeEach(() => { + post.mockReset(); + }); + + it("POSTs to the /ai-chat/realtime/test route (NOT the /workspace/ai-settings prefix)", async () => { + post.mockResolvedValue({ data: { ok: true } }); + await testRealtimeConnection(); + + expect(post).toHaveBeenCalledTimes(1); + const [route] = post.mock.calls[0]; + expect(route).toBe("/ai-chat/realtime/test"); + expect(route).not.toContain("/workspace/ai-settings"); + }); + + it("sends no request body", async () => { + post.mockResolvedValue({ data: { ok: true } }); + await testRealtimeConnection(); + + // Only the route argument — no payload. + expect(post.mock.calls[0].length).toBe(1); + }); + + it("unwraps the { ok } envelope from res.data", async () => { + post.mockResolvedValue({ data: { ok: true } }); + await expect(testRealtimeConnection()).resolves.toEqual({ ok: true }); + }); + + it("surfaces the failure envelope (ok:false + error) verbatim", async () => { + post.mockResolvedValue({ data: { ok: false, error: "unreachable" } }); + await expect(testRealtimeConnection()).resolves.toEqual({ + ok: false, + error: "unreachable", + }); + }); +}); diff --git a/apps/server/src/core/ai-chat/ai-chat.module.ts b/apps/server/src/core/ai-chat/ai-chat.module.ts index 311d549b..42f785db 100644 --- a/apps/server/src/core/ai-chat/ai-chat.module.ts +++ b/apps/server/src/core/ai-chat/ai-chat.module.ts @@ -49,8 +49,12 @@ import { AiRealtimeService } from './realtime/ai-realtime.service'; PublicShareChatService, PublicShareChatToolsService, // Realtime dictation: the Socket.IO `/ai-realtime` gateway + its upstream - // proxy service. AiSettingsService comes from AiModule; WorkspaceRepo from - // the global DatabaseModule; TokenService from TokenModule (both imported). + // proxy service. AiSettingsService comes from AiModule; WorkspaceRepo / + // UserRepo / UserSessionRepo from the global DatabaseModule; + // EnvironmentService from the global EnvironmentModule; TokenService from + // TokenModule (imported). The repos/env are used by the gateway's + // handleConnection auth (active-session + disabled-user) and CSWSH Origin + // checks that mirror the REST jwt.strategy. AiRealtimeGateway, AiRealtimeService, ], diff --git a/apps/server/src/core/ai-chat/ai-chat.realtime-test.spec.ts b/apps/server/src/core/ai-chat/ai-chat.realtime-test.spec.ts new file mode 100644 index 00000000..0a38d0b0 --- /dev/null +++ b/apps/server/src/core/ai-chat/ai-chat.realtime-test.spec.ts @@ -0,0 +1,68 @@ +import { ForbiddenException } from '@nestjs/common'; +import { AiChatController } from './ai-chat.controller'; +import WorkspaceAbilityFactory from '../casl/abilities/workspace-ability.factory'; +import { UserRole } from '../../common/helpers/types/permission'; + +// Contract for POST ai-chat/realtime/test: +// - admin gate: CASL Manage Settings. A non-admin (MEMBER) → Forbidden BEFORE +// aiRealtimeService.testConnection is ever called. +// - response form is the FROZEN { ok: true } | { ok: false, error } passthrough +// from the service (the global transform wraps it; the client reads req.data). +// Constructed directly with stubs + the REAL ability factory (the gate under +// test), mirroring the other controller specs in this codebase. + +function buildController(testConnection = jest.fn()) { + const aiRealtimeService = { testConnection }; + const controller = new AiChatController( + {} as any, // aiChatService + {} as any, // aiChatRepo + {} as any, // aiChatMessageRepo + {} as any, // aiTranscription + aiRealtimeService as any, + new WorkspaceAbilityFactory(), // REAL ability factory (the admin gate) + ); + return { controller, testConnection }; +} + +const workspace = { id: 'w1' } as any; +const userWith = (role: UserRole) => ({ id: 'u1', role }) as any; + +describe('AiChatController.testRealtime admin gate', () => { + it('forbids a MEMBER and never calls testConnection', async () => { + const { controller, testConnection } = buildController(); + + await expect( + controller.testRealtime(userWith(UserRole.MEMBER), workspace), + ).rejects.toBeInstanceOf(ForbiddenException); + + expect(testConnection).not.toHaveBeenCalled(); + }); + + it('allows an OWNER and returns the service result for the workspace', async () => { + const testConnection = jest.fn().mockResolvedValue({ ok: true }); + const { controller } = buildController(testConnection); + + const result = await controller.testRealtime( + userWith(UserRole.OWNER), + workspace, + ); + + expect(testConnection).toHaveBeenCalledWith('w1'); + expect(result).toEqual({ ok: true }); + }); + + it('allows an ADMIN and passes through a failure result verbatim', async () => { + const testConnection = jest + .fn() + .mockResolvedValue({ ok: false, error: 'boom' }); + const { controller } = buildController(testConnection); + + const result = await controller.testRealtime( + userWith(UserRole.ADMIN), + workspace, + ); + + // The frozen { ok:false, error } shape is forwarded unchanged. + expect(result).toEqual({ ok: false, error: 'boom' }); + }); +}); diff --git a/apps/server/src/core/ai-chat/realtime/ai-realtime.gateway.spec.ts b/apps/server/src/core/ai-chat/realtime/ai-realtime.gateway.spec.ts new file mode 100644 index 00000000..7a261c50 --- /dev/null +++ b/apps/server/src/core/ai-chat/realtime/ai-realtime.gateway.spec.ts @@ -0,0 +1,494 @@ +import { + AiRealtimeGateway, + __testCounters, +} from './ai-realtime.gateway'; +import { AiSttNotConfiguredException } from '../../../integrations/ai/ai-stt-not-configured.exception'; + +// The gateway authenticates over a cookie-JWT and mirrors the REST jwt.strategy +// active-session / disabled-user checks, enforces a CSWSH Origin allowlist, then +// gates on the workspace feature flags + per-process concurrency caps. These +// specs construct the gateway with `new` + mocked deps (no Socket.IO / DB) and +// drive handleConnection / handleStart / handleAudio / handleStop / +// handleDisconnect directly. + +const APP_ORIGIN = 'https://app.example.com'; + +function buildSocket(opts?: { + cookie?: string; + origin?: string | undefined; + noOriginHeader?: boolean; +}) { + const headers: Record = {}; + if (opts?.cookie !== undefined) headers.cookie = opts.cookie; + if (!opts?.noOriginHeader) { + headers.origin = opts?.origin ?? APP_ORIGIN; + } + return { + handshake: { headers }, + data: {} as any, + emit: jest.fn(), + disconnect: jest.fn(), + }; +} + +function buildGateway(overrides?: { + verifyJwt?: jest.Mock; + findUser?: jest.Mock; + findActiveSession?: jest.Mock; + findWorkspace?: jest.Mock; + openSession?: jest.Mock; +}) { + const tokenService = { + verifyJwt: + overrides?.verifyJwt ?? + jest + .fn() + .mockResolvedValue({ sub: 'u1', workspaceId: 'w1', sessionId: 's1' }), + }; + const userRepo = { + findById: + overrides?.findUser ?? + jest.fn().mockResolvedValue({ id: 'u1', deactivatedAt: null, deletedAt: null }), + }; + const userSessionRepo = { + findActiveById: + overrides?.findActiveSession ?? + jest + .fn() + .mockResolvedValue({ id: 's1', userId: 'u1', workspaceId: 'w1' }), + }; + const environmentService = { + getAppUrl: jest.fn().mockReturnValue(APP_ORIGIN), + }; + const workspaceRepo = { + findById: + overrides?.findWorkspace ?? + jest.fn().mockResolvedValue({ + settings: { ai: { dictation: true, dictationRealtime: true } }, + }), + }; + const aiRealtimeService = { + openSession: overrides?.openSession ?? jest.fn(), + }; + + const gateway = new AiRealtimeGateway( + tokenService as any, + workspaceRepo as any, + userRepo as any, + userSessionRepo as any, + environmentService as any, + aiRealtimeService as any, + ); + return { + gateway, + tokenService, + userRepo, + userSessionRepo, + environmentService, + workspaceRepo, + aiRealtimeService, + }; +} + +beforeEach(() => { + __testCounters.reset(); +}); + +describe('AiRealtimeGateway.handleConnection auth', () => { + it('rejects an invalid/expired JWT: error Unauthorized, disconnect, no counter', async () => { + const { gateway } = buildGateway({ + verifyJwt: jest.fn().mockRejectedValue(new Error('jwt expired')), + }); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: 'Unauthorized', + }); + expect(socket.disconnect).toHaveBeenCalled(); + expect(__testCounters.user.count('u1')).toBe(0); + expect(__testCounters.workspace.count('w1')).toBe(0); + }); + + it('rejects a missing cookie (no authToken) the same way', async () => { + const { gateway } = buildGateway({ + verifyJwt: jest.fn().mockRejectedValue(new Error('no token')), + }); + const socket = buildSocket({ cookie: '' }); + await gateway.handleConnection(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: 'Unauthorized', + }); + expect(socket.disconnect).toHaveBeenCalled(); + }); + + it('rejects a disabled (deactivated) user: Unauthorized, no counter', async () => { + const { gateway } = buildGateway({ + findUser: jest.fn().mockResolvedValue({ + id: 'u1', + deactivatedAt: new Date(), + deletedAt: null, + }), + }); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: 'Unauthorized', + }); + expect(socket.disconnect).toHaveBeenCalled(); + expect(__testCounters.user.count('u1')).toBe(0); + }); + + it('rejects a missing user', async () => { + const { gateway } = buildGateway({ + findUser: jest.fn().mockResolvedValue(undefined), + }); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + expect(socket.emit).toHaveBeenCalledWith('error', { + message: 'Unauthorized', + }); + }); + + it('rejects an inactive/revoked session even with a valid signature', async () => { + const { gateway } = buildGateway({ + findActiveSession: jest.fn().mockResolvedValue(undefined), + }); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: 'Unauthorized', + }); + expect(socket.disconnect).toHaveBeenCalled(); + expect(__testCounters.user.count('u1')).toBe(0); + }); + + it('rejects a session whose user/workspace does not match the token', async () => { + const { gateway } = buildGateway({ + findActiveSession: jest + .fn() + .mockResolvedValue({ id: 's1', userId: 'OTHER', workspaceId: 'w1' }), + }); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + expect(socket.emit).toHaveBeenCalledWith('error', { + message: 'Unauthorized', + }); + }); + + it('rejects a bad Origin (CSWSH) before auth and increments nothing', async () => { + const { gateway, tokenService } = buildGateway(); + const socket = buildSocket({ origin: 'https://evil.example.com' }); + await gateway.handleConnection(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: 'Unauthorized', + }); + expect(socket.disconnect).toHaveBeenCalled(); + // Origin check runs first: auth is never attempted. + expect(tokenService.verifyJwt).not.toHaveBeenCalled(); + expect(__testCounters.user.count('u1')).toBe(0); + }); + + it('allows a request with no Origin header (native/non-browser client)', async () => { + const { gateway } = buildGateway(); + const socket = buildSocket({ noOriginHeader: true }); + await gateway.handleConnection(socket as any); + + expect(socket.disconnect).not.toHaveBeenCalled(); + expect(__testCounters.user.count('u1')).toBe(1); + expect(__testCounters.workspace.count('w1')).toBe(1); + }); + + it('accepts a matching Origin and increments both counters', async () => { + const { gateway } = buildGateway(); + const socket = buildSocket({ origin: APP_ORIGIN }); + await gateway.handleConnection(socket as any); + + expect(socket.disconnect).not.toHaveBeenCalled(); + expect(__testCounters.user.count('u1')).toBe(1); + expect(__testCounters.workspace.count('w1')).toBe(1); + expect(socket.data.countedUserId).toBe('u1'); + expect(socket.data.countedWorkspaceId).toBe('w1'); + }); +}); + +describe('AiRealtimeGateway.handleConnection gate + caps', () => { + it('disconnects when the feature gate is off and leaves counters clean', async () => { + const { gateway } = buildGateway({ + findWorkspace: jest + .fn() + .mockResolvedValue({ settings: { ai: { dictation: true } } }), + }); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: expect.stringMatching(/not enabled/i), + }); + expect(socket.disconnect).toHaveBeenCalled(); + expect(__testCounters.user.count('u1')).toBe(0); + expect(__testCounters.workspace.count('w1')).toBe(0); + }); + + it('refuses when the per-user cap is already reached (no increment)', async () => { + __testCounters.user.increment('u1'); // user already at cap (1) + const { gateway } = buildGateway(); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: expect.stringMatching(/already active/i), + }); + expect(socket.disconnect).toHaveBeenCalled(); + // Still exactly 1 (the pre-existing slot), not bumped to 2. + expect(__testCounters.user.count('u1')).toBe(1); + }); + + it('checks BOTH caps before incrementing EITHER (workspace cap full → user untouched)', async () => { + // Workspace at its cap (5), user at 0. The connection must be refused and the + // user counter must NOT have been bumped. + for (let i = 0; i < 5; i++) __testCounters.workspace.increment('w1'); + const { gateway } = buildGateway(); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: expect.stringMatching(/maximum number/i), + }); + expect(__testCounters.user.count('u1')).toBe(0); + expect(__testCounters.workspace.count('w1')).toBe(5); + }); +}); + +describe('AiRealtimeGateway.handleStart lifecycle', () => { + function connectedSocket() { + return { + handshake: { headers: {} }, + data: { userId: 'u1', workspaceId: 'w1' } as any, + emit: jest.fn(), + disconnect: jest.fn(), + }; + } + + it('relays onReady→ready, onInterim→interim{itemId,text}, onFinal→final', async () => { + let captured: any; + const openSession = jest.fn().mockImplementation((_ws, opts) => { + captured = opts; + return { appendAudio: jest.fn(), stop: jest.fn(), close: jest.fn() }; + }); + const { gateway } = buildGateway({ openSession }); + const socket = connectedSocket(); + await gateway.handleStart(socket as any, { language: 'en' }); + + expect(openSession).toHaveBeenCalledWith('w1', expect.any(Object)); + captured.onReady(); + expect(socket.emit).toHaveBeenCalledWith('ready', {}); + captured.onInterim('item-1', 'hello'); + expect(socket.emit).toHaveBeenCalledWith('interim', { + itemId: 'item-1', + text: 'hello', + }); + captured.onFinal('item-1', 'hello world'); + expect(socket.emit).toHaveBeenCalledWith('final', { + itemId: 'item-1', + text: 'hello world', + }); + expect(socket.data.handle).toBeDefined(); + }); + + it('onClosed clears the handle and emits closed (releases double-start guard)', async () => { + let captured: any; + const openSession = jest.fn().mockImplementation((_ws, opts) => { + captured = opts; + return { appendAudio: jest.fn(), stop: jest.fn(), close: jest.fn() }; + }); + const { gateway } = buildGateway({ openSession }); + const socket = connectedSocket(); + await gateway.handleStart(socket as any); + + expect(socket.data.handle).toBeDefined(); + captured.onClosed(); + expect(socket.data.handle).toBeUndefined(); + expect(socket.emit).toHaveBeenCalledWith('closed', {}); + }); + + it('guards a double-start: a session already in progress → error, openSession called once', async () => { + const openSession = jest.fn().mockResolvedValue({ + appendAudio: jest.fn(), + stop: jest.fn(), + close: jest.fn(), + }); + const { gateway } = buildGateway({ openSession }); + const socket = connectedSocket(); + await gateway.handleStart(socket as any); + await gateway.handleStart(socket as any); + + expect(openSession).toHaveBeenCalledTimes(1); + expect(socket.emit).toHaveBeenCalledWith('error', { + message: expect.stringMatching(/already in progress/i), + }); + }); + + it('surfaces AiSttNotConfigured message verbatim (a clean 503 reason)', async () => { + const err = new AiSttNotConfiguredException(); + const { gateway } = buildGateway({ + openSession: jest.fn().mockRejectedValue(err), + }); + const socket = connectedSocket(); + await gateway.handleStart(socket as any); + + expect(socket.emit).toHaveBeenCalledWith('error', { + message: err.message, + }); + }); + + it('does not leak a raw key/stack on a provider error (uses describeProviderError)', async () => { + const err = new Error('connect ECONNREFUSED sk-secret-key-1234'); + (err as any).stack = 'Error: at openSession sk-secret-key-1234'; + const { gateway } = buildGateway({ + openSession: jest.fn().mockRejectedValue(err), + }); + const socket = connectedSocket(); + await gateway.handleStart(socket as any); + + const call = socket.emit.mock.calls.find((c) => c[0] === 'error'); + expect(call).toBeDefined(); + const message = call![1].message as string; + // A concrete, sanitized reason — never the raw stack. + expect(message).not.toContain('at openSession'); + expect(typeof message).toBe('string'); + }); +}); + +describe('AiRealtimeGateway.handleAudio / handleStop guards', () => { + function socketWithHandle(handle?: any) { + return { + handshake: { headers: {} }, + data: { userId: 'u1', workspaceId: 'w1', handle } as any, + emit: jest.fn(), + disconnect: jest.fn(), + }; + } + + it('no handle → handleAudio never calls appendAudio', () => { + const { gateway } = buildGateway(); + const socket = socketWithHandle(undefined); + gateway.handleAudio(socket as any, Buffer.from([1, 2, 3])); + // Nothing to assert beyond not throwing — there is no handle to call. + expect(socket.data.handle).toBeUndefined(); + }); + + it('a Buffer payload → exactly one appendAudio', () => { + const appendAudio = jest.fn(); + const { gateway } = buildGateway(); + const socket = socketWithHandle({ appendAudio }); + gateway.handleAudio(socket as any, Buffer.from([1, 2, 3])); + expect(appendAudio).toHaveBeenCalledTimes(1); + }); + + it('a non-binary payload → no appendAudio', () => { + const appendAudio = jest.fn(); + const { gateway } = buildGateway(); + const socket = socketWithHandle({ appendAudio }); + gateway.handleAudio(socket as any, 'not binary'); + expect(appendAudio).not.toHaveBeenCalled(); + }); + + it('handleStop with no handle does not throw', () => { + const { gateway } = buildGateway(); + const socket = socketWithHandle(undefined); + expect(() => gateway.handleStop(socket as any)).not.toThrow(); + }); + + it('handleStop with a handle calls stop once', () => { + const stop = jest.fn(); + const { gateway } = buildGateway(); + const socket = socketWithHandle({ stop }); + gateway.handleStop(socket as any); + expect(stop).toHaveBeenCalledTimes(1); + }); +}); + +describe('AiRealtimeGateway.handleDisconnect no-leak', () => { + it('decrements both counters for a fully-accepted connection', async () => { + const { gateway } = buildGateway(); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + expect(__testCounters.user.count('u1')).toBe(1); + + gateway.handleDisconnect(socket as any); + expect(__testCounters.user.count('u1')).toBe(0); + expect(__testCounters.workspace.count('w1')).toBe(0); + }); + + it('is a no-op when the connection was rejected before incrementing', async () => { + const { gateway } = buildGateway({ + verifyJwt: jest.fn().mockRejectedValue(new Error('bad')), + }); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + // No counted ids stashed → disconnect must not touch counters. + gateway.handleDisconnect(socket as any); + expect(__testCounters.user.count('u1')).toBe(0); + expect(__testCounters.workspace.count('w1')).toBe(0); + }); + + it('is idempotent: a double disconnect never goes negative', async () => { + const { gateway } = buildGateway(); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + + gateway.handleDisconnect(socket as any); + gateway.handleDisconnect(socket as any); + expect(__testCounters.user.count('u1')).toBe(0); + expect(__testCounters.workspace.count('w1')).toBe(0); + }); + + it('closes the upstream handle on disconnect', async () => { + const close = jest.fn(); + const { gateway } = buildGateway(); + const socket = buildSocket(); + await gateway.handleConnection(socket as any); + socket.data.handle = { close }; + + gateway.handleDisconnect(socket as any); + expect(close).toHaveBeenCalledTimes(1); + expect(socket.data.handle).toBeUndefined(); + }); +}); + +describe('AiRealtimeGateway.toBuffer', () => { + // toBuffer is private static; access it via the class for the documented + // contract (Buffer/Uint8Array/ArrayBuffer → Buffer; everything else → null). + const toBuffer = (AiRealtimeGateway as any).toBuffer as ( + p: unknown, + ) => Buffer | null; + + it('passes a Buffer through', () => { + const buf = Buffer.from([1, 2, 3]); + expect(toBuffer(buf)).toBe(buf); + }); + + it('converts a Uint8Array to a Buffer', () => { + const out = toBuffer(new Uint8Array([1, 2, 3])); + expect(Buffer.isBuffer(out)).toBe(true); + expect(out).toEqual(Buffer.from([1, 2, 3])); + }); + + it('converts an ArrayBuffer to a Buffer', () => { + const out = toBuffer(new Uint8Array([4, 5]).buffer); + expect(Buffer.isBuffer(out)).toBe(true); + }); + + it('returns null for a string', () => { + expect(toBuffer('abc')).toBeNull(); + }); + + it('returns null for null', () => { + expect(toBuffer(null)).toBeNull(); + }); +}); diff --git a/apps/server/src/core/ai-chat/realtime/ai-realtime.gateway.ts b/apps/server/src/core/ai-chat/realtime/ai-realtime.gateway.ts index 5261e828..3bd52719 100644 --- a/apps/server/src/core/ai-chat/realtime/ai-realtime.gateway.ts +++ b/apps/server/src/core/ai-chat/realtime/ai-realtime.gateway.ts @@ -10,12 +10,22 @@ import * as cookie from 'cookie'; import { TokenService } from '../../auth/services/token.service'; import { JwtPayload, JwtType } from '../../auth/dto/jwt-payload'; import { WorkspaceRepo } from '@docmost/db/repos/workspace/workspace.repo'; +import { UserRepo } from '@docmost/db/repos/user/user.repo'; +import { UserSessionRepo } from '@docmost/db/repos/session/user-session.repo'; +import { EnvironmentService } from '../../../integrations/environment/environment.service'; +import { isUserDisabled } from '../../../common/helpers'; import { AiSttNotConfiguredException } from '../../../integrations/ai/ai-stt-not-configured.exception'; import { describeProviderError } from '../../../integrations/ai/ai-error.util'; import { AiRealtimeService, RealtimeSessionHandle, } from './ai-realtime.service'; +import { + SessionCounters, + canConnect, + MAX_SESSIONS_PER_USER, + MAX_SESSIONS_PER_WORKSPACE, +} from './session-limits'; /** * Realtime dictation gateway — the server side of the FROZEN normalized @@ -30,32 +40,24 @@ import { * `settings.ai.dictation === true` AND `settings.ai.dictationRealtime === true`. * Hard concurrency caps (realtime is expensive) are enforced in-memory per user * and per workspace. + * + * ──────────────────────────────────────────────────────────────────────────── + * OPS WARNING — CONCURRENCY CAPS ARE PER-PROCESS / IN-MEMORY. + * The `userSessions` / `workspaceSessions` counters live in this Node process's + * heap. They are correct ONLY on a SINGLE API replica. With N horizontally + * scaled replicas behind a load balancer the EFFECTIVE limit is N × the cap + * (each replica counts only the sockets it terminates), and the caps are not + * shared. A multi-replica deployment that must enforce a true global cap needs + * a shared store (e.g. Redis). This is by design for the single-process default + * and is documented loudly so it is not mistaken for a global guarantee. + * ──────────────────────────────────────────────────────────────────────────── */ -/** Realtime is expensive: one live session per user, a handful per workspace. */ -const MAX_SESSIONS_PER_USER = 1; -const MAX_SESSIONS_PER_WORKSPACE = 5; - -// Module-level concurrency counters. A single Node process backs the gateway; -// these caps are best-effort within that process (a horizontally-scaled -// deployment would need a shared store, out of scope here). -const sessionsPerUser = new Map(); -const sessionsPerWorkspace = new Map(); - -function incr(map: Map, key: string): number { - const next = (map.get(key) ?? 0) + 1; - map.set(key, next); - return next; -} - -function decr(map: Map, key: string): void { - const next = (map.get(key) ?? 0) - 1; - if (next <= 0) { - map.delete(key); - } else { - map.set(key, next); - } -} +// Module-level concurrency counters (one Node process backs the gateway). See +// the OPS WARNING above: these are per-process/in-memory and only correct on a +// single API replica. +const userSessions = new SessionCounters(); +const workspaceSessions = new SessionCounters(); /** Per-socket state we stash on client.data. */ interface RealtimeClientData { @@ -69,6 +71,12 @@ interface RealtimeClientData { @WebSocketGateway({ namespace: '/ai-realtime', + // CORS origin is '*' at the socket.io layer, but this is a NEW externally + // reachable, provider-billed surface that authenticates via cookie. To block + // cross-site WebSocket hijacking (CSWSH) we enforce an explicit Origin + // allowlist in handleConnection (see assertAllowedOrigin) rather than relying + // on this permissive transport-level setting. (The pre-existing WsGateway is + // out of scope and intentionally unchanged.) cors: { origin: '*' }, transports: ['websocket'], }) @@ -80,11 +88,34 @@ export class AiRealtimeGateway constructor( private readonly tokenService: TokenService, private readonly workspaceRepo: WorkspaceRepo, + private readonly userRepo: UserRepo, + private readonly userSessionRepo: UserSessionRepo, + private readonly environmentService: EnvironmentService, private readonly aiRealtimeService: AiRealtimeService, ) {} + /** + * CSWSH guard: allow the handshake only when the browser-supplied `Origin` + * matches the app's configured origin, OR when there is no Origin header at + * all (native / non-browser clients never send one and cannot be coerced by a + * malicious page into doing so). A mismatched Origin → reject. Scoped to this + * realtime gateway only. + */ + private assertAllowedOrigin(client: Socket): void { + const origin = client.handshake.headers.origin; + // No Origin header → non-browser client; nothing to spoof, allow. + if (!origin) return; + const appOrigin = this.environmentService.getAppUrl(); + if (origin !== appOrigin) { + throw new Error(`Origin not allowed: ${origin}`); + } + } + async handleConnection(client: Socket): Promise { try { + // 1) CSWSH Origin allowlist (before touching auth or any counter). + this.assertAllowedOrigin(client); + const cookies = cookie.parse(client.handshake.headers.cookie ?? ''); const token: JwtPayload = await this.tokenService.verifyJwt( cookies['authToken'], @@ -93,55 +124,60 @@ export class AiRealtimeGateway const userId = token.sub; const workspaceId = token.workspaceId; + if (!workspaceId) { + throw new Error('Token missing workspaceId'); + } + + // 2) Mirror the REST jwt.strategy active-session / disabled-user checks + // (jwt.strategy.ts:53-72): a signed-but-stale token (revoked session, + // deactivated/deleted user) must NOT open a realtime session even though + // the signature is still valid until expiry (default 90d). Verifying the + // signature alone is not enough. + const user = await this.userRepo.findById(userId, workspaceId); + if (!user || isUserDisabled(user)) { + throw new Error('User disabled or not found'); + } + const sessionId = token.sessionId; + if (sessionId) { + const session = await this.userSessionRepo.findActiveById(sessionId); + if ( + !session || + session.userId !== userId || + session.workspaceId !== workspaceId + ) { + throw new Error('Session not active'); + } + } const data = client.data as RealtimeClientData; data.userId = userId; data.workspaceId = workspaceId; - // Gate: realtime dictation must be enabled at the workspace level. + // Gate + concurrency caps. canConnect is a pure decision over the current + // counts; it checks BOTH the feature gate and BOTH caps before we mutate + // either counter, so a rejected connection leaves the counters clean. const workspace = await this.workspaceRepo.findById(workspaceId); const settings = (workspace?.settings ?? {}) as { ai?: { dictation?: boolean; dictationRealtime?: boolean }; }; - if ( - settings.ai?.dictation !== true || - settings.ai?.dictationRealtime !== true - ) { - client.emit('error', { - message: 'Realtime dictation is not enabled', - }); + const decision = canConnect(userId, workspaceId, settings, { + userCount: userSessions.count(userId), + workspaceCount: workspaceSessions.count(workspaceId), + }); + if (decision.allowed === false) { + client.emit('error', { message: decision.reason }); client.disconnect(); return; } - // Hard concurrency caps (realtime is expensive). Check both before - // incrementing either, so a rejected connection leaves the counters clean. - const userCount = sessionsPerUser.get(userId) ?? 0; - const workspaceCount = sessionsPerWorkspace.get(workspaceId) ?? 0; - if (userCount >= MAX_SESSIONS_PER_USER) { - client.emit('error', { - message: - 'A realtime dictation session is already active for your account', - }); - client.disconnect(); - return; - } - if (workspaceCount >= MAX_SESSIONS_PER_WORKSPACE) { - client.emit('error', { - message: - 'The maximum number of concurrent realtime dictation sessions for this workspace has been reached', - }); - client.disconnect(); - return; - } - - incr(sessionsPerUser, userId); - incr(sessionsPerWorkspace, workspaceId); + userSessions.increment(userId); + workspaceSessions.increment(workspaceId); // Remember exactly what we counted so disconnect decrements symmetrically. data.countedUserId = userId; data.countedWorkspaceId = workspaceId; } catch (err) { - // Auth failure (or any unexpected connect error): never leak details. + // Auth/origin failure (or any unexpected connect error): never leak + // details, never increment a counter. this.logger.error('Realtime dictation connection rejected', err as Error); client.emit('error', { message: 'Unauthorized' }); client.disconnect(); @@ -213,11 +249,11 @@ export class AiRealtimeGateway state.handle?.close(); state.handle = undefined; if (state.countedUserId) { - decr(sessionsPerUser, state.countedUserId); + userSessions.decrement(state.countedUserId); state.countedUserId = undefined; } if (state.countedWorkspaceId) { - decr(sessionsPerWorkspace, state.countedWorkspaceId); + workspaceSessions.decrement(state.countedWorkspaceId); state.countedWorkspaceId = undefined; } } @@ -234,3 +270,22 @@ export class AiRealtimeGateway return null; } } + +// Re-export the extracted concurrency cap constants for the gateway's existing +// importers and any callers that referenced them off this module. +export { MAX_SESSIONS_PER_USER, MAX_SESSIONS_PER_WORKSPACE }; + +/** + * Test-only handles onto the module-level counters. Exposed so the gateway + * lifecycle / no-leak specs can assert and reset the per-process state between + * cases (the counters are shared across every instance in-process). NOT for + * production use. + */ +export const __testCounters = { + user: userSessions, + workspace: workspaceSessions, + reset(): void { + userSessions.reset(); + workspaceSessions.reset(); + }, +}; diff --git a/apps/server/src/core/ai-chat/realtime/ai-realtime.service.spec.ts b/apps/server/src/core/ai-chat/realtime/ai-realtime.service.spec.ts index e99b464b..90449b27 100644 --- a/apps/server/src/core/ai-chat/realtime/ai-realtime.service.spec.ts +++ b/apps/server/src/core/ai-chat/realtime/ai-realtime.service.spec.ts @@ -1,4 +1,121 @@ -import { parseUpstreamEvent } from './ai-realtime.service'; +import { EventEmitter } from 'node:events'; + +// Mock the SSRF guard so openSession's pre-flight check is controllable and the +// pinned `lookup` never touches real DNS in these unit tests. +jest.mock('../external-mcp/ssrf-guard', () => ({ + isUrlAllowed: jest.fn(async () => ({ ok: true })), + isIpAllowed: jest.fn(() => ({ ok: true })), +})); + +import { + parseUpstreamEvent, + AiRealtimeService, + type WsFactory, +} from './ai-realtime.service'; +import { isUrlAllowed } from '../external-mcp/ssrf-guard'; +import { AiSttNotConfiguredException } from '../../../integrations/ai/ai-stt-not-configured.exception'; + +const mockedIsUrlAllowed = isUrlAllowed as jest.MockedFunction< + typeof isUrlAllowed +>; + +/** ws readyState constants (CONNECTING/OPEN/CLOSING/CLOSED). */ +const WS_CONNECTING = 0; +const WS_OPEN = 1; +const WS_CLOSED = 3; + +/** + * Minimal fake of the `ws` WebSocket: an EventEmitter with readyState/send/close. + * `sent` records every frame; `failSend` makes the next send throw; helpers + * simulate the upstream emitting open/message/close/error. + */ +class FakeWs extends EventEmitter { + static readonly OPEN = WS_OPEN; + static readonly CONNECTING = WS_CONNECTING; + static readonly CLOSED = WS_CLOSED; + + readyState = WS_CONNECTING; + sent: string[] = []; + closed = false; + failSend = false; + lastOpts: unknown; + + constructor(opts?: unknown) { + super(); + this.lastOpts = opts; + } + + send(data: string): void { + if (this.failSend) throw new Error('send boom'); + this.sent.push(data); + } + + close(): void { + this.closed = true; + this.readyState = WS_CLOSED; + } + + /** Parsed view of the sent frames. */ + sentJson(): Array> { + return this.sent.map((s) => JSON.parse(s) as Record); + } + + // --- upstream-side simulation helpers --- + open(): void { + this.readyState = WS_OPEN; + this.emit('open'); + } + message(obj: unknown): void { + this.emit('message', Buffer.from(JSON.stringify(obj))); + } + rawMessage(raw: string): void { + this.emit('message', Buffer.from(raw)); + } + upstreamClose(code: number, reason = ''): void { + this.readyState = WS_CLOSED; + this.emit('close', code, Buffer.from(reason)); + } + errorEvent(err: Error): void { + this.emit('error', err); + } +} + +/** Build a service with a stub AiSettingsService.resolve and the fake ws factory. */ +function makeService( + resolveValue: Record | null, +): { service: AiRealtimeService; created: FakeWs[] } { + const created: FakeWs[] = []; + const aiSettings = { + resolve: jest.fn(async () => resolveValue), + } as unknown as ConstructorParameters[0]; + const service = new AiRealtimeService(aiSettings); + const factory: WsFactory = (_url, opts) => { + const ws = new FakeWs(opts); + created.push(ws); + return ws as unknown as ReturnType; + }; + service.setWsFactory(factory); + return { service, created }; +} + +/** A fully-configured STT config that resolves to the OpenAI default URL. */ +const OPENAI_CFG = { + driver: 'openai', + sttRealtimeModel: 'gpt-4o-transcribe', + sttApiKey: 'sk-test', +}; + +/** Callback spies + the options object to pass to openSession. */ +function makeCallbacks(language?: string) { + const cb = { + onReady: jest.fn(), + onInterim: jest.fn(), + onFinal: jest.fn(), + onError: jest.fn(), + onClosed: jest.fn(), + }; + return { ...cb, language }; +} /** * Unit tests for the PURE `parseUpstreamEvent` normalizer (no network). They @@ -50,7 +167,8 @@ describe('parseUpstreamEvent (OpenAI GA → normalized realtime events)', () => itemId: 'item-1', text: 'Hello world', }); - expect(acc.get('item-1')).toBe('Hello world'); + // Accumulator is keyed by `${item_id}:${content_index ?? 0}`. + expect(acc.get('item-1:0')).toBe('Hello world'); }); it('emits a trimmed final from the completed transcript and clears the accumulator', () => { @@ -63,7 +181,7 @@ describe('parseUpstreamEvent (OpenAI GA → normalized realtime events)', () => }), acc, ); - expect(acc.has('item-2')).toBe(true); + expect(acc.has('item-2:0')).toBe(true); const final = parseUpstreamEvent( JSON.stringify({ @@ -79,7 +197,7 @@ describe('parseUpstreamEvent (OpenAI GA → normalized realtime events)', () => text: 'Final transcript.', }); // The accumulator entry for the completed segment is removed. - expect(acc.has('item-2')).toBe(false); + expect(acc.has('item-2:0')).toBe(false); }); it('falls back to the accumulated text when completed omits the transcript', () => { @@ -103,7 +221,7 @@ describe('parseUpstreamEvent (OpenAI GA → normalized realtime events)', () => itemId: 'item-3', text: 'accumulated only', }); - expect(acc.has('item-3')).toBe(false); + expect(acc.has('item-3:0')).toBe(false); }); it('maps an error frame to { type: "error" } with the provider message', () => { @@ -183,4 +301,531 @@ describe('parseUpstreamEvent (OpenAI GA → normalized realtime events)', () => // Every started segment was completed → the accumulator is empty. expect(acc.size).toBe(0); }); + + it('ignores a delta with no item_id', () => { + expect( + parseUpstreamEvent( + JSON.stringify({ + type: 'conversation.item.input_audio_transcription.delta', + delta: 'orphan', + }), + acc, + ), + ).toEqual({ type: 'ignore' }); + expect(acc.size).toBe(0); + }); + + it('ignores a completed with no item_id', () => { + expect( + parseUpstreamEvent( + JSON.stringify({ + type: 'conversation.item.input_audio_transcription.completed', + transcript: 'orphan', + }), + acc, + ), + ).toEqual({ type: 'ignore' }); + }); + + it('falls back to describeProviderError when an error frame carries no message', () => { + const out = parseUpstreamEvent( + JSON.stringify({ type: 'error', error: { code: 'x', type: 'server' } }), + acc, + ); + expect(out.type).toBe('error'); + // No provider message → the generic fallback string. + expect(out.message).toBe('Realtime transcription error'); + }); + + it('emits text:"" for a completed with no transcript and an empty accumulator', () => { + const out = parseUpstreamEvent( + JSON.stringify({ + type: 'conversation.item.input_audio_transcription.completed', + item_id: 'empty-1', + }), + acc, + ); + expect(out).toEqual({ type: 'final', itemId: 'empty-1', text: '' }); + }); + + it('ignores non-object JSON payloads ("42", "null")', () => { + expect(parseUpstreamEvent('42', acc)).toEqual({ type: 'ignore' }); + expect(parseUpstreamEvent('null', acc)).toEqual({ type: 'ignore' }); + expect(acc.size).toBe(0); + }); + + it('keeps two content_index parts under the same item_id separate', () => { + // Two segments share item_id 'seg' but differ by content_index. + const a = parseUpstreamEvent( + JSON.stringify({ + type: 'conversation.item.input_audio_transcription.delta', + item_id: 'seg', + content_index: 0, + delta: 'first', + }), + acc, + ); + expect(a).toEqual({ type: 'interim', itemId: 'seg', text: 'first' }); + + const b = parseUpstreamEvent( + JSON.stringify({ + type: 'conversation.item.input_audio_transcription.delta', + item_id: 'seg', + content_index: 1, + delta: 'second', + }), + acc, + ); + // content_index 1 must NOT be concatenated onto content_index 0. + expect(b).toEqual({ type: 'interim', itemId: 'seg', text: 'second' }); + + // Two distinct accumulator entries. + expect(acc.get('seg:0')).toBe('first'); + expect(acc.get('seg:1')).toBe('second'); + + // Completing content_index 0 leaves content_index 1 intact. + const finalA = parseUpstreamEvent( + JSON.stringify({ + type: 'conversation.item.input_audio_transcription.completed', + item_id: 'seg', + content_index: 0, + transcript: 'first', + }), + acc, + ); + expect(finalA).toEqual({ type: 'final', itemId: 'seg', text: 'first' }); + expect(acc.has('seg:0')).toBe(false); + expect(acc.get('seg:1')).toBe('second'); + }); +}); + +describe('AiRealtimeService.deriveRealtimeUrl', () => { + const OPENAI_DEFAULT = 'wss://api.openai.com/v1/realtime?intent=transcription'; + + it('returns the OpenAI default for no base / whitespace base', () => { + expect(AiRealtimeService.deriveRealtimeUrl()).toBe(OPENAI_DEFAULT); + expect(AiRealtimeService.deriveRealtimeUrl('')).toBe(OPENAI_DEFAULT); + expect(AiRealtimeService.deriveRealtimeUrl(' ')).toBe(OPENAI_DEFAULT); + }); + + it('derives /v1/realtime for a bare host base', () => { + expect(AiRealtimeService.deriveRealtimeUrl('https://stt.example.com')).toBe( + 'wss://stt.example.com/v1/realtime?intent=transcription', + ); + }); + + it('does not duplicate /v1 when the base already ends in /v1', () => { + expect( + AiRealtimeService.deriveRealtimeUrl('https://stt.example.com/v1'), + ).toBe('wss://stt.example.com/v1/realtime?intent=transcription'); + }); + + it('does not duplicate /v1 when the base already ends in /v1/realtime', () => { + expect( + AiRealtimeService.deriveRealtimeUrl('https://stt.example.com/v1/realtime'), + ).toBe('wss://stt.example.com/v1/realtime?intent=transcription'); + }); + + it('strips a trailing slash before normalizing the path', () => { + expect( + AiRealtimeService.deriveRealtimeUrl('https://stt.example.com/v1/'), + ).toBe('wss://stt.example.com/v1/realtime?intent=transcription'); + }); + + it('upgrades https to wss', () => { + expect( + AiRealtimeService.deriveRealtimeUrl('https://stt.example.com:8443/v1'), + ).toBe('wss://stt.example.com:8443/v1/realtime?intent=transcription'); + }); + + it('THROWS (fail-closed) for a non-empty unparseable base', () => { + expect(() => AiRealtimeService.deriveRealtimeUrl('not a url')).toThrow( + /could not be parsed/i, + ); + // The bad base is mentioned in the message. + expect(() => AiRealtimeService.deriveRealtimeUrl('not a url')).toThrow( + /not a url/, + ); + }); + + it('THROWS for an http:// base (Bearer key would be plaintext)', () => { + expect(() => + AiRealtimeService.deriveRealtimeUrl('http://stt.example.com/v1'), + ).toThrow(/secure|plaintext|http:/i); + }); + + it('THROWS for a ws:// base (insecure scheme)', () => { + expect(() => + AiRealtimeService.deriveRealtimeUrl('ws://stt.example.com/v1'), + ).toThrow(/secure|plaintext|ws:/i); + }); +}); + +describe('AiRealtimeService.openSession (fake ws seam)', () => { + beforeEach(() => { + mockedIsUrlAllowed.mockReset(); + mockedIsUrlAllowed.mockResolvedValue({ ok: true }); + }); + + it('not configured (no driver/model) → throws AiSttNotConfiguredException and creates NO socket', async () => { + const { service, created } = makeService(null); + await expect( + service.openSession('ws-1', makeCallbacks()), + ).rejects.toBeInstanceOf(AiSttNotConfiguredException); + expect(created).toHaveLength(0); + }); + + it('SSRF guard blocks (isUrlAllowed=false) → throws and creates NO socket', async () => { + mockedIsUrlAllowed.mockResolvedValue({ ok: false, reason: 'blocked range' }); + const { service, created } = makeService(OPENAI_CFG); + await expect(service.openSession('ws-1', makeCallbacks())).rejects.toThrow( + /SSRF guard.*blocked range/i, + ); + expect(created).toHaveLength(0); + }); + + it('on open sends exactly one session.update with the GA transcription shape', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + await service.openSession('ws-1', cb); + const ws = created[0]; + + // Authorization header set; NO OpenAI-Beta header. + const opts = ws.lastOpts as { headers?: Record }; + expect(opts.headers?.Authorization).toBe('Bearer sk-test'); + expect(Object.keys(opts.headers ?? {})).not.toContain('OpenAI-Beta'); + + ws.open(); + expect(ws.sent).toHaveLength(1); + const frame = ws.sentJson()[0]; + expect(frame.type).toBe('session.update'); + const session = frame.session as Record; + expect(session.type).toBe('transcription'); + const input = (session.audio as Record>) + .input; + expect(input.format).toEqual({ type: 'audio/pcm', rate: 24000 }); + expect(input.turn_detection).toEqual({ type: 'server_vad' }); + const transcription = input.transcription as Record; + expect(transcription.model).toBe('gpt-4o-transcribe'); + // No language was supplied → omitted. + expect(transcription).not.toHaveProperty('language'); + }); + + it('includes language only when supplied', async () => { + const { service, created } = makeService(OPENAI_CFG); + await service.openSession('ws-1', makeCallbacks('en')); + const ws = created[0]; + ws.open(); + const input = ( + (ws.sentJson()[0].session as Record).audio as Record< + string, + Record + > + ).input; + const transcription = input.transcription as Record; + expect(transcription.language).toBe('en'); + }); + + it('fires onReady once on a session.created/updated frame', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + ws.message({ type: 'session.created' }); + expect(cb.onReady).toHaveBeenCalledTimes(1); + ws.message({ type: 'session.updated' }); + expect(cb.onReady).toHaveBeenCalledTimes(2); + }); + + it('routes interim → onInterim and completed → onFinal', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + ws.message({ + type: 'conversation.item.input_audio_transcription.delta', + item_id: 'i1', + delta: 'hi', + }); + expect(cb.onInterim).toHaveBeenCalledWith('i1', 'hi'); + ws.message({ + type: 'conversation.item.input_audio_transcription.completed', + item_id: 'i1', + transcript: 'hi there', + }); + expect(cb.onFinal).toHaveBeenCalledWith('i1', 'hi there'); + }); + + it('routes an error frame → onError + teardown (onClosed once)', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + ws.message({ type: 'error', error: { message: 'invalid_api_key' } }); + expect(cb.onError).toHaveBeenCalledWith('invalid_api_key'); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + expect(ws.closed).toBe(true); + }); + + it('ignores messages after teardown (post-close → no-op)', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + const handle = await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + handle.close(); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + ws.message({ type: 'session.created' }); + ws.message({ + type: 'conversation.item.input_audio_transcription.completed', + item_id: 'late', + transcript: 'dropped', + }); + expect(cb.onReady).not.toHaveBeenCalled(); + expect(cb.onFinal).not.toHaveBeenCalled(); + }); + + it('appendAudio is a no-op when the socket is not OPEN', async () => { + const { service, created } = makeService(OPENAI_CFG); + const handle = await service.openSession('ws-1', makeCallbacks()); + const ws = created[0]; + // Still CONNECTING (no open yet). + handle.appendAudio(Buffer.from([1, 2, 3])); + expect(ws.sent).toHaveLength(0); + }); + + it('appendAudio base64-encodes and sends input_audio_buffer.append when OPEN', async () => { + const { service, created } = makeService(OPENAI_CFG); + const handle = await service.openSession('ws-1', makeCallbacks()); + const ws = created[0]; + ws.open(); + ws.sent.length = 0; // drop the session.update frame + handle.appendAudio(Buffer.from([1, 2, 3])); + const frame = ws.sentJson()[0]; + expect(frame.type).toBe('input_audio_buffer.append'); + expect(frame.audio).toBe(Buffer.from([1, 2, 3]).toString('base64')); + }); + + it('appendAudio send throw → onError + teardown', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + const handle = await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + ws.failSend = true; + handle.appendAudio(Buffer.from([1])); + expect(cb.onError).toHaveBeenCalled(); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + }); +}); + +describe('AiRealtimeService stop() drain + idempotency', () => { + beforeEach(() => { + mockedIsUrlAllowed.mockReset(); + mockedIsUrlAllowed.mockResolvedValue({ ok: true }); + }); + + it('with nothing in flight, stop commits then tears down immediately', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + const handle = await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + ws.sent.length = 0; + handle.stop(); + // Commit was sent, socket closed, onClosed fired once. + expect(ws.sentJson()[0].type).toBe('input_audio_buffer.commit'); + expect(ws.closed).toBe(true); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + }); + + it('drains an in-flight segment: holds open until the tail completed, delivers it, then closes', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + const handle = await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + // A segment is in flight (delta but no completed yet). + ws.message({ + type: 'conversation.item.input_audio_transcription.delta', + item_id: 'tail', + delta: 'last phrase', + }); + ws.sent.length = 0; + handle.stop(); + // Commit sent; socket NOT yet closed (draining). + expect(ws.sentJson()[0].type).toBe('input_audio_buffer.commit'); + expect(ws.closed).toBe(false); + expect(cb.onClosed).not.toHaveBeenCalled(); + // The tail's completed arrives after stop → it is still delivered. + ws.message({ + type: 'conversation.item.input_audio_transcription.completed', + item_id: 'tail', + transcript: 'last phrase', + }); + expect(cb.onFinal).toHaveBeenCalledWith('tail', 'last phrase'); + // Now closed. + expect(ws.closed).toBe(true); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + }); + + it('drain timeout closes if the tail never arrives (fake timers)', async () => { + jest.useFakeTimers(); + try { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + const handle = await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + ws.message({ + type: 'conversation.item.input_audio_transcription.delta', + item_id: 'tail', + delta: 'partial', + }); + handle.stop(); + expect(ws.closed).toBe(false); + // Drain window (2.5s) elapses with no completed → close. + jest.advanceTimersByTime(3_000); + expect(ws.closed).toBe(true); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + } finally { + jest.useRealTimers(); + } + }); + + it('is idempotent: double close fires onClosed exactly once', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + const handle = await service.openSession('ws-1', cb); + created[0].open(); + handle.close(); + handle.close(); + handle.stop(); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + }); +}); + +describe('AiRealtimeService session timers (fake timers)', () => { + beforeEach(() => { + mockedIsUrlAllowed.mockReset(); + mockedIsUrlAllowed.mockResolvedValue({ ok: true }); + jest.useFakeTimers(); + }); + afterEach(() => { + jest.useRealTimers(); + }); + + it('idle 15s with no audio → onError + onClosed', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + await service.openSession('ws-1', cb); + created[0].open(); + jest.advanceTimersByTime(15_000); + expect(cb.onError).toHaveBeenCalledWith( + expect.stringContaining('idle'), + ); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + }); + + it('max 120s duration → onError + onClosed', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + const handle = await service.openSession('ws-1', cb); + created[0].open(); + // Keep pushing audio so the idle timer never fires before the max cap. + for (let t = 0; t < 120_000; t += 10_000) { + handle.appendAudio(Buffer.from([1])); + jest.advanceTimersByTime(10_000); + } + expect(cb.onError).toHaveBeenCalledWith( + expect.stringContaining('maximum duration'), + ); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + }); + + it('an unexpected upstream close (code !== 1000) reports onError', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + await service.openSession('ws-1', cb); + const ws = created[0]; + ws.open(); + ws.upstreamClose(1006, 'abnormal'); + expect(cb.onError).toHaveBeenCalledWith( + expect.stringContaining('1006'), + ); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + }); + + it('no timer fires after teardown', async () => { + const { service, created } = makeService(OPENAI_CFG); + const cb = makeCallbacks(); + const handle = await service.openSession('ws-1', cb); + created[0].open(); + handle.close(); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + jest.advanceTimersByTime(200_000); + expect(cb.onError).not.toHaveBeenCalled(); + expect(cb.onClosed).toHaveBeenCalledTimes(1); + }); +}); + +describe('AiRealtimeService.testConnection (settle once)', () => { + beforeEach(() => { + mockedIsUrlAllowed.mockReset(); + mockedIsUrlAllowed.mockResolvedValue({ ok: true }); + }); + + it('resolves { ok: true } on the first onReady', async () => { + const { service, created } = makeService(OPENAI_CFG); + const promise = service.testConnection('ws-1'); + // Let openSession's await resolve and create the socket. + await Promise.resolve(); + await Promise.resolve(); + const ws = created[0]; + ws.open(); + ws.message({ type: 'session.created' }); + await expect(promise).resolves.toEqual({ ok: true }); + expect(ws.closed).toBe(true); + }); + + it('resolves { ok: false, error } on the first onError', async () => { + const { service, created } = makeService(OPENAI_CFG); + const promise = service.testConnection('ws-1'); + await Promise.resolve(); + await Promise.resolve(); + const ws = created[0]; + ws.open(); + ws.message({ type: 'error', error: { message: 'bad key' } }); + await expect(promise).resolves.toEqual({ ok: false, error: 'bad key' }); + }); + + it('times out after 8s → { ok: false } (fake timers)', async () => { + jest.useFakeTimers(); + try { + const { service, created } = makeService(OPENAI_CFG); + const promise = service.testConnection('ws-1'); + // Flush the openSession microtasks so the socket exists. + await Promise.resolve(); + await Promise.resolve(); + created[0].open(); + jest.advanceTimersByTime(8_000); + await expect(promise).resolves.toEqual({ + ok: false, + error: 'Realtime connection timed out', + }); + } finally { + jest.useRealTimers(); + } + }); + + it('not configured → { ok: false } with the not-configured message', async () => { + const { service } = makeService(null); + await expect(service.testConnection('ws-1')).resolves.toEqual({ + ok: false, + error: 'AI STT model not configured', + }); + }); }); diff --git a/apps/server/src/core/ai-chat/realtime/ai-realtime.service.ts b/apps/server/src/core/ai-chat/realtime/ai-realtime.service.ts index 76cabaf7..37359b2c 100644 --- a/apps/server/src/core/ai-chat/realtime/ai-realtime.service.ts +++ b/apps/server/src/core/ai-chat/realtime/ai-realtime.service.ts @@ -1,9 +1,97 @@ import { Injectable, Logger } from '@nestjs/common'; -import WebSocket from 'ws'; +import { + lookup as dnsLookup, + type LookupAddress, + type LookupAllOptions, + type LookupOneOptions, +} from 'node:dns'; +import type { LookupFunction } from 'node:net'; +// CJS import-equals so the WebSocket class (and its static OPEN/CONNECTING +// constants) resolves as a runtime VALUE regardless of esModuleInterop — a +// plain `import WebSocket from 'ws'` compiles to `ws_1.default`, which is +// undefined for this CommonJS module. +import WebSocket = require('ws'); import { AiSettingsService } from '../../../integrations/ai/ai-settings.service'; import { AiSttNotConfiguredException } from '../../../integrations/ai/ai-stt-not-configured.exception'; import { describeProviderError } from '../../../integrations/ai/ai-error.util'; -import { isUrlAllowed } from '../external-mcp/ssrf-guard'; +import { isUrlAllowed, isIpAllowed } from '../external-mcp/ssrf-guard'; + +/** + * WebSocket client options plus the `lookup` DNS hook. `@types/ws` does not + * declare `lookup`, but the `ws` lib spreads the options into the underlying + * `http(s).request` → `net`/`tls.connect`, both of which honor `lookup`. We add + * it here so the SSRF IP-pinning hook is type-checked. + */ +export type RealtimeWsOptions = WebSocket.ClientOptions & { + lookup?: LookupFunction; +}; + +/** + * Factory that creates the upstream WebSocket. Defaults to the real `ws` + * constructor; tests inject a fake (an EventEmitter with readyState/send/close) + * so the handshake/message-routing/timer logic can be exercised without a + * network. The `lookup` SSRF-pinning option lives in `opts` (see openSession). + */ +export type WsFactory = (url: string, opts: RealtimeWsOptions) => WebSocket; + +/** The default, production WebSocket factory. */ +const defaultWsFactory: WsFactory = (url, opts) => new WebSocket(url, opts); + +/** + * Build a `net`/`tls` `lookup` option that pins the connection to SSRF-validated + * addresses. The `ws` lib forwards this option through `http(s).request` into + * `net.connect`/`tls.connect`, so it is invoked for EVERY DNS resolution the + * socket performs — there is no second, unchecked resolution after our pre-flight + * `isUrlAllowed` check. We always resolve ALL addresses ourselves, validate each + * with `isIpAllowed`, and only ever hand back validated addresses; a single + * blocked address fails the whole connect. Mirrors buildPinnedDispatcher in the + * external-MCP layer. The hostname (SNI / Host header) is left untouched so TLS + * certificate validation still uses the real hostname. + */ +function buildPinnedLookup(): LookupFunction { + const pinnedLookup = ( + hostname: string, + options: LookupOneOptions | LookupAllOptions | number, + callback: ( + err: NodeJS.ErrnoException | null, + address: string | LookupAddress[], + family?: number, + ) => void, + ): void => { + // Whether net requested a single address or all of them, resolve ALL + // ourselves and validate every one; do not trust the caller's `all` flag. + const wantAll = + typeof options === 'object' && options !== null && options.all === true; + dnsLookup(hostname, { all: true }, (err, addresses) => { + if (err) { + callback(err, '', 0); + return; + } + const addrs = addresses as LookupAddress[]; + if (addrs.length === 0) { + callback(new Error(`No address resolved for ${hostname}`), '', 0); + return; + } + const blocked = addrs.find((a) => !isIpAllowed(a.address).ok); + if (blocked) { + // Refuse the connection: net/tls.connect never sees this address. + callback(new Error(`Blocked address for ${hostname}`), '', 0); + return; + } + const validated: LookupAddress[] = addrs.map((a) => ({ + address: a.address, + family: a.family, + })); + if (wantAll) { + callback(null, validated); + } else { + // Single-address form: hand back the first validated entry. + callback(null, validated[0].address, validated[0].family); + } + }); + }; + return pinnedLookup as LookupFunction; +} /** * Realtime STT proxy (server side of the A2 transport: browser ↔ OUR server ↔ @@ -55,6 +143,12 @@ export interface RealtimeSessionHandle { const IDLE_TIMEOUT_MS = 15_000; /** Hard cap on a single realtime session's lifetime (mirrors the client's 120s). */ const MAX_SESSION_DURATION_MS = 120_000; +/** + * On graceful stop, how long to wait for the final in-flight segment's + * `...transcription.completed` (the tail after the last VAD pause) before + * closing anyway. Without this drain the last dictated phrase is silently lost. + */ +const STOP_DRAIN_TIMEOUT_MS = 2_500; /** How long testConnection waits for the upstream to become ready before failing. */ const TEST_CONNECTION_TIMEOUT_MS = 8_000; @@ -66,11 +160,16 @@ const TEST_CONNECTION_TIMEOUT_MS = 8_000; * * - session.created / session.updated → { type: 'ready' } * - conversation.item.input_audio_transcription.delta → append delta to - * acc[item_id]; return { type: 'interim', itemId, text: } + * acc[`${item_id}:${content_index}`]; return { type: 'interim', itemId, + * text: } * - conversation.item.input_audio_transcription.completed → final transcript - * (trimmed), delete acc[item_id]; return { type: 'final', itemId, text } + * (trimmed), delete acc entry; return { type: 'final', itemId, text } * - error → { type: 'error', message } (provider message, else describeProviderError) * - anything else / unparseable → { type: 'ignore' } + * + * The accumulator is keyed by `${item_id}:${content_index ?? 0}` so distinct + * content parts of the same item (GA may stream multiple `content_index` + * segments under one item_id) are never concatenated together. */ export function parseUpstreamEvent( raw: string, @@ -79,6 +178,7 @@ export function parseUpstreamEvent( let evt: { type?: string; item_id?: string; + content_index?: number; delta?: string; transcript?: string; error?: { message?: string; code?: string; type?: string }; @@ -94,6 +194,11 @@ export function parseUpstreamEvent( return { type: 'ignore' }; } + // Compose the accumulator key from item_id + content_index so two distinct + // content parts under the same item_id don't get merged. + const accKey = (itemId: string): string => + `${itemId}:${evt.content_index ?? 0}`; + switch (evt.type) { case 'session.created': case 'session.updated': @@ -102,19 +207,21 @@ export function parseUpstreamEvent( case 'conversation.item.input_audio_transcription.delta': { const itemId = evt.item_id; if (!itemId) return { type: 'ignore' }; - const prev = acc.get(itemId) ?? ''; + const key = accKey(itemId); + const prev = acc.get(key) ?? ''; const next = prev + (evt.delta ?? ''); - acc.set(itemId, next); + acc.set(key, next); return { type: 'interim', itemId, text: next }; } case 'conversation.item.input_audio_transcription.completed': { const itemId = evt.item_id; if (!itemId) return { type: 'ignore' }; + const key = accKey(itemId); // Prefer the authoritative `transcript`; fall back to whatever we // accumulated from deltas if the completed frame omits it. - const text = (evt.transcript ?? acc.get(itemId) ?? '').trim(); - acc.delete(itemId); + const text = (evt.transcript ?? acc.get(key) ?? '').trim(); + acc.delete(key); return { type: 'final', itemId, text }; } @@ -135,8 +242,21 @@ export function parseUpstreamEvent( export class AiRealtimeService { private readonly logger = new Logger(AiRealtimeService.name); + /** + * WebSocket factory seam (R-SVC1). Defaults to the real `ws` constructor; + * tests can override it with `setWsFactory` to inject a fake socket. Nest + * always constructs the service with the single `aiSettings` dependency, so + * the factory is settable rather than a constructor param. + */ + private wsFactory: WsFactory = defaultWsFactory; + constructor(private readonly aiSettings: AiSettingsService) {} + /** Test seam: override the WebSocket factory (R-SVC1). */ + setWsFactory(factory: WsFactory): void { + this.wsFactory = factory; + } + /** * Resolve the workspace STT config, SSRF-check the upstream, open the upstream * realtime WS and wire its events to the supplied callbacks. Returns a handle @@ -157,9 +277,14 @@ export class AiRealtimeService { const baseUrl = cfg.sttRealtimeBaseUrl || cfg.sttBaseUrl || cfg.baseUrl; const wssUrl = AiRealtimeService.deriveRealtimeUrl(baseUrl); - // SSRF check on the http(s) equivalent (ssrf-guard only allows http/https): - // wss→https, ws→http. Re-checked here, right before connecting, to close the - // DNS-rebinding window (same defense the external-MCP layer uses). + // Fast pre-flight SSRF check on the http(s) equivalent (ssrf-guard only + // allows http/https): wss→https, ws→http. This is only a cheap pre-check — + // the `ws` lib re-resolves DNS independently when it connects, so a pass here + // does NOT guarantee the socket connects to a public address (TOCTOU / + // DNS-rebinding). The authoritative defense is the pinned `lookup` below, + // which validates EVERY resolved address right before connect (full parity + // with the external-MCP buildPinnedDispatcher — no second unchecked DNS + // resolution). const httpEquivalent = wssUrl.replace(/^wss:/i, 'https:').replace(/^ws:/i, 'http:'); const check = await isUrlAllowed(httpEquivalent); if (!check.ok) { @@ -172,14 +297,22 @@ export class AiRealtimeService { // Never log the key; only the (non-secret) URL is safe to log. this.logger.log(`Opening realtime STT session for workspace ${workspaceId}`); - const ws = new WebSocket(wssUrl, { + const ws = this.wsFactory(wssUrl, { headers: key ? { Authorization: `Bearer ${key}` } : {}, // DO NOT send OpenAI-Beta: realtime=v1 — removed in GA. + // SSRF IP-pinning: the `ws` lib forwards this `lookup` into the underlying + // net/tls connect, so every resolved address is validated before connect. + lookup: buildPinnedLookup(), }); let closed = false; + // Graceful-stop drain state: once stop() commits with an in-flight segment, + // `draining` is true and we hold the socket open until that segment's + // `completed` arrives (delivered via onFinal) or the drain timer fires. + let draining = false; let idleTimer: NodeJS.Timeout | undefined; let maxTimer: NodeJS.Timeout | undefined; + let drainTimer: NodeJS.Timeout | undefined; const clearTimers = (): void => { if (idleTimer) { @@ -190,6 +323,10 @@ export class AiRealtimeService { clearTimeout(maxTimer); maxTimer = undefined; } + if (drainTimer) { + clearTimeout(drainTimer); + drainTimer = undefined; + } }; // Idempotent teardown: clears timers, force-closes the upstream, fires @@ -225,6 +362,7 @@ export class AiRealtimeService { `Realtime session idle for ${IDLE_TIMEOUT_MS}ms (no audio received); closing.`, ); }, IDLE_TIMEOUT_MS); + idleTimer.unref?.(); }; // Hard lifetime cap, armed immediately so a never-opening or runaway session @@ -234,6 +372,7 @@ export class AiRealtimeService { `Realtime session exceeded the maximum duration of ${MAX_SESSION_DURATION_MS}ms; closing.`, ); }, MAX_SESSION_DURATION_MS); + maxTimer.unref?.(); // Also guard the handshake itself: if the upstream never opens / never sends, // the idle timer (15s) reclaims it well before the 120s max-duration cap. @@ -285,6 +424,12 @@ export class AiRealtimeService { break; case 'final': opts.onFinal(parsed.itemId!, parsed.text ?? ''); + // If we're draining on a graceful stop and the accumulator is now + // empty (no more in-flight segments), the tail we were waiting for has + // arrived and been delivered → close now. + if (draining && acc.size === 0) { + teardown(); + } break; case 'error': // Log the full upstream error then surface the concrete reason. @@ -335,9 +480,10 @@ export class AiRealtimeService { resetIdleTimer(); }, stop: (): void => { + if (closed) return; // Graceful stop: with server_vad no manual commit is required, but an // explicit commit flushes any buffered tail before we close. - if (!closed && ws.readyState === WebSocket.OPEN) { + if (ws.readyState === WebSocket.OPEN) { try { ws.send(JSON.stringify({ type: 'input_audio_buffer.commit' })); } catch (err) { @@ -345,6 +491,32 @@ export class AiRealtimeService { this.logger.error('Failed to commit realtime audio buffer', err as Error); } } + + // If there is an in-flight segment (a non-empty accumulation), do NOT + // close immediately: the committed tail's `...transcription.completed` + // arrives a beat later and would be dropped by the `closed` guard, + // silently losing the last dictated phrase. Instead drain: hold the + // socket open until that `completed` lands (delivered via onFinal, which + // then tears down) or a short timeout fires. With nothing in flight we + // can close right away. + if (ws.readyState === WebSocket.OPEN && acc.size > 0) { + draining = true; + // The idle timer would otherwise fire mid-drain; the drain timer is + // the bounded fallback that closes if the tail never arrives. + if (idleTimer) { + clearTimeout(idleTimer); + idleTimer = undefined; + } + drainTimer = setTimeout(() => { + // Tail never arrived within the drain window: close anyway. The + // partial (if any) was already delivered as the latest interim. + teardown(); + }, STOP_DRAIN_TIMEOUT_MS); + // Don't let the drain timer keep the event loop alive (mirror idle/max). + drainTimer.unref?.(); + return; + } + teardown(); }, close: (): void => { @@ -437,25 +609,43 @@ export class AiRealtimeService { /** * Derive the upstream realtime WSS URL from the (optional) effective base URL. * - * - No base URL → OpenAI default + * - Empty / whitespace base → OpenAI default * `wss://api.openai.com/v1/realtime?intent=transcription`. + * - A non-empty base that fails to parse → THROW. We must NOT silently fall + * back to OpenAI: a self-hosted STT key + live audio would then leak to + * OpenAI (fail-closed, not fail-open). + * - A non-secure scheme (`http://` / `ws://`) → THROW. The Authorization + * Bearer key would otherwise travel in plaintext to a self-hosted endpoint. * - Otherwise: take the base origin, ensure exactly one - * `/v1/realtime?intent=transcription` path, and upgrade the scheme to wss - * (http→ws→wss; https→wss). A base that already ends in `/v1` (or - * `/v1/realtime`) does not get a duplicated `/v1`. + * `/v1/realtime?intent=transcription` path, and produce a `wss://` URL. A + * base that already ends in `/v1` (or `/v1/realtime`) does not get a + * duplicated `/v1`. */ static deriveRealtimeUrl(baseUrl?: string): string { if (!baseUrl || !baseUrl.trim()) { return 'wss://api.openai.com/v1/realtime?intent=transcription'; } + const trimmed = baseUrl.trim(); let parsed: URL; try { - parsed = new URL(baseUrl.trim()); + parsed = new URL(trimmed); } catch { - // Unparseable base: fall back to the OpenAI default rather than throwing - // here; the SSRF check on the default still applies downstream. - return 'wss://api.openai.com/v1/realtime?intent=transcription'; + // Fail-closed: a configured-but-unparseable base must NOT silently fall + // back to OpenAI (would leak a self-hosted key + audio to OpenAI). + throw new Error( + `Invalid realtime STT base URL: "${trimmed}" could not be parsed.`, + ); + } + + // Require a secure scheme. Reject http/ws: the Bearer key would be sent in + // plaintext. (Default is reject; loopback-only relaxations, if ever needed, + // would be added explicitly here.) + if (parsed.protocol === 'http:' || parsed.protocol === 'ws:') { + throw new Error( + `Insecure realtime STT base URL: "${trimmed}" uses ${parsed.protocol}//; ` + + 'a secure https:// / wss:// scheme is required so the API key is not sent in plaintext.', + ); } // Normalize the path: strip a trailing slash, drop an existing @@ -467,11 +657,10 @@ export class AiRealtimeService { } path = `${path}/realtime`; - // Scheme → wss (secure) / ws (insecure). The SSRF guard runs on the - // http(s) equivalent before connecting. - const scheme = parsed.protocol === 'http:' || parsed.protocol === 'ws:' ? 'ws' : 'wss'; - - return `${scheme}://${parsed.host}${path}?intent=transcription`; + // The scheme is already secure (https: / wss:); the resulting realtime URL + // is always wss://. The SSRF guard runs on the https equivalent before + // connecting, and the pinned lookup validates each resolved address. + return `wss://${parsed.host}${path}?intent=transcription`; } /** Normalize a ws RawData payload (Buffer | ArrayBuffer | Buffer[]) to a string. */ diff --git a/apps/server/src/core/ai-chat/realtime/session-limits.spec.ts b/apps/server/src/core/ai-chat/realtime/session-limits.spec.ts new file mode 100644 index 00000000..b16a65f7 --- /dev/null +++ b/apps/server/src/core/ai-chat/realtime/session-limits.spec.ts @@ -0,0 +1,145 @@ +import { + SessionCounters, + canConnect, + MAX_SESSIONS_PER_USER, + MAX_SESSIONS_PER_WORKSPACE, +} from './session-limits'; + +describe('SessionCounters', () => { + let counters: SessionCounters; + + beforeEach(() => { + counters = new SessionCounters(); + }); + + it('starts at 0 for an unknown key', () => { + expect(counters.count('a')).toBe(0); + }); + + it('increments and reports the running count', () => { + expect(counters.increment('a')).toBe(1); + expect(counters.increment('a')).toBe(2); + expect(counters.count('a')).toBe(2); + // Distinct keys are independent. + expect(counters.count('b')).toBe(0); + }); + + it('decrements and deletes the key at zero (no phantom zero-entry)', () => { + counters.increment('a'); + counters.increment('a'); + expect(counters.decrement('a')).toBe(1); + expect(counters.decrement('a')).toBe(0); + expect(counters.count('a')).toBe(0); + // Re-incrementing a deleted key starts fresh at 1. + expect(counters.increment('a')).toBe(1); + }); + + it('decrementing an absent key is a no-op: no negative, no phantom slot', () => { + expect(counters.decrement('missing')).toBe(0); + expect(counters.count('missing')).toBe(0); + // A subsequent increment must start at 1, proving no -1 was stored. + expect(counters.increment('missing')).toBe(1); + }); + + it('double-decrement past zero never goes negative', () => { + counters.increment('a'); + expect(counters.decrement('a')).toBe(0); + expect(counters.decrement('a')).toBe(0); + expect(counters.count('a')).toBe(0); + }); + + it('reset() clears all counts', () => { + counters.increment('a'); + counters.increment('b'); + counters.reset(); + expect(counters.count('a')).toBe(0); + expect(counters.count('b')).toBe(0); + }); +}); + +describe('canConnect', () => { + const enabled = { ai: { dictation: true, dictationRealtime: true } }; + const zero = { userCount: 0, workspaceCount: 0 }; + + it('allows when both flags are true and caps are not hit', () => { + expect(canConnect('u', 'w', enabled, zero)).toEqual({ allowed: true }); + }); + + it('denies when only dictation is enabled (XOR)', () => { + const res = canConnect('u', 'w', { ai: { dictation: true } }, zero); + expect(res.allowed).toBe(false); + if (res.allowed === false) expect(res.reason).toMatch(/not enabled/i); + }); + + it('denies when only dictationRealtime is enabled (XOR)', () => { + const res = canConnect( + 'u', + 'w', + { ai: { dictationRealtime: true } }, + zero, + ); + expect(res.allowed).toBe(false); + }); + + it('denies when settings.ai is absent', () => { + const res = canConnect('u', 'w', {}, zero); + expect(res.allowed).toBe(false); + if (res.allowed === false) expect(res.reason).toMatch(/not enabled/i); + }); + + it('denies when settings is undefined', () => { + const res = canConnect('u', 'w', undefined, zero); + expect(res.allowed).toBe(false); + }); + + it('enforces the user cap with >= (no off-by-one)', () => { + // At exactly the cap → deny (>= boundary). + const atCap = canConnect('u', 'w', enabled, { + userCount: MAX_SESSIONS_PER_USER, + workspaceCount: 0, + }); + expect(atCap.allowed).toBe(false); + if (atCap.allowed === false) expect(atCap.reason).toMatch(/already active/i); + + // One below the cap → allow. + const below = canConnect('u', 'w', enabled, { + userCount: MAX_SESSIONS_PER_USER - 1, + workspaceCount: 0, + }); + expect(below.allowed).toBe(true); + }); + + it('enforces the workspace cap with >= (no off-by-one)', () => { + const atCap = canConnect('u', 'w', enabled, { + userCount: 0, + workspaceCount: MAX_SESSIONS_PER_WORKSPACE, + }); + expect(atCap.allowed).toBe(false); + if (atCap.allowed === false) expect(atCap.reason).toMatch(/maximum number/i); + + const below = canConnect('u', 'w', enabled, { + userCount: 0, + workspaceCount: MAX_SESSIONS_PER_WORKSPACE - 1, + }); + expect(below.allowed).toBe(true); + }); + + it('reports the gate reason before any cap reason (gate has priority)', () => { + // Feature disabled AND caps exceeded → the gate reason wins. + const res = canConnect('u', 'w', {}, { + userCount: MAX_SESSIONS_PER_USER, + workspaceCount: MAX_SESSIONS_PER_WORKSPACE, + }); + expect(res.allowed).toBe(false); + if (res.allowed === false) expect(res.reason).toMatch(/not enabled/i); + }); + + it('reports the user-cap reason before the workspace-cap reason', () => { + const res = canConnect('u', 'w', enabled, { + userCount: MAX_SESSIONS_PER_USER, + workspaceCount: MAX_SESSIONS_PER_WORKSPACE, + }); + expect(res.allowed).toBe(false); + if (res.allowed === false) expect(res.reason).toMatch(/already active/i); + }); +}); diff --git a/apps/server/src/core/ai-chat/realtime/session-limits.ts b/apps/server/src/core/ai-chat/realtime/session-limits.ts new file mode 100644 index 00000000..a4623dc2 --- /dev/null +++ b/apps/server/src/core/ai-chat/realtime/session-limits.ts @@ -0,0 +1,132 @@ +/** + * Extracted, dependency-free concurrency-cap primitives for the realtime + * dictation gateway. Pulled out of the gateway so the cap arithmetic and the + * connect decision can be unit-tested directly, without standing up Socket.IO, + * the token service, or the database. + * + * ──────────────────────────────────────────────────────────────────────────── + * OPS WARNING — PER-PROCESS / IN-MEMORY ONLY. + * SessionCounters is a plain in-heap Map. The caps it backs are correct ONLY on + * a SINGLE API replica. With N horizontally scaled replicas the effective limit + * is N × the cap (each replica counts only the sockets it terminates). A true + * global cap across replicas needs a shared store (e.g. Redis). By design for + * the single-process default; documented loudly so it is never mistaken for a + * global guarantee. + * ──────────────────────────────────────────────────────────────────────────── + */ + +/** Realtime is expensive: one live session per user, a handful per workspace. */ +export const MAX_SESSIONS_PER_USER = 1; +export const MAX_SESSIONS_PER_WORKSPACE = 5; + +/** + * A small Map-backed counter keyed by an arbitrary string (user id / workspace + * id). `increment` adds a slot, `decrement` removes one and deletes the key at + * zero so the map never accumulates phantom zero-entries and never goes + * negative. `count` reads the current value, `reset` clears everything (tests). + */ +export class SessionCounters { + private readonly counts = new Map(); + + /** Returns the current count for a key (0 when absent). */ + count(key: string): number { + return this.counts.get(key) ?? 0; + } + + /** Add one slot for `key`; returns the new count. */ + increment(key: string): number { + const next = (this.counts.get(key) ?? 0) + 1; + this.counts.set(key, next); + return next; + } + + /** + * Remove one slot for `key`. Deletes the key once it reaches zero so no + * phantom entry lingers, and never records a negative count (decrementing an + * absent or zero key is a no-op). Returns the resulting count. + */ + decrement(key: string): number { + const current = this.counts.get(key); + // Absent / zero / negative → nothing to release; never create a phantom slot. + if (!current || current <= 0) { + this.counts.delete(key); + return 0; + } + const next = current - 1; + if (next <= 0) { + this.counts.delete(key); + return 0; + } + this.counts.set(key, next); + return next; + } + + /** Drop all counts. Test helper. */ + reset(): void { + this.counts.clear(); + } +} + +/** Current concurrency snapshot passed into `canConnect`. */ +export interface SessionCounts { + userCount: number; + workspaceCount: number; +} + +/** Minimal view of the workspace settings `canConnect` needs. */ +export interface RealtimeGateSettings { + ai?: { dictation?: boolean; dictationRealtime?: boolean }; +} + +export type CanConnectResult = + | { allowed: true } + | { allowed: false; reason: string }; + +/** + * Pure decision: may a new realtime dictation socket open right now? + * + * Gate (both required): `settings.ai.dictation === true` AND + * `settings.ai.dictationRealtime === true`. Missing `settings.ai` → deny. + * + * Caps (checked with `>=`, no off-by-one): deny when the user already holds + * `MAX_SESSIONS_PER_USER` sessions, or the workspace already holds + * `MAX_SESSIONS_PER_WORKSPACE`. The gate is evaluated before the caps so a + * disabled feature reports the gate reason, not a cap reason. + * + * The caller MUST check this BEFORE incrementing any counter so a denied + * connection leaves the counters untouched. + */ +export function canConnect( + _userId: string, + _workspaceId: string, + settings: RealtimeGateSettings | undefined, + counts: SessionCounts, +): CanConnectResult { + // Feature gate first. + if ( + settings?.ai?.dictation !== true || + settings?.ai?.dictationRealtime !== true + ) { + return { allowed: false, reason: 'Realtime dictation is not enabled' }; + } + + // Per-user cap. + if (counts.userCount >= MAX_SESSIONS_PER_USER) { + return { + allowed: false, + reason: + 'A realtime dictation session is already active for your account', + }; + } + + // Per-workspace cap. + if (counts.workspaceCount >= MAX_SESSIONS_PER_WORKSPACE) { + return { + allowed: false, + reason: + 'The maximum number of concurrent realtime dictation sessions for this workspace has been reached', + }; + } + + return { allowed: true }; +} diff --git a/apps/server/src/core/workspace/dto/update-workspace.dto.spec.ts b/apps/server/src/core/workspace/dto/update-workspace.dto.spec.ts index 2ef48315..fa0b3b0f 100644 --- a/apps/server/src/core/workspace/dto/update-workspace.dto.spec.ts +++ b/apps/server/src/core/workspace/dto/update-workspace.dto.spec.ts @@ -48,6 +48,28 @@ describe('UpdateWorkspaceDto.trackerHead validation', () => { }); }); +describe('UpdateWorkspaceDto.aiDictationRealtime validation', () => { + it('accepts aiDictationRealtime: true', async () => { + const errors = await validateDto({ aiDictationRealtime: true }); + expect(hasError(errors, 'aiDictationRealtime')).toBe(false); + }); + + it('accepts aiDictationRealtime: false', async () => { + const errors = await validateDto({ aiDictationRealtime: false }); + expect(hasError(errors, 'aiDictationRealtime')).toBe(false); + }); + + it('rejects a non-boolean aiDictationRealtime with an isBoolean error', async () => { + const errors = await validateDto({ aiDictationRealtime: 'yes' }); + expect(hasError(errors, 'aiDictationRealtime', 'isBoolean')).toBe(true); + }); + + it('accepts an omitted aiDictationRealtime (optional)', async () => { + const errors = await validateDto({}); + expect(hasError(errors, 'aiDictationRealtime')).toBe(false); + }); +}); + describe('UpdateWorkspaceDto.htmlEmbed validation', () => { it('accepts htmlEmbed: true', async () => { const errors = await validateDto({ htmlEmbed: true }); diff --git a/apps/server/src/core/workspace/services/workspace-ai-dictation-realtime.spec.ts b/apps/server/src/core/workspace/services/workspace-ai-dictation-realtime.spec.ts new file mode 100644 index 00000000..ca4424ca --- /dev/null +++ b/apps/server/src/core/workspace/services/workspace-ai-dictation-realtime.spec.ts @@ -0,0 +1,131 @@ +import { WorkspaceService } from './workspace.service'; + +/** + * Exercises the REAL WorkspaceService.update aiDictationRealtime branch at the + * service seam: + * - an update carrying `aiDictationRealtime` calls + * `workspaceRepo.updateAiSettings(workspaceId, 'dictationRealtime', value, trx)` + * (note the SETTING KEY is 'dictationRealtime', not the DTO field name); + * - the change is audited (before/after) when the value actually changes; + * - the `aiDictationRealtime` field is removed from the DTO BEFORE the generic + * `updateWorkspace(dto, ...)` runs, so it can never be written as a (non- + * existent) workspaces column. + * + * The repo, db transaction, and audit service are mocked; `executeTx` runs the + * callback against a fake trx. + */ +describe('WorkspaceService.update — aiDictationRealtime branch (real code)', () => { + function buildService(opts: { settingsBefore?: Record }) { + const updateAiSettings = jest.fn().mockResolvedValue(undefined); + const updateWorkspace = jest.fn().mockResolvedValue(undefined); + const workspaceRepo = { + // First call: read settingsBefore. Second call (with options): the updated + // workspace (must include a licenseKey because update() destructures it). + findById: jest + .fn() + .mockResolvedValueOnce({ id: 'w1', settings: opts.settingsBefore ?? {} }) + .mockResolvedValueOnce({ id: 'w1', name: 'WS', licenseKey: null }), + updateAiSettings, + updateWorkspace, + }; + + // Fake kysely db: only .transaction().execute(cb) is used on this path. + const db = { + transaction: jest.fn(() => ({ + execute: jest.fn(async (cb: any) => cb({ __trx: true })), + })), + }; + + const auditService = { log: jest.fn() }; + + const service = new WorkspaceService( + workspaceRepo as any, // workspaceRepo + {} as any, // spaceService + {} as any, // spaceMemberService + {} as any, // groupRepo + {} as any, // groupUserRepo + {} as any, // userRepo + {} as any, // environmentService + {} as any, // domainService + {} as any, // licenseCheckService + {} as any, // shareRepo + {} as any, // watcherRepo + {} as any, // favoriteRepo + db as any, // db (InjectKysely) + {} as any, // attachmentQueue + {} as any, // billingQueue + {} as any, // aiQueue + auditService as any, // auditService + {} as any, // userSessionRepo + ); + + return { service, workspaceRepo, updateAiSettings, updateWorkspace, auditService }; + } + + it("persists true via updateAiSettings with the 'dictationRealtime' key", async () => { + const { service, updateAiSettings } = buildService({}); + + await service.update('w1', { aiDictationRealtime: true } as any); + + expect(updateAiSettings).toHaveBeenCalledTimes(1); + expect(updateAiSettings).toHaveBeenCalledWith( + 'w1', + 'dictationRealtime', + true, + expect.anything(), // the transaction handle + ); + }); + + it('persists false (explicit disable is not dropped)', async () => { + const { service, updateAiSettings } = buildService({ + settingsBefore: { ai: { dictationRealtime: true } }, + }); + + await service.update('w1', { aiDictationRealtime: false } as any); + + expect(updateAiSettings).toHaveBeenCalledWith( + 'w1', + 'dictationRealtime', + false, + expect.anything(), + ); + }); + + it('does NOT call updateAiSettings when aiDictationRealtime is undefined', async () => { + const { service, updateAiSettings } = buildService({}); + + await service.update('w1', { name: 'New name' } as any); + + // updateAiSettings is only reached by AI branches; none fire here. + expect(updateAiSettings).not.toHaveBeenCalled(); + }); + + it('audits the change (before/after) when the value actually changes', async () => { + const { service, auditService } = buildService({ + settingsBefore: { ai: { dictationRealtime: false } }, + }); + + await service.update('w1', { aiDictationRealtime: true } as any); + + expect(auditService.log).toHaveBeenCalledTimes(1); + const logged = auditService.log.mock.calls[0][0]; + expect(logged.changes.before.aiDictationRealtime).toBe(false); + expect(logged.changes.after.aiDictationRealtime).toBe(true); + }); + + it('removes aiDictationRealtime from the DTO before the generic updateWorkspace', async () => { + const { service, updateWorkspace } = buildService({}); + + await service.update('w1', { + aiDictationRealtime: true, + name: 'New name', + } as any); + + expect(updateWorkspace).toHaveBeenCalledTimes(1); + const [dtoPassed] = updateWorkspace.mock.calls[0]; + // The AI toggle must NOT reach the generic column writer (no such column). + expect('aiDictationRealtime' in dtoPassed).toBe(false); + // A genuine column (name) still flows through. + expect(dtoPassed.name).toBe('New name'); + }); +}); diff --git a/apps/server/src/integrations/ai/ai-settings.service.spec.ts b/apps/server/src/integrations/ai/ai-settings.service.spec.ts new file mode 100644 index 00000000..7359a53c --- /dev/null +++ b/apps/server/src/integrations/ai/ai-settings.service.spec.ts @@ -0,0 +1,130 @@ +import { AiSettingsService } from './ai-settings.service'; + +// Unit tests for the partial-merge behaviour of AiSettingsService.update and the +// key-fallback behaviour of resolve. Constructed directly with stub deps (no +// Nest graph): we assert exactly which repo calls fire for a given partial DTO, +// proving the realtime STT fields merge in without clobbering chat fields and +// that an empty patch writes nothing. + +interface Deps { + updateAiProviderSettings: jest.Mock; + readProviderResult?: Record; + getMaskedResult?: unknown; +} + +function buildService(deps: Deps) { + const workspaceRepo = { + updateAiProviderSettings: deps.updateAiProviderSettings, + // findById feeds the private readProvider() (target-driver resolution). + findById: jest.fn().mockResolvedValue({ + settings: { ai: { provider: deps.readProviderResult ?? {} } }, + }), + }; + const aiProviderCredentialsRepo = { + find: jest.fn(), + upsert: jest.fn(), + clearKey: jest.fn(), + upsertEmbeddingKey: jest.fn(), + clearEmbeddingKey: jest.fn(), + upsertSttKey: jest.fn(), + clearSttKey: jest.fn(), + }; + const secretBox = { + encryptSecret: jest.fn((v: string) => `enc(${v})`), + decryptSecret: jest.fn((v: string) => `dec(${v})`), + }; + const pageEmbeddingRepo = { countIndexedPages: jest.fn().mockResolvedValue(0) }; + const pageRepo = { countEmbeddablePages: jest.fn().mockResolvedValue(0) }; + + const service = new AiSettingsService( + workspaceRepo as any, + {} as any, // aiAgentRoleRepo + aiProviderCredentialsRepo as any, + pageEmbeddingRepo as any, + pageRepo as any, + secretBox as any, + {} as any, // aiQueue + ); + + // getMasked is exercised at the end of update(); stub it so update() resolves + // without a second repo round-trip we don't care about here. + jest + .spyOn(service, 'getMasked') + .mockResolvedValue((deps.getMaskedResult ?? {}) as any); + + return { service, workspaceRepo, aiProviderCredentialsRepo, secretBox }; +} + +describe('AiSettingsService.update partial merge', () => { + it('a DTO with only realtime fields patches exactly those keys', async () => { + const updateAiProviderSettings = jest.fn().mockResolvedValue(undefined); + const { service } = buildService({ updateAiProviderSettings }); + + await service.update('w1', { + sttRealtimeModel: 'gpt-4o-realtime', + sttRealtimeBaseUrl: 'https://api.example.com/v1', + }); + + expect(updateAiProviderSettings).toHaveBeenCalledTimes(1); + const [, patch] = updateAiProviderSettings.mock.calls[0]; + expect(Object.keys(patch).sort()).toEqual( + ['sttRealtimeBaseUrl', 'sttRealtimeModel'].sort(), + ); + }); + + it('a DTO with chatModel does NOT clobber realtime fields (only chatModel patched)', async () => { + const updateAiProviderSettings = jest.fn().mockResolvedValue(undefined); + const { service } = buildService({ updateAiProviderSettings }); + + await service.update('w1', { chatModel: 'gpt-4o' }); + + const [, patch] = updateAiProviderSettings.mock.calls[0]; + expect(patch).toEqual({ chatModel: 'gpt-4o' }); + expect(patch).not.toHaveProperty('sttRealtimeModel'); + expect(patch).not.toHaveProperty('sttRealtimeBaseUrl'); + }); + + it('an empty patch never calls updateAiProviderSettings', async () => { + const updateAiProviderSettings = jest.fn().mockResolvedValue(undefined); + const { service } = buildService({ updateAiProviderSettings }); + + await service.update('w1', {}); + + expect(updateAiProviderSettings).not.toHaveBeenCalled(); + }); +}); + +describe('AiSettingsService.resolve STT key fallback', () => { + it('uses the STT-specific key when sttApiKeyEnc is present (decrypt)', async () => { + const { service, aiProviderCredentialsRepo, secretBox } = buildService({ + updateAiProviderSettings: jest.fn(), + readProviderResult: { driver: 'openai', chatModel: 'gpt-4o' }, + }); + aiProviderCredentialsRepo.find.mockResolvedValue({ + apiKeyEnc: 'CHAT', + sttApiKeyEnc: 'STT', + }); + + const cfg = await service.resolve('w1'); + + expect(cfg?.sttApiKey).toBe('dec(STT)'); + expect(secretBox.decryptSecret).toHaveBeenCalledWith('STT'); + }); + + it('falls back to the chat apiKey when sttApiKeyEnc is absent', async () => { + const { service, aiProviderCredentialsRepo } = buildService({ + updateAiProviderSettings: jest.fn(), + readProviderResult: { driver: 'openai', chatModel: 'gpt-4o' }, + }); + aiProviderCredentialsRepo.find.mockResolvedValue({ + apiKeyEnc: 'CHAT', + // no sttApiKeyEnc + }); + + const cfg = await service.resolve('w1'); + + // sttApiKey === the resolved chat apiKey (dec(CHAT)). + expect(cfg?.sttApiKey).toBe('dec(CHAT)'); + expect(cfg?.apiKey).toBe('dec(CHAT)'); + }); +}); diff --git a/apps/server/src/integrations/ai/dto/update-ai-settings.ssrf.spec.ts b/apps/server/src/integrations/ai/dto/update-ai-settings.ssrf.spec.ts new file mode 100644 index 00000000..a527c41a --- /dev/null +++ b/apps/server/src/integrations/ai/dto/update-ai-settings.ssrf.spec.ts @@ -0,0 +1,54 @@ +import 'reflect-metadata'; +import { plainToInstance } from 'class-transformer'; +import { validate } from 'class-validator'; +import { UpdateAiSettingsDto } from './update-ai-settings.dto'; +import { isUrlAllowed } from '../../../core/ai-chat/external-mcp/ssrf-guard'; + +// SSRF contract for sttRealtimeBaseUrl. +// +// The DTO intentionally validates sttRealtimeBaseUrl with @IsString() ONLY (no +// @IsUrl): an admin may legitimately point at an internal-looking host that DNS +// resolves to a public address, and over-strict URL validation would reject +// valid setups. The real defense is the CONNECT-TIME SSRF guard (isUrlAllowed on +// the http-equivalent of the wss URL), which blocks link-local/loopback/private +// targets. This pins both halves of that contract. + +async function validateDto(payload: Record) { + const dto = plainToInstance(UpdateAiSettingsDto, payload); + return validate(dto as object); +} + +describe('UpdateAiSettingsDto.sttRealtimeBaseUrl is @IsString only (no @IsUrl)', () => { + it('accepts a metadata-service URL at the DTO layer (string, not URL-validated)', async () => { + const errors = await validateDto({ + sttRealtimeBaseUrl: 'http://169.254.169.254/v1', + }); + const fieldErr = errors.find( + (e) => e.property === 'sttRealtimeBaseUrl', + ); + // No DTO-level rejection: blocking is deferred to the connect-time guard. + expect(fieldErr).toBeUndefined(); + }); + + it('rejects a non-string sttRealtimeBaseUrl with an isString error', async () => { + const errors = await validateDto({ sttRealtimeBaseUrl: 123 }); + const fieldErr = errors.find( + (e) => e.property === 'sttRealtimeBaseUrl', + ); + expect(Object.keys(fieldErr?.constraints ?? {})).toContain('isString'); + }); +}); + +describe('connect-time SSRF guard blocks the metadata service', () => { + it('isUrlAllowed denies the http-equivalent of the cloud metadata endpoint', async () => { + // The realtime path derives a wss URL then checks isUrlAllowed on the + // http(s)-equivalent. For http://169.254.169.254 the equivalent is itself. + const result = await isUrlAllowed('http://169.254.169.254/v1'); + expect(result.ok).toBe(false); + }); + + it('isUrlAllowed denies loopback', async () => { + const result = await isUrlAllowed('http://127.0.0.1/v1'); + expect(result.ok).toBe(false); + }); +}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4816bcd7..09b4d03f 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -447,6 +447,9 @@ importers: '@vitejs/plugin-react': specifier: 6.0.1 version: 6.0.1(vite@8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3)) + '@vitest/coverage-v8': + specifier: 4.1.6 + version: 4.1.6(vitest@4.1.6) eslint: specifier: 9.28.0 version: 9.28.0(jiti@2.4.2) @@ -491,7 +494,7 @@ importers: version: 8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3) vitest: specifier: 4.1.6 - version: 4.1.6(@opentelemetry/api@1.9.0)(@types/node@22.19.1)(happy-dom@20.8.9)(jsdom@25.0.0)(vite@8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3)) + version: 4.1.6(@opentelemetry/api@1.9.0)(@types/node@22.19.1)(@vitest/coverage-v8@4.1.6)(happy-dom@20.8.9)(jsdom@25.0.0)(vite@8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3)) apps/server: dependencies: @@ -1427,10 +1430,18 @@ packages: resolution: {integrity: sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==} engines: {node: '>=6.9.0'} + '@babel/helper-string-parser@7.29.7': + resolution: {integrity: sha512-Pb5ijPrZ89GDH8223L4UP8i6QApWxs04RbPQJTeWDV0/keR2E36MeKnyr6LYmUUvqRRI+Iv87SuF1W6ErINzYw==} + engines: {node: '>=6.9.0'} + '@babel/helper-validator-identifier@7.28.5': resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==} engines: {node: '>=6.9.0'} + '@babel/helper-validator-identifier@7.29.7': + resolution: {integrity: sha512-qehxGkRj55h/ff8EMaJ+cYhyaKlHIxqYDn682wQD7RNp9UujOQsHog2uS0r2vzr4pW+sXf90NeeayjcNaX3fFg==} + engines: {node: '>=6.9.0'} + '@babel/helper-validator-option@7.27.1': resolution: {integrity: sha512-YvjJow9FxbhFFKDSuFnVCe2WxXk1zWc22fFePVNEaWJEu8IrZVlda6N0uHwzZrUM1il7NC9Mlp4MaJYbYd9JSg==} engines: {node: '>=6.9.0'} @@ -1453,6 +1464,11 @@ packages: engines: {node: '>=6.0.0'} hasBin: true + '@babel/parser@7.29.7': + resolution: {integrity: sha512-hnORnjP/1P/zFEndoeX+n+t1RwWRJiJpM/jO7FW32Kn9r5+sJB2JWOdYo4L6k78j15eCwY3Gm/7364B1EMwtNg==} + engines: {node: '>=6.0.0'} + hasBin: true + '@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression@7.23.3': resolution: {integrity: sha512-iRkKcCqb7iGnq9+3G6rZ+Ciz5VywC4XNRHe57lKM+jOeYAoR0lVqdeeDRfh0tQcTfw/+vBhHn926FmQhLtlFLQ==} engines: {node: '>=6.9.0'} @@ -1942,9 +1958,17 @@ packages: resolution: {integrity: sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==} engines: {node: '>=6.9.0'} + '@babel/types@7.29.7': + resolution: {integrity: sha512-4zBIxpPzowiZpusoFkyGVwakdRJUyuH5PxQ/PrqghfdFWWasvnCdPfQXHrenDai+gyLARulZjZowCOj6fjT4pA==} + engines: {node: '>=6.9.0'} + '@bcoe/v8-coverage@0.2.3': resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==} + '@bcoe/v8-coverage@1.0.2': + resolution: {integrity: sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==} + engines: {node: '>=18'} + '@borewit/text-codec@0.2.1': resolution: {integrity: sha512-k7vvKPbf7J2fZ5klGRD9AeKfUvojuZIQ3BT5u7Jfv+puwXkUBUT5PVyMDfJZpy30CBDXGMgw7fguK/lpOMBvgw==} @@ -5373,6 +5397,15 @@ packages: babel-plugin-react-compiler: optional: true + '@vitest/coverage-v8@4.1.6': + resolution: {integrity: sha512-36l628fQ/9a/8ihy97eOtEnvWQEdqULQOJtcaxtoNq0G1w3Mxd4szSahOaMM9/NGyZ+hyKcMtIW/WIxq0XQViQ==} + peerDependencies: + '@vitest/browser': 4.1.6 + vitest: 4.1.6 + peerDependenciesMeta: + '@vitest/browser': + optional: true + '@vitest/expect@4.1.6': resolution: {integrity: sha512-7EHDquPthALSV0jhhjgEW8FXaviMx7rSqu8W6oqCoAuOhKov814P99QDV1pxMA3QPv21YudvJngIhjrNI4opLg==} @@ -5658,6 +5691,9 @@ packages: resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} engines: {node: '>=12'} + ast-v8-to-istanbul@1.0.4: + resolution: {integrity: sha512-0bC0/4bTSrnwdhU3IsZDwEdojvuPrSg59OYZfKsLRtJZ0u8VBx9DebfqqG8bRdCC0I7vjgxmPi41P0lpkhJHtA==} + async-lock@1.4.1: resolution: {integrity: sha512-Az2ZTpuytrtqENulXwO3GGv1Bztugx6TT37NIo7imr/Qo0gsYiGtSdBa2B6fsXhTpVZDNfu1Qn3pk531e3q+nQ==} @@ -7575,6 +7611,10 @@ packages: resolution: {integrity: sha512-BewmUXImeuRk2YY0PVbxgKAysvhRPUQE0h5QRM++nVWyubKGV0l8qQ5op8+B2DOmwSe63Jivj0BjkPQVf8fP5g==} engines: {node: '>=8'} + istanbul-reports@3.2.0: + resolution: {integrity: sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==} + engines: {node: '>=8'} + iterare@1.2.1: resolution: {integrity: sha512-RKYVTCjAnRthyJes037NX/IiqeidgN1xc3j1RjFfECFp28A1GVwK9nA+i0rJPaHqSZwygLzRnFlzUuHFoWWy+Q==} engines: {node: '>=6'} @@ -7800,6 +7840,9 @@ packages: js-tiktoken@1.0.21: resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==} + js-tokens@10.0.0: + resolution: {integrity: sha512-lM/UBzQmfJRo9ABXbPWemivdCW8V2G8FHaHdypQaIy523snUjog0W71ayWXTjiR+ixeMyVHN2XcpnTd/liPg/Q==} + js-tokens@4.0.0: resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} @@ -8245,6 +8288,9 @@ packages: magic-string@0.30.21: resolution: {integrity: sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ==} + magicast@0.5.3: + resolution: {integrity: sha512-pVKE4UdSQ7DvHzivsCIFx2BJn1mHG6KsyrFcaxFx6tONdneEuThrDx0Cj3AMg58KyN4pzYT+LHOotxDQDjNvkw==} + make-dir@2.1.0: resolution: {integrity: sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==} engines: {node: '>=6'} @@ -11540,8 +11586,12 @@ snapshots: '@babel/helper-string-parser@7.27.1': {} + '@babel/helper-string-parser@7.29.7': {} + '@babel/helper-validator-identifier@7.28.5': {} + '@babel/helper-validator-identifier@7.29.7': {} + '@babel/helper-validator-option@7.27.1': {} '@babel/helper-wrap-function@7.22.20': @@ -11563,6 +11613,10 @@ snapshots: dependencies: '@babel/types': 7.28.5 + '@babel/parser@7.29.7': + dependencies: + '@babel/types': 7.29.7 + '@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression@7.23.3(@babel/core@7.28.5)': dependencies: '@babel/core': 7.28.5 @@ -12168,8 +12222,15 @@ snapshots: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.28.5 + '@babel/types@7.29.7': + dependencies: + '@babel/helper-string-parser': 7.29.7 + '@babel/helper-validator-identifier': 7.29.7 + '@bcoe/v8-coverage@0.2.3': {} + '@bcoe/v8-coverage@1.0.2': {} + '@borewit/text-codec@0.2.1': {} '@braintree/sanitize-url@6.0.2': {} @@ -15818,6 +15879,20 @@ snapshots: '@rolldown/pluginutils': 1.0.0-rc.7 vite: 8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3) + '@vitest/coverage-v8@4.1.6(vitest@4.1.6)': + dependencies: + '@bcoe/v8-coverage': 1.0.2 + '@vitest/utils': 4.1.6 + ast-v8-to-istanbul: 1.0.4 + istanbul-lib-coverage: 3.2.2 + istanbul-lib-report: 3.0.1 + istanbul-reports: 3.2.0 + magicast: 0.5.3 + obug: 2.1.1 + std-env: 4.1.0 + tinyrainbow: 3.1.0 + vitest: 4.1.6(@opentelemetry/api@1.9.0)(@types/node@22.19.1)(@vitest/coverage-v8@4.1.6)(happy-dom@20.8.9)(jsdom@25.0.0)(vite@8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3)) + '@vitest/expect@4.1.6': dependencies: '@standard-schema/spec': 1.1.0 @@ -16147,6 +16222,12 @@ snapshots: assertion-error@2.0.1: {} + ast-v8-to-istanbul@1.0.4: + dependencies: + '@jridgewell/trace-mapping': 0.3.31 + estree-walker: 3.0.3 + js-tokens: 10.0.0 + async-lock@1.4.1: {} async-mutex@0.5.0: @@ -18358,6 +18439,11 @@ snapshots: html-escaper: 2.0.2 istanbul-lib-report: 3.0.1 + istanbul-reports@3.2.0: + dependencies: + html-escaper: 2.0.2 + istanbul-lib-report: 3.0.1 + iterare@1.2.1: {} iterator.prototype@1.1.5: @@ -18768,6 +18854,8 @@ snapshots: dependencies: base64-js: 1.5.1 + js-tokens@10.0.0: {} + js-tokens@4.0.0: {} js-yaml@3.14.2: @@ -19204,6 +19292,12 @@ snapshots: dependencies: '@jridgewell/sourcemap-codec': 1.5.5 + magicast@0.5.3: + dependencies: + '@babel/parser': 7.29.7 + '@babel/types': 7.29.7 + source-map-js: 1.2.1 + make-dir@2.1.0: dependencies: pify: 4.0.1 @@ -21528,7 +21622,7 @@ snapshots: tsx: 4.21.0 yaml: 2.8.3 - vitest@4.1.6(@opentelemetry/api@1.9.0)(@types/node@22.19.1)(happy-dom@20.8.9)(jsdom@25.0.0)(vite@8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3)): + vitest@4.1.6(@opentelemetry/api@1.9.0)(@types/node@22.19.1)(@vitest/coverage-v8@4.1.6)(happy-dom@20.8.9)(jsdom@25.0.0)(vite@8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3)): dependencies: '@vitest/expect': 4.1.6 '@vitest/mocker': 4.1.6(vite@8.0.5(@types/node@22.19.1)(esbuild@0.28.0)(jiti@2.4.2)(less@4.2.0)(sugarss@5.0.1(postcss@8.5.14))(terser@5.39.0)(tsx@4.21.0)(yaml@2.8.3)) @@ -21553,6 +21647,7 @@ snapshots: optionalDependencies: '@opentelemetry/api': 1.9.0 '@types/node': 22.19.1 + '@vitest/coverage-v8': 4.1.6(vitest@4.1.6) happy-dom: 20.8.9 jsdom: 25.0.0 transitivePeerDependencies: