Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions packages/components/nodes/agentflow/Agent/Agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import {
revertBase64ImagesToFileRefs,
normalizeMessagesForStorage,
replaceInlineDataWithFileReferences,
updateFlowState
updateFlowState,
createTokenCounter
} from '../utils'
import {
convertMultiOptionsToStringArray,
Expand Down Expand Up @@ -1806,10 +1807,11 @@ class Agent_Agentflow implements INode {
abortController: AbortController
): Promise<void> {
const maxTokenLimit = (nodeData.inputs?.agentMemoryMaxTokenLimit as number) || 2000
const countTokens = createTokenCounter(llmWithoutToolsBind)

// Convert past messages to a format suitable for token counting
const messagesString = pastMessages.map((msg: any) => `${msg.role}: ${msg.content}`).join('\n')
const tokenCount = await llmWithoutToolsBind.getNumTokens(messagesString)
const tokenCount = await countTokens(messagesString)

if (tokenCount > maxTokenLimit) {
// Calculate how many messages to summarize (messages that exceed the token limit)
Expand All @@ -1824,7 +1826,7 @@ class Agent_Agentflow implements INode {
messagesToSummarize.push(poppedMessage)
// Recalculate token count for remaining messages
const remainingMessagesString = remainingMessages.map((msg: any) => `${msg.role}: ${msg.content}`).join('\n')
currBufferLength = await llmWithoutToolsBind.getNumTokens(remainingMessagesString)
currBufferLength = await countTokens(remainingMessagesString)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import { AnalyticHandler } from '../../../src/handler'
import { ICommonObject, IMessage, INode, INodeData, INodeOptionsValue, INodeOutputsValue, INodeParams } from '../../../src/Interface'
import { AIMessageChunk, BaseMessageLike } from '@langchain/core/messages'
import { getPastChatHistoryImageMessages, getUniqueImageMessages, processMessagesWithImages, revertBase64ImagesToFileRefs } from '../utils'
import {
createTokenCounter,
getPastChatHistoryImageMessages,
getUniqueImageMessages,
processMessagesWithImages,
revertBase64ImagesToFileRefs
} from '../utils'
import { CONDITION_AGENT_SYSTEM_PROMPT, DEFAULT_SUMMARIZER_TEMPLATE } from '../prompt'
import { BaseChatModel } from '@langchain/core/language_models/chat_models'
import { findBestScenarioIndex } from './matchScenario'
Expand Down Expand Up @@ -608,10 +614,11 @@ class ConditionAgent_Agentflow implements INode {
abortController: AbortController
): Promise<void> {
const maxTokenLimit = (nodeData.inputs?.conditionAgentMemoryMaxTokenLimit as number) || 2000
const countTokens = createTokenCounter(llmNodeInstance)

// Convert past messages to a format suitable for token counting
const messagesString = pastMessages.map((msg: any) => `${msg.role}: ${msg.content}`).join('\n')
const tokenCount = await llmNodeInstance.getNumTokens(messagesString)
const tokenCount = await countTokens(messagesString)

if (tokenCount > maxTokenLimit) {
// Calculate how many messages to summarize (messages that exceed the token limit)
Expand All @@ -626,7 +633,7 @@ class ConditionAgent_Agentflow implements INode {
messagesToSummarize.push(poppedMessage)
// Recalculate token count for remaining messages
const remainingMessagesString = remainingMessages.map((msg: any) => `${msg.role}: ${msg.content}`).join('\n')
currBufferLength = await llmNodeInstance.getNumTokens(remainingMessagesString)
currBufferLength = await countTokens(remainingMessagesString)
}
}

Expand Down
8 changes: 5 additions & 3 deletions packages/components/nodes/agentflow/LLM/LLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import {
processMessagesWithImages,
revertBase64ImagesToFileRefs,
replaceInlineDataWithFileReferences,
updateFlowState
updateFlowState,
createTokenCounter
} from '../utils'
import { processTemplateVariables, configureStructuredOutput, extractResponseContent } from '../../../src/utils'
import { getModelConfigByModelName, MODEL_TYPE } from '../../../src/modelLoader'
Expand Down Expand Up @@ -797,10 +798,11 @@ class LLM_Agentflow implements INode {
abortController: AbortController
): Promise<void> {
const maxTokenLimit = (nodeData.inputs?.llmMemoryMaxTokenLimit as number) || 2000
const countTokens = createTokenCounter(llmNodeInstance)

// Convert past messages to a format suitable for token counting
const messagesString = pastMessages.map((msg: any) => `${msg.role}: ${msg.content}`).join('\n')
const tokenCount = await llmNodeInstance.getNumTokens(messagesString)
const tokenCount = await countTokens(messagesString)

if (tokenCount > maxTokenLimit) {
// Calculate how many messages to summarize (messages that exceed the token limit)
Expand All @@ -815,7 +817,7 @@ class LLM_Agentflow implements INode {
messagesToSummarize.push(poppedMessage)
// Recalculate token count for remaining messages
const remainingMessagesString = remainingMessages.map((msg: any) => `${msg.role}: ${msg.content}`).join('\n')
currBufferLength = await llmNodeInstance.getNumTokens(remainingMessagesString)
currBufferLength = await countTokens(remainingMessagesString)
}
}

Expand Down
91 changes: 90 additions & 1 deletion packages/components/nodes/agentflow/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import { revertBase64ImagesToFileRefs, processMessagesWithImages, addImageArtifactsToMessages, getUniqueImageMessages } from './utils'
import {
revertBase64ImagesToFileRefs,
processMessagesWithImages,
addImageArtifactsToMessages,
getUniqueImageMessages,
createTokenCounter,
getApproximateTokenCount
} from './utils'
import { sanitizeFileName } from '../../src/validator'
import { IChatMessage, IMultimodalContentItem } from './Interface.Agentflow'
import { IFileUpload } from '../../src/Interface'
Expand Down Expand Up @@ -657,3 +664,85 @@ describe('path traversal prevention in image processing', () => {
expect(fileRef?.fileName).not.toContain('..')
})
})

describe('createTokenCounter', () => {
const originalEnv = process.env
let warnSpy: jest.SpyInstance

beforeEach(() => {
jest.resetModules()
process.env = { ...originalEnv }
delete process.env.DISABLE_TIKTOKEN
delete process.env.USE_APPROXIMATE_TOKENS
delete process.env.TIKTOKEN_TIMEOUT
warnSpy = jest.spyOn(console, 'warn').mockImplementation(() => undefined)
})

afterEach(() => {
warnSpy.mockRestore()
process.env = originalEnv
})

it('uses the model token counter when it resolves before timeout', async () => {
const llm = { getNumTokens: jest.fn().mockResolvedValue(42) }
const countTokens = createTokenCounter(llm)

await expect(countTokens('hello world')).resolves.toBe(42)
expect(llm.getNumTokens).toHaveBeenCalledWith('hello world')
expect(warnSpy).not.toHaveBeenCalled()
})

it('can use approximate token counts without calling the model counter', async () => {
process.env.DISABLE_TIKTOKEN = 'true'
const llm = { getNumTokens: jest.fn().mockResolvedValue(42) }
const countTokens = createTokenCounter(llm)

await expect(countTokens('12345678')).resolves.toBe(2)
expect(llm.getNumTokens).not.toHaveBeenCalled()
})

it('falls back to approximate counts after a tokenizer error', async () => {
const llm = { getNumTokens: jest.fn().mockRejectedValue(new Error('fetch failed')) }
const countTokens = createTokenCounter(llm)

await expect(countTokens('12345678')).resolves.toBe(2)
await expect(countTokens('123456789012')).resolves.toBe(3)
expect(llm.getNumTokens).toHaveBeenCalledTimes(1)
expect(warnSpy).toHaveBeenCalledTimes(1)
})

it('clears the timeout when token counting throws synchronously', async () => {
const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout')
const llm = {
getNumTokens: jest.fn(() => {
throw new Error('sync tokenizer failure')
})
}
const countTokens = createTokenCounter(llm)

await expect(countTokens('12345678')).resolves.toBe(2)
expect(llm.getNumTokens).toHaveBeenCalledTimes(1)
expect(clearTimeoutSpy).toHaveBeenCalledTimes(1)
expect(warnSpy).toHaveBeenCalledTimes(1)

clearTimeoutSpy.mockRestore()
})

it('uses approximate counts when the model token counter is unavailable', async () => {
const countTokens = createTokenCounter(undefined)

await expect(countTokens('12345678')).resolves.toBe(2)
expect(warnSpy).not.toHaveBeenCalled()
})

it('falls back to approximate counts after token counting times out', async () => {
process.env.TIKTOKEN_TIMEOUT = '1'
const llm = { getNumTokens: jest.fn(() => new Promise<number>(() => undefined)) }
const countTokens = createTokenCounter(llm)

await expect(countTokens('123456789')).resolves.toBe(getApproximateTokenCount('123456789'))
await expect(countTokens('1234')).resolves.toBe(1)
expect(llm.getNumTokens).toHaveBeenCalledTimes(1)
expect(warnSpy).toHaveBeenCalledTimes(1)
})
})
62 changes: 62 additions & 0 deletions packages/components/nodes/agentflow/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,70 @@ const ARTIFACT_TYPES: Record<string, string> = {
pdf: 'text'
}

const DEFAULT_TOKEN_COUNT_TIMEOUT_MS = 1500

type TokenCountingModel = {
getNumTokens(text: string): Promise<number>
}

// ─── Shared helpers (used across multiple functions) ─────────────────────────

const isTruthyEnv = (value?: string): boolean => ['1', 'true', 'yes'].includes((value || '').toLowerCase())

const getTokenCountTimeoutMs = (): number => {
const timeout = Number(process.env.TIKTOKEN_TIMEOUT)
return Number.isFinite(timeout) && timeout > 0 ? timeout : DEFAULT_TOKEN_COUNT_TIMEOUT_MS
}

const getNumTokensWithTimeout = async (llm: TokenCountingModel, text: string, timeoutMs: number): Promise<number> => {
return new Promise((resolve, reject) => {
const timeout = setTimeout(() => reject(new Error(`Token counting timed out after ${timeoutMs}ms`)), timeoutMs)

let tokenCountPromise: Promise<number>
try {
tokenCountPromise = Promise.resolve(llm.getNumTokens(text))
} catch (error) {
clearTimeout(timeout)
reject(error)
return
}

tokenCountPromise.then(
(count) => {
clearTimeout(timeout)
resolve(count)
},
(error) => {
clearTimeout(timeout)
reject(error)
}
)
})
}
Comment thread
vicksiyi marked this conversation as resolved.

export const getApproximateTokenCount = (text: string): number => Math.ceil((text || '').length / 4)

export const createTokenCounter = (llm?: TokenCountingModel | null): ((text: string) => Promise<number>) => {
let useApproximateCount =
isTruthyEnv(process.env.DISABLE_TIKTOKEN) ||
isTruthyEnv(process.env.USE_APPROXIMATE_TOKENS) ||
typeof llm?.getNumTokens !== 'function'

return async (text: string): Promise<number> => {
if (useApproximateCount) {
return getApproximateTokenCount(text)
}

try {
return await getNumTokensWithTimeout(llm, text, getTokenCountTimeoutMs())
} catch (error) {
useApproximateCount = true
console.warn('Failed to calculate number of tokens, falling back to approximate count', error)
return getApproximateTokenCount(text)
}
}
}

/** Reads a file from storage and returns a base64 data-URL string. */
const storedFileToBase64 = async (fileName: string, mime: string, options: ICommonObject): Promise<string> => {
const contents = await getFileFromStorage(fileName, options.orgId, options.chatflowid, options.chatId)
Expand Down