diff --git a/packages/components/credentials/BaiduQianfanApiKey.credential.ts b/packages/components/credentials/BaiduQianfanApiKey.credential.ts new file mode 100644 index 00000000000..835ce5d9f38 --- /dev/null +++ b/packages/components/credentials/BaiduQianfanApiKey.credential.ts @@ -0,0 +1,23 @@ +import { INodeParams, INodeCredential } from '../src/Interface' + +class BaiduQianfanApiKey implements INodeCredential { + label: string + name: string + version: number + inputs: INodeParams[] + + constructor() { + this.label = 'Baidu Qianfan API Key' + this.name = 'baiduQianfanApiKey' + this.version = 1.0 + this.inputs = [ + { + label: 'Qianfan API Key', + name: 'qianfanApiKey', + type: 'password' + } + ] + } +} + +module.exports = { credClass: BaiduQianfanApiKey } diff --git a/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerank.test.ts b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerank.test.ts new file mode 100644 index 00000000000..640e1099843 --- /dev/null +++ b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerank.test.ts @@ -0,0 +1,96 @@ +import { Document } from '@langchain/core/documents' +import { BaiduQianfanRerank } from './BaiduQianfanRerank' + +const originalFetch = global.fetch +const mockedFetch = jest.fn() + +describe('BaiduQianfanRerank', () => { + beforeEach(() => { + jest.clearAllMocks() + global.fetch = mockedFetch as unknown as typeof fetch + }) + + afterAll(() => { + global.fetch = originalFetch + }) + + it('calls Qianfan rerank API and preserves metadata from ranked indexes', async () => { + mockedFetch.mockResolvedValue({ + ok: true, + json: jest.fn().mockResolvedValue({ + results: [ + { index: 1, document: 'second', relevance_score: 0.92 }, + { index: 0, document: 'first', relevance_score: 0.41 } + ] + }) + }) + + const compressor = new BaiduQianfanRerank('api-key', 'bce-reranker-base', 2) + const documents = [ + new Document({ pageContent: 'first', metadata: { source: 'a' } }), + new Document({ pageContent: 'second', metadata: { source: 'b' } }) + ] + + const result = await compressor.compressDocuments(documents, 'weather in Shanghai') + + expect(mockedFetch).toHaveBeenCalledWith('https://qianfan.baidubce.com/v2/rerank', { + method: 'POST', + headers: { + Authorization: 'Bearer api-key', + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + model: 'bce-reranker-base', + query: 'weather in Shanghai', + documents: ['first', 'second'], + top_n: 2 + }) + }) + expect(result.map((doc) => doc.pageContent)).toEqual(['second', 'first']) + expect(result[0].metadata).toEqual({ source: 'b', relevance_score: 0.92 }) + expect(result[1].metadata).toEqual({ source: 'a', relevance_score: 0.41 }) + }) + + it('returns an empty array without calling Qianfan when no documents are provided', async () => { + const compressor = new BaiduQianfanRerank('api-key', 'bce-reranker-base', 4) + + await expect(compressor.compressDocuments([], 'query')).resolves.toEqual([]) + expect(mockedFetch).not.toHaveBeenCalled() + }) + + it('falls back to the original documents when Qianfan returns an invalid index', async () => { + mockedFetch.mockResolvedValue({ + ok: true, + json: jest.fn().mockResolvedValue({ + results: [{ index: 99, document: 'missing', relevance_score: 0.9 }] + }) + }) + + const compressor = new BaiduQianfanRerank('api-key', 'bce-reranker-base', 4) + const documents = [new Document({ pageContent: 'first', metadata: { source: 'a' } })] + + await expect(compressor.compressDocuments(documents, 'query')).resolves.toBe(documents) + }) + + it('falls back to the original documents when Qianfan returns an API error', async () => { + mockedFetch.mockResolvedValue({ + ok: false, + status: 404, + text: jest.fn().mockResolvedValue('model not found') + }) + + const compressor = new BaiduQianfanRerank('api-key', 'missing-model', 4) + const documents = [new Document({ pageContent: 'first', metadata: { source: 'a' } })] + + await expect(compressor.compressDocuments(documents, 'query')).resolves.toBe(documents) + }) + + it('falls back to the original documents when the Qianfan call fails', async () => { + mockedFetch.mockRejectedValue(new Error('network failed')) + + const compressor = new BaiduQianfanRerank('api-key', 'bce-reranker-base', 4) + const documents = [new Document({ pageContent: 'first', metadata: { source: 'a' } })] + + await expect(compressor.compressDocuments(documents, 'query')).resolves.toBe(documents) + }) +}) diff --git a/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerank.ts b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerank.ts new file mode 100644 index 00000000000..71caa1f19ec --- /dev/null +++ b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerank.ts @@ -0,0 +1,77 @@ +import { Callbacks } from '@langchain/core/callbacks/manager' +import { Document } from '@langchain/core/documents' +import { BaseDocumentCompressor } from '@langchain/classic/retrievers/document_compressors' + +const QIANFAN_RERANK_API_URL = 'https://qianfan.baidubce.com/v2/rerank' + +type QianfanRerankResult = { + index: number + document: string + relevance_score: number +} + +type QianfanRerankResponse = { + results?: QianfanRerankResult[] +} + +export class BaiduQianfanRerank extends BaseDocumentCompressor { + private readonly qianfanApiKey: string + private readonly model: string + private readonly topN: number + + constructor(qianfanApiKey: string, model: string, topN: number) { + super() + this.qianfanApiKey = qianfanApiKey + this.model = model + this.topN = topN + } + + async compressDocuments( + documents: Document>[], + query: string, + _?: Callbacks | undefined + ): Promise>[]> { + if (documents.length === 0) return [] + + try { + const response = await fetch(QIANFAN_RERANK_API_URL, { + method: 'POST', + headers: { + Authorization: `Bearer ${this.qianfanApiKey}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + model: this.model, + query, + documents: documents.map((doc) => doc.pageContent), + top_n: this.topN + }) + }) + + if (!response.ok) throw new Error(`Baidu Qianfan Rerank API call failed with status ${response.status}`) + + const rerankResponse = (await response.json()) as QianfanRerankResponse + + if (!Array.isArray(rerankResponse.results)) return documents + + const rerankedDocuments: Document>[] = [] + for (const result of rerankResponse.results) { + const doc = documents[result.index] + if (!doc) return documents + rerankedDocuments.push( + new Document({ + pageContent: doc.pageContent, + metadata: { + ...doc.metadata, + relevance_score: result.relevance_score + } + }) + ) + } + + return rerankedDocuments + } catch (error) { + return documents + } + } +} diff --git a/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerankRetriever.test.ts b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerankRetriever.test.ts new file mode 100644 index 00000000000..f824939714f --- /dev/null +++ b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerankRetriever.test.ts @@ -0,0 +1,243 @@ +jest.mock('@langchain/classic/retrievers/contextual_compression', () => ({ + ContextualCompressionRetriever: jest.fn().mockImplementation(({ baseCompressor, baseRetriever }) => ({ + baseCompressor, + baseRetriever, + invoke: jest.fn().mockResolvedValue([{ pageContent: 'reranked doc', metadata: { relevance_score: 0.98 } }]) + })) +})) + +jest.mock('../../../src/utils', () => ({ + getCredentialData: jest.fn(), + getCredentialParam: jest.fn(), + handleEscapeCharacters: jest.fn((text: string) => text) +})) + +jest.mock('./BaiduQianfanRerank', () => ({ + BaiduQianfanRerank: jest.fn().mockImplementation((qianfanApiKey, model, topN) => ({ + qianfanApiKey, + model, + topN + })) +})) + +import { ContextualCompressionRetriever } from '@langchain/classic/retrievers/contextual_compression' +import { getCredentialData, getCredentialParam, handleEscapeCharacters } from '../../../src/utils' +import { BaiduQianfanRerank } from './BaiduQianfanRerank' + +const { nodeClass: BaiduQianfanRerankRetriever } = require('./BaiduQianfanRerankRetriever') + +describe('BaiduQianfanRerankRetriever', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('declares Flowise metadata, existing Baidu credential, and built-in model option', () => { + const node = new BaiduQianfanRerankRetriever() + const modelInput = node.inputs.find((input: { name: string }) => input.name === 'modelName') + + expect(node).toMatchObject({ + label: 'Baidu Qianfan Rerank Retriever', + name: 'baiduQianfanRerankRetriever', + type: 'BaiduQianfanRerankRetriever', + category: 'Retrievers', + icon: 'baiduwenxin.svg' + }) + expect(node.credential).toMatchObject({ + name: 'credential', + credentialNames: ['baiduQianfanApiKey', 'baiduQianfanApi'] + }) + expect(modelInput).toMatchObject({ + type: 'options', + default: 'bce-reranker-base', + options: [{ label: 'bce-reranker-base', name: 'bce-reranker-base' }] + }) + }) + + it('creates a contextual compression retriever with Qianfan API key and base retriever k by default', async () => { + ;(getCredentialData as jest.Mock).mockResolvedValue({ + qianfanApiKey: 'api-key', + qianfanAccessKey: 'access-key', + qianfanSecretKey: 'secret-key' + }) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + const baseRetriever = { k: 6 } + + const node = new BaiduQianfanRerankRetriever() + const result = await node.init( + { + credential: 'cred-1', + inputs: { + baseRetriever, + modelName: 'bce-reranker-base' + }, + outputs: { + output: 'retriever' + } + }, + 'user query', + {} + ) + + expect(BaiduQianfanRerank).toHaveBeenCalledWith('api-key', 'bce-reranker-base', 6) + expect(ContextualCompressionRetriever).toHaveBeenCalledWith({ + baseCompressor: expect.objectContaining({ model: 'bce-reranker-base', topN: 6 }), + baseRetriever + }) + expect(result).toMatchObject({ baseRetriever }) + }) + + it('falls back to Qianfan access key when the dedicated API key field is not configured', async () => { + ;(getCredentialData as jest.Mock).mockResolvedValue({ + qianfanAccessKey: 'fallback-api-key' + }) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + + const node = new BaiduQianfanRerankRetriever() + await node.init( + { + credential: 'cred-1', + inputs: { + baseRetriever: { k: 4 }, + modelName: 'bce-reranker-base' + }, + outputs: { + output: 'retriever' + } + }, + 'user query', + {} + ) + + expect(BaiduQianfanRerank).toHaveBeenCalledWith('fallback-api-key', 'bce-reranker-base', 4) + }) + + it('uses custom model names and explicit topN values', async () => { + ;(getCredentialData as jest.Mock).mockResolvedValue({ + qianfanApiKey: 'api-key' + }) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + + const node = new BaiduQianfanRerankRetriever() + await node.init( + { + credential: 'cred-1', + inputs: { + baseRetriever: { k: 10 }, + modelName: 'bce-reranker-base', + customModelName: 'custom-reranker', + topN: '3' + }, + outputs: { + output: 'retriever' + } + }, + 'user query', + {} + ) + + expect(BaiduQianfanRerank).toHaveBeenCalledWith('api-key', 'custom-reranker', 3) + }) + + it('parses topN as an integer', async () => { + ;(getCredentialData as jest.Mock).mockResolvedValue({ + qianfanApiKey: 'api-key' + }) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + + const node = new BaiduQianfanRerankRetriever() + await node.init( + { + credential: 'cred-1', + inputs: { + baseRetriever: { k: 10 }, + modelName: 'bce-reranker-base', + topN: '3.7' + }, + outputs: { + output: 'retriever' + } + }, + 'user query', + {} + ) + + expect(BaiduQianfanRerank).toHaveBeenCalledWith('api-key', 'bce-reranker-base', 3) + }) + + it('throws when the Qianfan API key is missing from credentials', async () => { + ;(getCredentialData as jest.Mock).mockResolvedValue({}) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + + const node = new BaiduQianfanRerankRetriever() + await expect( + node.init( + { + credential: 'cred-1', + inputs: { + baseRetriever: { k: 4 }, + modelName: 'bce-reranker-base' + }, + outputs: { + output: 'retriever' + } + }, + 'user query', + {} + ) + ).rejects.toThrow('Baidu Qianfan API Key is missing in credentials.') + }) + + it('returns document output by invoking the rerank retriever', async () => { + ;(getCredentialData as jest.Mock).mockResolvedValue({ + qianfanApiKey: 'api-key' + }) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + + const node = new BaiduQianfanRerankRetriever() + const result = await node.init( + { + credential: 'cred-1', + inputs: { + baseRetriever: { k: 4 }, + modelName: 'bce-reranker-base', + query: 'override query' + }, + outputs: { + output: 'document' + } + }, + 'input query', + {} + ) + + const retriever = (ContextualCompressionRetriever as unknown as jest.Mock).mock.results[0].value + expect(retriever.invoke).toHaveBeenCalledWith('override query') + expect(result).toEqual([{ pageContent: 'reranked doc', metadata: { relevance_score: 0.98 } }]) + }) + + it('returns text output by concatenating reranked documents', async () => { + ;(getCredentialData as jest.Mock).mockResolvedValue({ + qianfanApiKey: 'api-key' + }) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + + const node = new BaiduQianfanRerankRetriever() + const result = await node.init( + { + credential: 'cred-1', + inputs: { + baseRetriever: { k: 4 }, + modelName: 'bce-reranker-base' + }, + outputs: { + output: 'text' + } + }, + 'input query', + {} + ) + + expect(handleEscapeCharacters).toHaveBeenCalledWith('reranked doc\n', false) + expect(result).toBe('reranked doc\n') + }) +}) diff --git a/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerankRetriever.ts b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerankRetriever.ts new file mode 100644 index 00000000000..4691c98d1aa --- /dev/null +++ b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/BaiduQianfanRerankRetriever.ts @@ -0,0 +1,140 @@ +import { BaseRetriever } from '@langchain/core/retrievers' +import { VectorStoreRetriever } from '@langchain/core/vectorstores' +import { ContextualCompressionRetriever } from '@langchain/classic/retrievers/contextual_compression' +import { getCredentialData, getCredentialParam, handleEscapeCharacters } from '../../../src/utils' +import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' +import { BaiduQianfanRerank } from './BaiduQianfanRerank' + +class BaiduQianfanRerankRetriever_Retrievers implements INode { + label: string + name: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + credential: INodeParams + outputs: INodeOutputsValue[] + + constructor() { + this.label = 'Baidu Qianfan Rerank Retriever' + this.name = 'baiduQianfanRerankRetriever' + this.version = 1.0 + this.type = 'BaiduQianfanRerankRetriever' + this.icon = 'baiduwenxin.svg' + this.category = 'Retrievers' + this.description = 'Baidu Qianfan Rerank indexes the documents from most to least semantically relevant to the query.' + this.baseClasses = [this.type, 'BaseRetriever'] + this.credential = { + label: 'Connect Credential', + name: 'credential', + type: 'credential', + credentialNames: ['baiduQianfanApiKey', 'baiduQianfanApi'] + } + this.inputs = [ + { + label: 'Vector Store Retriever', + name: 'baseRetriever', + type: 'VectorStoreRetriever' + }, + { + label: 'Model Name', + name: 'modelName', + type: 'options', + options: [ + { + label: 'bce-reranker-base', + name: 'bce-reranker-base' + } + ], + default: 'bce-reranker-base', + optional: true + }, + { + label: 'Custom Model Name', + name: 'customModelName', + type: 'string', + placeholder: 'bce-reranker-base', + description: 'Custom model name to use. If provided, it will override the selected model.', + additionalParams: true, + optional: true + }, + { + label: 'Query', + name: 'query', + type: 'string', + description: 'Query to retrieve documents from retriever. If not specified, user question will be used', + optional: true, + acceptVariable: true + }, + { + label: 'Top N', + name: 'topN', + description: 'Number of top results to fetch. Default to the TopK of the Base Retriever', + placeholder: '4', + type: 'number', + additionalParams: true, + optional: true + } + ] + this.outputs = [ + { + label: 'Baidu Qianfan Rerank Retriever', + name: 'retriever', + baseClasses: this.baseClasses + }, + { + label: 'Document', + name: 'document', + description: 'Array of document objects containing metadata and pageContent', + baseClasses: ['Document', 'json'] + }, + { + label: 'Text', + name: 'text', + description: 'Concatenated string from pageContent of documents', + baseClasses: ['string', 'json'] + } + ] + } + + async init(nodeData: INodeData, input: string, options: ICommonObject): Promise { + const baseRetriever = nodeData.inputs?.baseRetriever as BaseRetriever + const modelName = nodeData.inputs?.modelName as string + const customModelName = nodeData.inputs?.customModelName as string + const query = nodeData.inputs?.query as string + const topN = nodeData.inputs?.topN as string + const output = nodeData.outputs?.output as string + + const credentialData = await getCredentialData(nodeData.credential ?? '', options) + const qianfanApiKey = + getCredentialParam('qianfanApiKey', credentialData, nodeData) || + getCredentialParam('qianfanAccessKey', credentialData, nodeData) + if (!qianfanApiKey) { + throw new Error('Baidu Qianfan API Key is missing in credentials.') + } + + const k = topN ? parseInt(topN, 10) : (baseRetriever as VectorStoreRetriever).k ?? 4 + + const qianfanCompressor = new BaiduQianfanRerank(qianfanApiKey, customModelName || modelName, k) + const retriever = new ContextualCompressionRetriever({ + baseCompressor: qianfanCompressor, + baseRetriever + }) + + if (output === 'retriever') return retriever + if (output === 'document') return await retriever.invoke(query ? query : input) + if (output === 'text') { + const docs = await retriever.invoke(query ? query : input) + let finaltext = '' + for (const doc of docs) finaltext += `${doc.pageContent}\n` + return handleEscapeCharacters(finaltext, false) + } + + return retriever + } +} + +module.exports = { nodeClass: BaiduQianfanRerankRetriever_Retrievers } diff --git a/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/baiduwenxin.svg b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/baiduwenxin.svg new file mode 100644 index 00000000000..afe2bc69024 --- /dev/null +++ b/packages/components/nodes/retrievers/BaiduQianfanRerankRetriever/baiduwenxin.svg @@ -0,0 +1,7 @@ + + \ No newline at end of file