feat(ai-chat): bound external MCP tool calls with per-call timeouts
External MCP tools (web search, crawl) had no per-call timeout: a hung tool call was only broken by the 15-min transport silence timeout shared with the chat provider, and a server that kept the socket warm but never returned could spin until the user cancelled. Add two independent, composing bounds for external MCP traffic (the chat provider path is unchanged): - Silence 5 min: buildPinnedDispatcher now overrides headersTimeout/ bodyTimeout with mcpStreamTimeoutMs() (AI_MCP_STREAM_TIMEOUT_MS, default 300000) on the external-MCP dispatcher only, so a byte-silent upstream is severed in ~5 min instead of 15. - Total per-call 15 min: wrapToolWithCallTimeout wraps each external tool's execute with a fresh AbortController + timer composed with the turn signal via AbortSignal.any (AI_MCP_CALL_TIMEOUT_MS, default 900000). It RACES the call against the abort signal because @ai-sdk/mcp does not settle its in-flight promise on abort, so a warm-but-stuck call would otherwise hang forever. On timeout the call surfaces as a tool-error and the agent loop recovers. Add tests (incl. a never-settling real-client-style stub) and document both env vars in .env.example.
This commit is contained in:
@@ -0,0 +1,205 @@
|
||||
import { type Tool, type ToolCallOptions } from 'ai';
|
||||
import {
|
||||
wrapToolWithCallTimeout,
|
||||
wrapToolsWithCallTimeout,
|
||||
} from './mcp-clients.service';
|
||||
import {
|
||||
mcpStreamTimeoutMs,
|
||||
mcpCallTimeoutMs,
|
||||
} from '../../../integrations/ai/ai-streaming-fetch';
|
||||
|
||||
/**
|
||||
* Per-call total-timeout guard for external MCP tools (mcp-clients.service).
|
||||
*
|
||||
* `@ai-sdk/mcp`'s tool execute has NO built-in per-call timeout — a tool that
|
||||
* keeps the connection warm but never returns is otherwise unbounded. The
|
||||
* wrapper attaches a fresh AbortController + timer per CALL and composes it with
|
||||
* the turn's abortSignal via AbortSignal.any, so EITHER the per-call timeout OR a
|
||||
* client disconnect aborts the in-flight call.
|
||||
*
|
||||
* Fake timers prove the timeout fires WITHOUT real waiting; no leaked timer keeps
|
||||
* the process alive after a fast resolve.
|
||||
*/
|
||||
const CALL_TIMEOUT_MS = 900_000;
|
||||
|
||||
/** Build a Tool around an `execute` impl, mirroring the SDK's minimal shape. */
|
||||
function toolWith(
|
||||
execute: (args: unknown, options: ToolCallOptions) => unknown,
|
||||
): Tool {
|
||||
return { description: 'x', inputSchema: undefined, execute } as unknown as Tool;
|
||||
}
|
||||
|
||||
/** Invoke a (possibly wrapped) tool's execute with an optional turn signal. */
|
||||
function callExecute(
|
||||
tool: Tool,
|
||||
args: unknown,
|
||||
abortSignal?: AbortSignal,
|
||||
): unknown {
|
||||
const execute = tool.execute as (
|
||||
args: unknown,
|
||||
options: ToolCallOptions,
|
||||
) => unknown;
|
||||
return execute(args, { abortSignal } as ToolCallOptions);
|
||||
}
|
||||
|
||||
describe('wrapToolWithCallTimeout', () => {
|
||||
beforeEach(() => jest.useFakeTimers());
|
||||
afterEach(() => {
|
||||
jest.clearAllTimers();
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('aborts a tool that only rejects when its abortSignal fires, after ms elapses', async () => {
|
||||
// The tool resolves NEVER on its own — it only settles when the abortSignal
|
||||
// it is handed aborts. So a resolution proves the per-call timer fired and
|
||||
// aborted the call (not the tool finishing by itself).
|
||||
let received: AbortSignal | undefined;
|
||||
const tool = toolWith((_args, options) => {
|
||||
received = options.abortSignal;
|
||||
return new Promise((_resolve, reject) => {
|
||||
options.abortSignal?.addEventListener('abort', () => {
|
||||
reject(options.abortSignal?.reason ?? new Error('aborted'));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
const wrapped = wrapToolWithCallTimeout(tool, CALL_TIMEOUT_MS);
|
||||
const promise = callExecute(wrapped, { q: 'x' }) as Promise<unknown>;
|
||||
// Attach the rejection handler synchronously so advancing timers cannot mark
|
||||
// it an unhandled rejection.
|
||||
const settled = promise.then(
|
||||
() => ({ ok: true as const }),
|
||||
(err: unknown) => ({ ok: false as const, err }),
|
||||
);
|
||||
|
||||
// Nothing fired yet.
|
||||
jest.advanceTimersByTime(CALL_TIMEOUT_MS - 1);
|
||||
// Past the cap -> the per-call timer aborts the composed signal.
|
||||
jest.advanceTimersByTime(2);
|
||||
|
||||
const result = await settled;
|
||||
expect(result.ok).toBe(false);
|
||||
expect(received).toBeInstanceOf(AbortSignal);
|
||||
// The abort reason / rejection mentions the timeout.
|
||||
const message =
|
||||
(result as { err: unknown }).err instanceof Error
|
||||
? ((result as { err: Error }).err.message)
|
||||
: String((result as { err: unknown }).err);
|
||||
expect(message).toMatch(/timed out after 900000ms/);
|
||||
});
|
||||
|
||||
it('aborts a REAL-client-style tool that never settles and ignores abort (race fix)', async () => {
|
||||
// Models the ACTUAL @ai-sdk/mcp semantics: its in-flight promise does NOT
|
||||
// reject on abort (it only checks the signal when a response arrives), so a
|
||||
// warm-but-stuck call NEVER settles on its own and does NOT listen to the
|
||||
// abort signal. The wrapper must still reject after `ms` via the race — an
|
||||
// implementation that merely `await original(...)` would hang here forever.
|
||||
// This test FAILS against the old await-only code and PASSES with the race.
|
||||
const tool = toolWith(() => new Promise(() => {})); // never settles, no abort
|
||||
const wrapped = wrapToolWithCallTimeout(tool, CALL_TIMEOUT_MS);
|
||||
const promise = callExecute(wrapped, { q: 'x' }) as Promise<unknown>;
|
||||
// Assert the rejection without hanging: drive fake time async so the timer's
|
||||
// abort -> race rejection microtasks flush, then await the rejection.
|
||||
const expectation = expect(promise).rejects.toThrow(/timed out after 900000ms/);
|
||||
await jest.advanceTimersByTimeAsync(CALL_TIMEOUT_MS + 1);
|
||||
await expectation;
|
||||
});
|
||||
|
||||
it('passes a fast tool through and leaks no timer (advancing later does not throw)', async () => {
|
||||
const tool = toolWith(() => Promise.resolve('fast-result'));
|
||||
const wrapped = wrapToolWithCallTimeout(tool, CALL_TIMEOUT_MS);
|
||||
|
||||
const value = await (callExecute(wrapped, {}) as Promise<unknown>);
|
||||
expect(value).toBe('fast-result');
|
||||
|
||||
// The timer was cleared in the finally — advancing past the cap aborts
|
||||
// nothing and throws nothing.
|
||||
expect(() => jest.advanceTimersByTime(CALL_TIMEOUT_MS * 2)).not.toThrow();
|
||||
});
|
||||
|
||||
it('aborts when the caller turn signal aborts before the timeout (disconnect path)', async () => {
|
||||
// Real-client semantics: the tool never settles and does NOT listen to abort,
|
||||
// so the wrapper must reject via the race when the caller's turn signal (a
|
||||
// client disconnect) aborts BEFORE the per-call cap. The race propagates the
|
||||
// caller's abort reason.
|
||||
const tool = toolWith(() => new Promise(() => {})); // never settles, no abort
|
||||
const wrapped = wrapToolWithCallTimeout(tool, CALL_TIMEOUT_MS);
|
||||
const turn = new AbortController();
|
||||
const promise = callExecute(wrapped, {}, turn.signal) as Promise<unknown>;
|
||||
const settled = promise.then(
|
||||
() => ({ ok: true as const }),
|
||||
(err: unknown) => ({ ok: false as const, err }),
|
||||
);
|
||||
|
||||
// Disconnect well before the cap; the per-call timer never fires here.
|
||||
turn.abort(new Error('client disconnected'));
|
||||
const result = await settled;
|
||||
expect(result.ok).toBe(false);
|
||||
const message =
|
||||
(result as { err: unknown }).err instanceof Error
|
||||
? (result as { err: Error }).err.message
|
||||
: String((result as { err: unknown }).err);
|
||||
// The caller's abort reason propagates through the race.
|
||||
expect(message).toMatch(/client disconnected/);
|
||||
});
|
||||
|
||||
it('passes a tool with no execute through unchanged', () => {
|
||||
const noExecute = { description: 'x', inputSchema: undefined } as unknown as Tool;
|
||||
const wrapped = wrapToolWithCallTimeout(noExecute, CALL_TIMEOUT_MS);
|
||||
// Same object back, execute still absent.
|
||||
expect(wrapped).toBe(noExecute);
|
||||
expect((wrapped as { execute?: unknown }).execute).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('wrapToolsWithCallTimeout', () => {
|
||||
beforeEach(() => jest.useFakeTimers());
|
||||
afterEach(() => {
|
||||
jest.clearAllTimers();
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('wraps every tool in the map (each call gets its own guard)', async () => {
|
||||
const tools: Record<string, Tool> = {
|
||||
a: toolWith(() => Promise.resolve('A')),
|
||||
b: toolWith(() => Promise.resolve('B')),
|
||||
};
|
||||
const out = wrapToolsWithCallTimeout(tools, CALL_TIMEOUT_MS);
|
||||
expect(Object.keys(out)).toEqual(['a', 'b']);
|
||||
expect(await (callExecute(out.a, {}) as Promise<unknown>)).toBe('A');
|
||||
expect(await (callExecute(out.b, {}) as Promise<unknown>)).toBe('B');
|
||||
});
|
||||
});
|
||||
|
||||
describe('mcp timeout env helpers', () => {
|
||||
const ORIG_SILENCE = process.env.AI_MCP_STREAM_TIMEOUT_MS;
|
||||
const ORIG_CALL = process.env.AI_MCP_CALL_TIMEOUT_MS;
|
||||
afterEach(() => {
|
||||
if (ORIG_SILENCE === undefined) delete process.env.AI_MCP_STREAM_TIMEOUT_MS;
|
||||
else process.env.AI_MCP_STREAM_TIMEOUT_MS = ORIG_SILENCE;
|
||||
if (ORIG_CALL === undefined) delete process.env.AI_MCP_CALL_TIMEOUT_MS;
|
||||
else process.env.AI_MCP_CALL_TIMEOUT_MS = ORIG_CALL;
|
||||
});
|
||||
|
||||
it('mcpStreamTimeoutMs defaults to 5 min and honors a positive override', () => {
|
||||
delete process.env.AI_MCP_STREAM_TIMEOUT_MS;
|
||||
expect(mcpStreamTimeoutMs()).toBe(300_000);
|
||||
process.env.AI_MCP_STREAM_TIMEOUT_MS = '60000';
|
||||
expect(mcpStreamTimeoutMs()).toBe(60_000);
|
||||
for (const bad of ['0', '-1', 'x', '']) {
|
||||
process.env.AI_MCP_STREAM_TIMEOUT_MS = bad;
|
||||
expect(mcpStreamTimeoutMs()).toBe(300_000);
|
||||
}
|
||||
});
|
||||
|
||||
it('mcpCallTimeoutMs defaults to 15 min and honors a positive override', () => {
|
||||
delete process.env.AI_MCP_CALL_TIMEOUT_MS;
|
||||
expect(mcpCallTimeoutMs()).toBe(900_000);
|
||||
process.env.AI_MCP_CALL_TIMEOUT_MS = '120000';
|
||||
expect(mcpCallTimeoutMs()).toBe(120_000);
|
||||
for (const bad of ['0', '-1', 'x', '']) {
|
||||
process.env.AI_MCP_CALL_TIMEOUT_MS = bad;
|
||||
expect(mcpCallTimeoutMs()).toBe(900_000);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1,12 +1,16 @@
|
||||
import { isIP } from 'node:net';
|
||||
import { lookup as dnsLookup, type LookupAddress } from 'node:dns';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { type Tool } from 'ai';
|
||||
import { type Tool, type ToolCallOptions } from 'ai';
|
||||
import { createMCPClient } from '@ai-sdk/mcp';
|
||||
import { Agent, type Dispatcher } from 'undici';
|
||||
import { AiMcpServerRepo } from '@docmost/db/repos/ai-chat/ai-mcp-server.repo';
|
||||
import { AiMcpServer } from '@docmost/db/types/entity.types';
|
||||
import { streamingDispatcherOptions } from '../../../integrations/ai/ai-streaming-fetch';
|
||||
import {
|
||||
streamingDispatcherOptions,
|
||||
mcpStreamTimeoutMs,
|
||||
mcpCallTimeoutMs,
|
||||
} from '../../../integrations/ai/ai-streaming-fetch';
|
||||
import { SecretBoxService } from '../../../integrations/crypto/secret-box';
|
||||
import { isUrlAllowed, isIpAllowed } from './ssrf-guard';
|
||||
|
||||
@@ -219,6 +223,8 @@ export class McpClientsService {
|
||||
const tools: Record<string, Tool> = {};
|
||||
const clients: McpClient[] = [];
|
||||
const outcomes: ServerOutcome[] = [];
|
||||
// Per-call total wall-clock cap, read once for this build (env-overridable).
|
||||
const callTimeoutMs = mcpCallTimeoutMs();
|
||||
|
||||
for (const server of servers) {
|
||||
try {
|
||||
@@ -230,10 +236,13 @@ export class McpClientsService {
|
||||
Array.isArray(allow) && allow.length > 0
|
||||
? pick(raw, allow)
|
||||
: raw;
|
||||
// Bound each tool's execute with a per-call total-timeout guard before
|
||||
// merging, so a single chatty-but-stuck call is aborted after the cap.
|
||||
const guarded = wrapToolsWithCallTimeout(picked, callTimeoutMs);
|
||||
// Namespace each tool with the sanitized server name AND disambiguate
|
||||
// against names already merged from earlier servers, so no external
|
||||
// tool is silently overwritten on collision.
|
||||
this.mergeNamespaced(tools, picked, server.name, server.id);
|
||||
this.mergeNamespaced(tools, guarded, server.name, server.id);
|
||||
outcomes.push({ name: server.name, ok: true });
|
||||
} catch (err) {
|
||||
// A failed server is skipped — the turn proceeds with the rest. Log a
|
||||
@@ -400,17 +409,21 @@ export function validateResolvedAddresses(
|
||||
* to an IP literal).
|
||||
*/
|
||||
function buildPinnedDispatcher(): Agent {
|
||||
// External-MCP traffic uses a DEDICATED, shorter silence timeout
|
||||
// (`AI_MCP_STREAM_TIMEOUT_MS`, default 5 min) — deliberately tighter than the
|
||||
// chat provider's 15-min `streamTimeoutMs()` — so a byte-silent/hung MCP
|
||||
// upstream is broken in ~5 min instead of 15. We keep the keep-alive options
|
||||
// from `streamingDispatcherOptions()` but OVERRIDE headers/body timeouts.
|
||||
// Accepted trade-off: a legitimately long but byte-silent single tool call,
|
||||
// and an SSE transport idling >5 min BETWEEN tool calls, are also cut here; the
|
||||
// per-call total cap (wrapToolsWithCallTimeout, `AI_MCP_CALL_TIMEOUT_MS`) is the
|
||||
// complementary guard for chatty-but-stuck calls that keep the socket warm yet
|
||||
// never return.
|
||||
const mcpSilenceMs = mcpStreamTimeoutMs();
|
||||
return new Agent({
|
||||
// Raise undici's default 300s headers/body timeouts on external MCP traffic
|
||||
// to the same generous-but-finite silence timeout the chat fetch uses (#175).
|
||||
// A long agent turn keeps an SSE transport (e.g. crawl4ai's /mcp/sse) open
|
||||
// across the whole turn; that connection can idle BETWEEN tool calls longer
|
||||
// than 5 min, and undici's bodyTimeout would otherwise sever it mid-task — a
|
||||
// tool-call failure that aborts the streamed turn and shows the user "Lost
|
||||
// connection to the AI provider". A slow single tool call (a crawl) can
|
||||
// likewise exceed headersTimeout. The timeout stays FINITE so a genuinely
|
||||
// hung server is still broken eventually.
|
||||
...streamingDispatcherOptions(),
|
||||
headersTimeout: mcpSilenceMs,
|
||||
bodyTimeout: mcpSilenceMs,
|
||||
connect: {
|
||||
lookup: (hostname, _options, callback) => {
|
||||
// Always resolve ALL addresses ourselves; do not trust the caller's
|
||||
@@ -572,6 +585,78 @@ function disambiguate(
|
||||
return capName(`${name.slice(0, MAX_TOOL_NAME_LENGTH - 14)}_${Date.now()}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap every tool's execute with a per-call total-timeout guard so a single
|
||||
* external MCP tool call that keeps the connection warm but never returns is
|
||||
* aborted after `ms` wall-clock (complements the transport silence timeout).
|
||||
*/
|
||||
export function wrapToolsWithCallTimeout(
|
||||
tools: Record<string, Tool>,
|
||||
ms: number,
|
||||
): Record<string, Tool> {
|
||||
const out: Record<string, Tool> = {};
|
||||
for (const [name, t] of Object.entries(tools)) {
|
||||
out[name] = wrapToolWithCallTimeout(t, ms);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Per-call total-timeout wrapper for one MCP tool. A fresh AbortController +
|
||||
* timer bounds the call; it is composed with the turn's abortSignal via
|
||||
* AbortSignal.any so EITHER the per-call timeout OR a client disconnect aborts
|
||||
* the call. We RACE the call against the composed abort signal rather than just
|
||||
* awaiting it, because @ai-sdk/mcp does NOT settle its in-flight promise on abort
|
||||
* (verified in @ai-sdk/mcp@1.0.52: request() only does throwIfAborted() once
|
||||
* before send and only re-checks the signal inside the response-message handler,
|
||||
* which runs ONLY when a response arrives). So for a warm-but-stuck call awaiting
|
||||
* `original` alone would hang forever even after the timer aborts.
|
||||
*/
|
||||
export function wrapToolWithCallTimeout(tool: Tool, ms: number): Tool {
|
||||
const original = tool.execute;
|
||||
if (typeof original !== 'function') return tool;
|
||||
const execute = async (args: unknown, options: ToolCallOptions) => {
|
||||
const controller = new AbortController();
|
||||
const timer = setTimeout(() => {
|
||||
controller.abort(new Error(`MCP tool call timed out after ${ms}ms`));
|
||||
}, ms);
|
||||
timer.unref?.();
|
||||
const abortSignal = options?.abortSignal
|
||||
? AbortSignal.any([options.abortSignal, controller.signal])
|
||||
: controller.signal;
|
||||
// Reject as soon as the composed signal fires, independent of whether
|
||||
// `original` ever settles. The losing `original` promise is left pending; it
|
||||
// is cleaned up when the client is closed at turn end, and Promise.race
|
||||
// attaches a rejection handler to BOTH inputs so a late rejection of either
|
||||
// is never an unhandled rejection (do NOT add an extra .catch — it could
|
||||
// swallow the real result and would break the race semantics).
|
||||
const aborted = new Promise<never>((_, reject) => {
|
||||
const fail = () => reject(abortReason(abortSignal));
|
||||
if (abortSignal.aborted) fail();
|
||||
else abortSignal.addEventListener('abort', fail, { once: true });
|
||||
});
|
||||
try {
|
||||
return await Promise.race([
|
||||
original(args, { ...options, abortSignal }),
|
||||
aborted,
|
||||
]);
|
||||
} finally {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
};
|
||||
// `Tool` is a union whose `execute` overloads conflict; cast narrowly so the
|
||||
// wrapped tool keeps every other field while swapping only `execute`.
|
||||
return { ...tool, execute } as unknown as Tool;
|
||||
}
|
||||
|
||||
/** The signal's reason as an Error (informative thrown value on abort/timeout). */
|
||||
function abortReason(signal: AbortSignal): Error {
|
||||
const r = signal.reason;
|
||||
return r instanceof Error
|
||||
? r
|
||||
: new Error(typeof r === 'string' ? r : 'MCP tool call aborted');
|
||||
}
|
||||
|
||||
/** Reject a promise after `ms`, so a hung connect/tools() never stalls a turn. */
|
||||
function withTimeout<T>(promise: Promise<T>, ms: number): Promise<T> {
|
||||
return new Promise<T>((resolve, reject) => {
|
||||
|
||||
Reference in New Issue
Block a user