1- // Ollama-powered vector embedding engine with cosine similarity search
1+ // Multi-provider vector embedding engine with cosine similarity search
2+ // Supports Ollama (local) and OpenAI-compatible APIs (Gemini, OpenAI, etc.)
23// Indexes file headers and symbols, caches embeddings to disk for speed
34
4- import { Ollama } from "ollama" ;
55import { readFile , writeFile , mkdir } from "fs/promises" ;
66import { join } from "path" ;
77
@@ -74,7 +74,11 @@ export interface EmbeddingCache {
7474 [ path : string ] : { hash : string ; vector : number [ ] } ;
7575}
7676
77+ const EMBED_PROVIDER = ( process . env . CONTEXTPLUS_EMBED_PROVIDER ?? "ollama" ) . toLowerCase ( ) ;
7778const EMBED_MODEL = process . env . OLLAMA_EMBED_MODEL ?? "nomic-embed-text" ;
79+ const OPENAI_EMBED_MODEL = process . env . CONTEXTPLUS_OPENAI_EMBED_MODEL ?? process . env . OPENAI_EMBED_MODEL ?? "text-embedding-3-small" ;
80+ const OPENAI_API_KEY = process . env . CONTEXTPLUS_OPENAI_API_KEY ?? process . env . OPENAI_API_KEY ?? "" ;
81+ const OPENAI_BASE_URL = process . env . CONTEXTPLUS_OPENAI_BASE_URL ?? process . env . OPENAI_BASE_URL ?? "https://api.openai.com/v1" ;
7882const CACHE_DIR = ".mcp_data" ;
7983const CACHE_FILE = "embeddings-cache.json" ;
8084const MIN_EMBED_BATCH_SIZE = 5 ;
@@ -87,7 +91,53 @@ const MIN_EMBED_CHUNK_CHARS = 256;
8791const DEFAULT_EMBED_CHUNK_CHARS = 2000 ;
8892const MAX_EMBED_CHUNK_CHARS = 8000 ;
8993
90- const ollama = new Ollama ( { host : process . env . OLLAMA_HOST } ) ;
94+ type OllamaEmbedClient = { embed : ( params : Record < string , unknown > ) => Promise < { embeddings : number [ ] [ ] } > } ;
95+ let ollamaClient : OllamaEmbedClient | null = null ;
96+
97+ async function getOllamaClient ( ) : Promise < OllamaEmbedClient > {
98+ if ( ! ollamaClient ) {
99+ const { Ollama } = await import ( "ollama" ) ;
100+ ollamaClient = new Ollama ( { host : process . env . OLLAMA_HOST } ) as unknown as OllamaEmbedClient ;
101+ }
102+ return ollamaClient ;
103+ }
104+
105+ async function callOllamaEmbed ( input : string [ ] , signal : AbortSignal ) : Promise < number [ ] [ ] > {
106+ const client = await getOllamaClient ( ) ;
107+ const options = getEmbedRuntimeOptions ( ) ;
108+ const request : Record < string , unknown > = { model : EMBED_MODEL , input, signal } ;
109+ if ( options ) request . options = options ;
110+ const response = await client . embed ( request ) ;
111+ return response . embeddings ;
112+ }
113+
114+ async function callOpenAIEmbed ( input : string [ ] , signal : AbortSignal ) : Promise < number [ ] [ ] > {
115+ const url = `${ OPENAI_BASE_URL . replace ( / \/ + $ / , "" ) } /embeddings` ;
116+ const response = await fetch ( url , {
117+ method : "POST" ,
118+ headers : {
119+ "Content-Type" : "application/json" ,
120+ "Authorization" : `Bearer ${ OPENAI_API_KEY } ` ,
121+ } ,
122+ body : JSON . stringify ( { model : OPENAI_EMBED_MODEL , input } ) ,
123+ signal,
124+ } ) ;
125+
126+ if ( ! response . ok ) {
127+ const body = await response . text ( ) . catch ( ( ) => "" ) ;
128+ throw new Error ( `OpenAI embed API error ${ response . status } : ${ body } ` ) ;
129+ }
130+
131+ const data = await response . json ( ) as { data : { embedding : number [ ] } [ ] } ;
132+ return data . data . map ( ( item ) => item . embedding ) ;
133+ }
134+
135+ async function callProviderEmbed ( input : string [ ] , signal : AbortSignal ) : Promise < number [ ] [ ] > {
136+ if ( EMBED_PROVIDER === "openai" ) {
137+ return callOpenAIEmbed ( input , signal ) ;
138+ }
139+ return callOllamaEmbed ( input , signal ) ;
140+ }
91141
92142function toIntegerOr ( value : string | undefined , fallback : number ) : number {
93143 if ( ! value ) return fallback ;
@@ -110,6 +160,7 @@ function toOptionalBoolean(value: string | undefined): boolean | undefined {
110160}
111161
112162function getEmbedRuntimeOptions ( ) : EmbedRuntimeOptions | undefined {
163+ if ( EMBED_PROVIDER === "openai" ) return undefined ;
113164 const options : EmbedRuntimeOptions = {
114165 num_gpu : toOptionalInteger ( process . env . CONTEXTPLUS_EMBED_NUM_GPU ) ,
115166 main_gpu : toOptionalInteger ( process . env . CONTEXTPLUS_EMBED_MAIN_GPU ) ,
@@ -123,17 +174,6 @@ function getEmbedRuntimeOptions(): EmbedRuntimeOptions | undefined {
123174 return options ;
124175}
125176
126- function buildEmbedRequest ( input : string [ ] ) : { model : string ; input : string [ ] ; options ?: EmbedRuntimeOptions } {
127- const options = getEmbedRuntimeOptions ( ) ;
128- return options ? { model : EMBED_MODEL , input, options } : { model : EMBED_MODEL , input } ;
129- }
130-
131- async function embedWithTimeout ( request : ReturnType < typeof buildEmbedRequest > ) : Promise < { embeddings : number [ ] [ ] } > {
132- const timeoutCtrl = AbortSignal . timeout ( EMBED_TIMEOUT_MS ) ;
133- const signal = AbortSignal . any ( [ embedAbortController . signal , timeoutCtrl ] ) ;
134- return ollama . embed ( { ...request , signal } as Parameters < typeof ollama . embed > [ 0 ] ) ;
135- }
136-
137177export function getEmbeddingBatchSize ( ) : number {
138178 const requested = toIntegerOr ( process . env . CONTEXTPLUS_EMBED_BATCH_SIZE , DEFAULT_EMBED_BATCH_SIZE ) ;
139179 return Math . min ( MAX_EMBED_BATCH_SIZE , Math . max ( MIN_EMBED_BATCH_SIZE , requested ) ) ;
@@ -152,7 +192,8 @@ function getErrorMessage(error: unknown): string {
152192function isContextLengthError ( error : unknown ) : boolean {
153193 const message = getErrorMessage ( error ) . toLowerCase ( ) ;
154194 return message . includes ( "input length exceeds context length" )
155- || ( message . includes ( "context" ) && message . includes ( "exceed" ) ) ;
195+ || ( message . includes ( "context" ) && message . includes ( "exceed" ) )
196+ || message . includes ( "maximum context length" ) ;
156197}
157198
158199function shrinkEmbeddingInput ( input : string ) : string {
@@ -167,9 +208,11 @@ async function embedSingleAdaptive(input: string): Promise<number[]> {
167208
168209 for ( let attempt = 0 ; attempt <= MAX_SINGLE_INPUT_RETRIES ; attempt ++ ) {
169210 try {
170- const response = await embedWithTimeout ( buildEmbedRequest ( [ candidate ] ) ) ;
171- if ( ! response . embeddings [ 0 ] ) throw new Error ( "Missing embedding vector in Ollama response" ) ;
172- return response . embeddings [ 0 ] ;
211+ const timeoutCtrl = AbortSignal . timeout ( EMBED_TIMEOUT_MS ) ;
212+ const signal = AbortSignal . any ( [ embedAbortController . signal , timeoutCtrl ] ) ;
213+ const embeddings = await callProviderEmbed ( [ candidate ] , signal ) ;
214+ if ( ! embeddings [ 0 ] ) throw new Error ( "Missing embedding vector in response" ) ;
215+ return embeddings [ 0 ] ;
173216 } catch ( error ) {
174217 if ( ! isContextLengthError ( error ) ) throw error ;
175218 const nextCandidate = shrinkEmbeddingInput ( candidate ) ;
@@ -183,11 +226,13 @@ async function embedSingleAdaptive(input: string): Promise<number[]> {
183226
184227async function embedBatchAdaptive ( batch : string [ ] ) : Promise < number [ ] [ ] > {
185228 try {
186- const response = await embedWithTimeout ( buildEmbedRequest ( batch ) ) ;
187- if ( response . embeddings . length !== batch . length ) {
188- throw new Error ( `Embedding response size mismatch: expected ${ batch . length } , got ${ response . embeddings . length } ` ) ;
229+ const timeoutCtrl = AbortSignal . timeout ( EMBED_TIMEOUT_MS ) ;
230+ const signal = AbortSignal . any ( [ embedAbortController . signal , timeoutCtrl ] ) ;
231+ const embeddings = await callProviderEmbed ( batch , signal ) ;
232+ if ( embeddings . length !== batch . length ) {
233+ throw new Error ( `Embedding response size mismatch: expected ${ batch . length } , got ${ embeddings . length } ` ) ;
189234 }
190- return response . embeddings ;
235+ return embeddings ;
191236 } catch ( error ) {
192237 if ( ! isContextLengthError ( error ) ) throw error ;
193238 if ( batch . length === 1 ) {
0 commit comments