diff --git a/packages/components/nodes/chatmodels/ChatOpenAICustom/ChatOpenAICustom.test.ts b/packages/components/nodes/chatmodels/ChatOpenAICustom/ChatOpenAICustom.test.ts new file mode 100644 index 00000000000..d0bf5275299 --- /dev/null +++ b/packages/components/nodes/chatmodels/ChatOpenAICustom/ChatOpenAICustom.test.ts @@ -0,0 +1,80 @@ +jest.mock('@langchain/openai', () => ({ + ChatOpenAI: jest.fn().mockImplementation((fields) => ({ fields })) +})) + +jest.mock('../../../src/utils', () => ({ + getBaseClasses: jest.fn().mockReturnValue(['BaseChatModel']), + getCredentialData: jest.fn(), + getCredentialParam: jest.fn() +})) + +import { getCredentialData, getCredentialParam } from '../../../src/utils' + +const { nodeClass: ChatOpenAICustom } = require('./ChatOpenAICustom') + +describe('ChatOpenAICustom', () => { + beforeEach(() => { + jest.clearAllMocks() + ;(getCredentialData as jest.Mock).mockResolvedValue({ openAIApiKey: 'test-api-key' }) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + }) + + it('exposes stopSequence as an additional parameter', () => { + const node = new ChatOpenAICustom() + + expect(node.inputs).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + label: 'Stop Sequence', + name: 'stopSequence', + additionalParams: true + }) + ]) + ) + }) + + it('passes comma-separated stop sequences to ChatOpenAI', async () => { + const node = new ChatOpenAICustom() + const model = await node.init( + { + credential: 'cred-1', + inputs: { + modelName: 'custom-model', + temperature: '0.3', + stopSequence: '<|im_end|>, END', + streaming: false + } + }, + '', + {} + ) + + expect(model.fields).toMatchObject({ + modelName: 'custom-model', + openAIApiKey: 'test-api-key', + apiKey: 'test-api-key', + temperature: 0.3, + streaming: false, + stop: ['<|im_end|>', 'END'] + }) + }) + + it('ignores empty stop sequence entries', async () => { + const node = new ChatOpenAICustom() + const model = await node.init( + { + credential: 'cred-1', + inputs: { + modelName: 'custom-model', + temperature: '0.3', + stopSequence: 'foo,, bar, ', + streaming: false + } + }, + '', + {} + ) + + expect(model.fields.stop).toEqual(['foo', 'bar']) + }) +}) diff --git a/packages/components/nodes/chatmodels/ChatOpenAICustom/ChatOpenAICustom.ts b/packages/components/nodes/chatmodels/ChatOpenAICustom/ChatOpenAICustom.ts index faf2ebd8b37..21981bd9190 100644 --- a/packages/components/nodes/chatmodels/ChatOpenAICustom/ChatOpenAICustom.ts +++ b/packages/components/nodes/chatmodels/ChatOpenAICustom/ChatOpenAICustom.ts @@ -18,7 +18,7 @@ class ChatOpenAICustom_ChatModels implements INode { constructor() { this.label = 'OpenAI Custom Model' this.name = 'chatOpenAICustom' - this.version = 4.0 + this.version = 4.1 this.type = 'ChatOpenAI-Custom' this.icon = 'openai.svg' this.category = 'Chat Models' @@ -92,6 +92,15 @@ class ChatOpenAICustom_ChatModels implements INode { optional: true, additionalParams: true }, + { + label: 'Stop Sequence', + name: 'stopSequence', + type: 'string', + rows: 4, + optional: true, + description: 'List of stop words to use when generating. Use comma to separate multiple stop words.', + additionalParams: true + }, { label: 'Timeout', name: 'timeout', @@ -126,6 +135,7 @@ class ChatOpenAICustom_ChatModels implements INode { const topP = nodeData.inputs?.topP as string const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string const presencePenalty = nodeData.inputs?.presencePenalty as string + const stopSequence = nodeData.inputs?.stopSequence as string const timeout = nodeData.inputs?.timeout as string const streaming = nodeData.inputs?.streaming as boolean const basePath = nodeData.inputs?.basepath as string @@ -147,6 +157,13 @@ class ChatOpenAICustom_ChatModels implements INode { if (topP) obj.topP = parseFloat(topP) if (frequencyPenalty) obj.frequencyPenalty = parseFloat(frequencyPenalty) if (presencePenalty) obj.presencePenalty = parseFloat(presencePenalty) + if (stopSequence) { + const stopSequenceArray = stopSequence + .split(',') + .map((item) => item.trim()) + .filter((item) => item !== '') + if (stopSequenceArray.length > 0) obj.stop = stopSequenceArray + } if (timeout) obj.timeout = parseInt(timeout, 10) if (cache) obj.cache = cache