diff --git a/apps/server/src/core/ai-chat/public-share-chat.service.ts b/apps/server/src/core/ai-chat/public-share-chat.service.ts index e0ac5282..a98e738f 100644 --- a/apps/server/src/core/ai-chat/public-share-chat.service.ts +++ b/apps/server/src/core/ai-chat/public-share-chat.service.ts @@ -176,6 +176,29 @@ export class PublicShareChatService { return this.tokenBudget.record(workspaceId, tokens); } + /** + * `streamText` onFinish hook body: account a finished turn's REAL token spend + * (input re-sent per step + output, summed across all steps) against the + * per-workspace rolling-day budget, so a future turn over budget is rejected up + * front (issue #159, finding #5). `totalUsage` fields are `number | undefined`; + * fall back to the sum of input+output when the provider omits `totalTokens`. + * Fire-and-forget: the turn already streamed, so a record failure must not + * break it. + */ + recordTurnUsage( + workspaceId: string, + totalUsage: { + totalTokens?: number; + inputTokens?: number; + outputTokens?: number; + }, + ): void { + const tokens = + totalUsage.totalTokens ?? + (totalUsage.inputTokens ?? 0) + (totalUsage.outputTokens ?? 0); + void this.recordShareTokens(workspaceId, tokens); + } + /** * Resolve the admin-selected agent role for the anonymous public-share * assistant, scoped to the workspace and soft-delete aware. Returns null when @@ -263,18 +286,8 @@ export class PublicShareChatService { // bill even if the per-IP throttle is evaded; worst case = steps × this. maxOutputTokens: resolveShareAiMaxOutputTokens(), abortSignal: signal, - onFinish: ({ totalUsage }) => { - // Account the turn's REAL token spend (input re-sent per step + output, - // summed across all steps) against the per-workspace rolling-day budget - // so a future turn over budget is rejected up front (issue #159 #5). - // totalUsage fields are `number | undefined`; fall back to the sum of - // input+output when the provider omits totalTokens. Fire-and-forget: - // the turn already streamed, so a record failure must not break it. - const u = totalUsage ?? ({} as typeof totalUsage); - const tokens = - u?.totalTokens ?? (u?.inputTokens ?? 0) + (u?.outputTokens ?? 0); - void this.recordShareTokens(workspaceId, tokens); - }, + onFinish: ({ totalUsage }) => + this.recordTurnUsage(workspaceId, totalUsage), onError: ({ error }) => { // Reuse the shared formatter so provider error formatting stays // unified (statusCode + body) with the authenticated path. diff --git a/apps/server/src/core/ai-chat/public-share-chat.spec.ts b/apps/server/src/core/ai-chat/public-share-chat.spec.ts index 3232e631..f65058d9 100644 --- a/apps/server/src/core/ai-chat/public-share-chat.spec.ts +++ b/apps/server/src/core/ai-chat/public-share-chat.spec.ts @@ -728,6 +728,49 @@ describe('PublicShareChatService.withinShareTokenBudget / recordShareTokens', () }); }); +describe('PublicShareChatService.recordTurnUsage (streamText onFinish accounting)', () => { + function makeService() { + const redisService = { getOrThrow: () => new FakeTokenRedis() } as never; + const service = new PublicShareChatService( + {} as never, + {} as never, + {} as never, + redisService, + {} as never, + ); + const recordSpy = jest + .spyOn(service, 'recordShareTokens') + .mockResolvedValue(undefined); + return { service, recordSpy }; + } + + it('sums input+output when the provider omits totalTokens', () => { + const { service, recordSpy } = makeService(); + // The onFinish payload shape: a totalUsage with per-component counts but no + // authoritative total (provider omitted it). + service.recordTurnUsage('ws-1', { inputTokens: 1200, outputTokens: 300 }); + expect(recordSpy).toHaveBeenCalledWith('ws-1', 1500); + }); + + it('treats missing input/output components as 0 in the fallback sum', () => { + const { service, recordSpy } = makeService(); + service.recordTurnUsage('ws-1', { outputTokens: 42 }); + expect(recordSpy).toHaveBeenCalledWith('ws-1', 42); + }); + + it('prefers the authoritative totalTokens when present (not the sum)', () => { + const { service, recordSpy } = makeService(); + // totalTokens is the provider's authoritative figure and may differ from a + // naive input+output sum (e.g. cached/ reasoning tokens); it must win. + service.recordTurnUsage('ws-1', { + totalTokens: 5000, + inputTokens: 1200, + outputTokens: 300, + }); + expect(recordSpy).toHaveBeenCalledWith('ws-1', 5000); + }); +}); + describe('PublicShareChatService.tryConsumeWorkspaceQuota', () => { it('delegates to the redis-backed per-workspace limiter', async () => { const redis = new FakeRedis();