Skip to content

Commit b27e889

Browse files
authored
feat(components): improve Baidu Wenxin chat model configuration (#6140)
1 parent 0e52877 commit b27e889

3 files changed

Lines changed: 149 additions & 5 deletions

File tree

packages/components/models.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,27 @@
815815
}
816816
]
817817
},
818+
{
819+
"name": "chatBaiduWenxin",
820+
"models": [
821+
{
822+
"label": "ernie-4.5-8k-preview",
823+
"name": "ernie-4.5-8k-preview"
824+
},
825+
{
826+
"label": "ernie-4.0-8k",
827+
"name": "ernie-4.0-8k"
828+
},
829+
{
830+
"label": "ernie-3.5-8k-preview",
831+
"name": "ernie-3.5-8k-preview"
832+
},
833+
{
834+
"label": "ernie-speed-128k",
835+
"name": "ernie-speed-128k"
836+
}
837+
]
838+
},
818839
{
819840
"name": "chatAlibabaTongyi",
820841
"models": [
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
jest.mock('@langchain/baidu-qianfan', () => ({
2+
ChatBaiduQianfan: jest.fn().mockImplementation((fields) => ({ fields }))
3+
}))
4+
5+
jest.mock('../../../src/utils', () => ({
6+
getBaseClasses: jest.fn().mockReturnValue(['BaseChatModel']),
7+
getCredentialData: jest.fn(),
8+
getCredentialParam: jest.fn()
9+
}))
10+
11+
jest.mock('../../../src/modelLoader', () => ({
12+
MODEL_TYPE: { CHAT: 'chat' },
13+
getModels: jest.fn()
14+
}))
15+
16+
import { getCredentialData, getCredentialParam } from '../../../src/utils'
17+
import { getModels } from '../../../src/modelLoader'
18+
19+
const { nodeClass: ChatBaiduWenxin } = require('./ChatBaiduWenxin')
20+
21+
describe('ChatBaiduWenxin', () => {
22+
beforeEach(() => {
23+
jest.clearAllMocks()
24+
})
25+
26+
it('loads model options from the shared model loader', async () => {
27+
;(getModels as jest.Mock).mockResolvedValue([{ label: 'ernie-4.5-8k-preview', name: 'ernie-4.5-8k-preview' }])
28+
29+
const node = new ChatBaiduWenxin()
30+
const models = await node.loadMethods.listModels()
31+
32+
expect(getModels).toHaveBeenCalledWith('chat', 'chatBaiduWenxin')
33+
expect(models).toEqual([{ label: 'ernie-4.5-8k-preview', name: 'ernie-4.5-8k-preview' }])
34+
})
35+
36+
it('passes advanced settings and custom model names to ChatBaiduQianfan', async () => {
37+
;(getCredentialData as jest.Mock).mockResolvedValue({
38+
qianfanAccessKey: 'access-key',
39+
qianfanSecretKey: 'secret-key'
40+
})
41+
;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key])
42+
43+
const node = new ChatBaiduWenxin()
44+
const model = await node.init(
45+
{
46+
credential: 'cred-1',
47+
inputs: {
48+
modelName: 'ernie-4.0-8k',
49+
customModelName: 'ernie-speed-128k',
50+
temperature: '0.2',
51+
streaming: false,
52+
topP: '0.8',
53+
penaltyScore: '1.4',
54+
userId: 'user-123'
55+
}
56+
},
57+
'',
58+
{}
59+
)
60+
61+
expect(model.fields).toMatchObject({
62+
qianfanAccessKey: 'access-key',
63+
qianfanSecretKey: 'secret-key',
64+
modelName: 'ernie-speed-128k',
65+
temperature: 0.2,
66+
streaming: false,
67+
topP: 0.8,
68+
penaltyScore: 1.4,
69+
userId: 'user-123'
70+
})
71+
})
72+
})

packages/components/nodes/chatmodels/ChatBaiduWenxin/ChatBaiduWenxin.ts

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { BaseCache } from '@langchain/core/caches'
22
import { ChatBaiduQianfan } from '@langchain/baidu-qianfan'
3-
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
3+
import { ICommonObject, INode, INodeData, INodeOptionsValue, INodeParams } from '../../../src/Interface'
4+
import { MODEL_TYPE, getModels } from '../../../src/modelLoader'
45
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
56

67
class ChatBaiduWenxin_ChatModels implements INode {
@@ -18,7 +19,7 @@ class ChatBaiduWenxin_ChatModels implements INode {
1819
constructor() {
1920
this.label = 'Baidu Wenxin'
2021
this.name = 'chatBaiduWenxin'
21-
this.version = 2.0
22+
this.version = 3.0
2223
this.type = 'ChatBaiduWenxin'
2324
this.icon = 'baiduwenxin.svg'
2425
this.category = 'Chat Models'
@@ -38,10 +39,20 @@ class ChatBaiduWenxin_ChatModels implements INode {
3839
optional: true
3940
},
4041
{
41-
label: 'Model',
42+
label: 'Model Name',
4243
name: 'modelName',
44+
type: 'asyncOptions',
45+
loadMethod: 'listModels',
46+
default: 'ernie-4.5-8k-preview'
47+
},
48+
{
49+
label: 'Custom Model Name',
50+
name: 'customModelName',
4351
type: 'string',
44-
placeholder: 'ERNIE-Bot-turbo'
52+
placeholder: 'ernie-speed-128k',
53+
description: 'Custom model name to use. If provided, it will override the selected model.',
54+
additionalParams: true,
55+
optional: true
4556
},
4657
{
4758
label: 'Temperature',
@@ -57,15 +68,52 @@ class ChatBaiduWenxin_ChatModels implements INode {
5768
type: 'boolean',
5869
default: true,
5970
optional: true
71+
},
72+
{
73+
label: 'Top Probability',
74+
name: 'topP',
75+
type: 'number',
76+
description: 'Nucleus sampling. The model considers tokens whose cumulative probability mass reaches this value.',
77+
step: 0.1,
78+
optional: true,
79+
additionalParams: true
80+
},
81+
{
82+
label: 'Penalty Score',
83+
name: 'penaltyScore',
84+
type: 'number',
85+
description: 'Penalizes repeated tokens according to frequency. Baidu Qianfan accepts values from 1.0 to 2.0.',
86+
step: 0.1,
87+
optional: true,
88+
additionalParams: true
89+
},
90+
{
91+
label: 'User ID',
92+
name: 'userId',
93+
type: 'string',
94+
description: 'Optional unique identifier for the end user making the request.',
95+
optional: true,
96+
additionalParams: true
6097
}
6198
]
6299
}
63100

101+
//@ts-ignore
102+
loadMethods = {
103+
async listModels(): Promise<INodeOptionsValue[]> {
104+
return await getModels(MODEL_TYPE.CHAT, 'chatBaiduWenxin')
105+
}
106+
}
107+
64108
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
65109
const cache = nodeData.inputs?.cache as BaseCache
66110
const temperature = nodeData.inputs?.temperature as string
67111
const modelName = nodeData.inputs?.modelName as string
112+
const customModelName = nodeData.inputs?.customModelName as string
68113
const streaming = nodeData.inputs?.streaming as boolean
114+
const topP = nodeData.inputs?.topP as string
115+
const penaltyScore = nodeData.inputs?.penaltyScore as string
116+
const userId = nodeData.inputs?.userId as string
69117

70118
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
71119
const qianfanAccessKey = getCredentialParam('qianfanAccessKey', credentialData, nodeData)
@@ -75,9 +123,12 @@ class ChatBaiduWenxin_ChatModels implements INode {
75123
streaming: streaming ?? true,
76124
qianfanAccessKey,
77125
qianfanSecretKey,
78-
modelName,
126+
modelName: customModelName || modelName,
79127
temperature: temperature ? parseFloat(temperature) : undefined
80128
}
129+
if (topP) obj.topP = parseFloat(topP)
130+
if (penaltyScore) obj.penaltyScore = parseFloat(penaltyScore)
131+
if (userId) obj.userId = userId
81132
if (cache) obj.cache = cache
82133

83134
const model = new ChatBaiduQianfan(obj)

0 commit comments

Comments
 (0)