1- // Ollama-powered vector embedding engine with cosine similarity search
1+ // Multi-provider vector embedding engine with cosine similarity search
2+ // Supports Ollama (local) and OpenAI-compatible APIs (Gemini, OpenAI, etc.)
23// Indexes file headers and symbols, caches embeddings to disk for speed
34
4- import { Ollama } from "ollama" ;
55import { readFile , writeFile , mkdir } from "fs/promises" ;
66import { join } from "path" ;
77
@@ -74,9 +74,14 @@ export interface EmbeddingCache {
7474 [ path : string ] : { hash : string ; vector : number [ ] } ;
7575}
7676
77+ const EMBED_PROVIDER = ( process . env . CONTEXTPLUS_EMBED_PROVIDER ?? "ollama" ) . toLowerCase ( ) ;
7778const EMBED_MODEL = process . env . OLLAMA_EMBED_MODEL ?? "nomic-embed-text" ;
79+ const OPENAI_EMBED_MODEL = process . env . CONTEXTPLUS_OPENAI_EMBED_MODEL ?? process . env . OPENAI_EMBED_MODEL ?? "text-embedding-3-small" ;
80+ const OPENAI_API_KEY = process . env . CONTEXTPLUS_OPENAI_API_KEY ?? process . env . OPENAI_API_KEY ?? "" ;
81+ const OPENAI_BASE_URL = process . env . CONTEXTPLUS_OPENAI_BASE_URL ?? process . env . OPENAI_BASE_URL ?? "https://api.openai.com/v1" ;
7882const CACHE_DIR = ".mcp_data" ;
79- const CACHE_FILE = "embeddings-cache.json" ;
83+ const ACTIVE_EMBED_MODEL = EMBED_PROVIDER === "openai" ? OPENAI_EMBED_MODEL : EMBED_MODEL ;
84+ const CACHE_FILE = `embeddings-cache-${ EMBED_PROVIDER } -${ ACTIVE_EMBED_MODEL . replace ( / [ ^ a - z A - Z 0 - 9 . _ - ] / g, "_" ) } .json` ;
8085const MIN_EMBED_BATCH_SIZE = 5 ;
8186const MAX_EMBED_BATCH_SIZE = 10 ;
8287const DEFAULT_EMBED_BATCH_SIZE = 8 ;
@@ -87,7 +92,53 @@ const MIN_EMBED_CHUNK_CHARS = 256;
8792const DEFAULT_EMBED_CHUNK_CHARS = 2000 ;
8893const MAX_EMBED_CHUNK_CHARS = 8000 ;
8994
90- const ollama = new Ollama ( { host : process . env . OLLAMA_HOST } ) ;
95+ type OllamaEmbedClient = { embed : ( params : Record < string , unknown > ) => Promise < { embeddings : number [ ] [ ] } > } ;
96+ let ollamaClient : OllamaEmbedClient | null = null ;
97+
98+ async function getOllamaClient ( ) : Promise < OllamaEmbedClient > {
99+ if ( ! ollamaClient ) {
100+ const { Ollama } = await import ( "ollama" ) ;
101+ ollamaClient = new Ollama ( { host : process . env . OLLAMA_HOST } ) as unknown as OllamaEmbedClient ;
102+ }
103+ return ollamaClient ;
104+ }
105+
106+ async function callOllamaEmbed ( input : string [ ] , signal : AbortSignal ) : Promise < number [ ] [ ] > {
107+ const client = await getOllamaClient ( ) ;
108+ const options = getEmbedRuntimeOptions ( ) ;
109+ const request : Record < string , unknown > = { model : EMBED_MODEL , input, signal } ;
110+ if ( options ) request . options = options ;
111+ const response = await client . embed ( request ) ;
112+ return response . embeddings ;
113+ }
114+
115+ async function callOpenAIEmbed ( input : string [ ] , signal : AbortSignal ) : Promise < number [ ] [ ] > {
116+ const url = `${ OPENAI_BASE_URL . replace ( / \/ + $ / , "" ) } /embeddings` ;
117+ const response = await fetch ( url , {
118+ method : "POST" ,
119+ headers : {
120+ "Content-Type" : "application/json" ,
121+ "Authorization" : `Bearer ${ OPENAI_API_KEY } ` ,
122+ } ,
123+ body : JSON . stringify ( { model : OPENAI_EMBED_MODEL , input } ) ,
124+ signal,
125+ } ) ;
126+
127+ if ( ! response . ok ) {
128+ const body = await response . text ( ) . catch ( ( ) => "" ) ;
129+ throw new Error ( `OpenAI embed API error ${ response . status } : ${ body } ` ) ;
130+ }
131+
132+ const data = await response . json ( ) as { data : { embedding : number [ ] } [ ] } ;
133+ return data . data . map ( ( item ) => item . embedding ) ;
134+ }
135+
136+ async function callProviderEmbed ( input : string [ ] , signal : AbortSignal ) : Promise < number [ ] [ ] > {
137+ if ( EMBED_PROVIDER === "openai" ) {
138+ return callOpenAIEmbed ( input , signal ) ;
139+ }
140+ return callOllamaEmbed ( input , signal ) ;
141+ }
91142
92143function toIntegerOr ( value : string | undefined , fallback : number ) : number {
93144 if ( ! value ) return fallback ;
@@ -110,6 +161,7 @@ function toOptionalBoolean(value: string | undefined): boolean | undefined {
110161}
111162
112163function getEmbedRuntimeOptions ( ) : EmbedRuntimeOptions | undefined {
164+ if ( EMBED_PROVIDER === "openai" ) return undefined ;
113165 const options : EmbedRuntimeOptions = {
114166 num_gpu : toOptionalInteger ( process . env . CONTEXTPLUS_EMBED_NUM_GPU ) ,
115167 main_gpu : toOptionalInteger ( process . env . CONTEXTPLUS_EMBED_MAIN_GPU ) ,
@@ -123,17 +175,6 @@ function getEmbedRuntimeOptions(): EmbedRuntimeOptions | undefined {
123175 return options ;
124176}
125177
126- function buildEmbedRequest ( input : string [ ] ) : { model : string ; input : string [ ] ; options ?: EmbedRuntimeOptions } {
127- const options = getEmbedRuntimeOptions ( ) ;
128- return options ? { model : EMBED_MODEL , input, options } : { model : EMBED_MODEL , input } ;
129- }
130-
131- async function embedWithTimeout ( request : ReturnType < typeof buildEmbedRequest > ) : Promise < { embeddings : number [ ] [ ] } > {
132- const timeoutCtrl = AbortSignal . timeout ( EMBED_TIMEOUT_MS ) ;
133- const signal = AbortSignal . any ( [ embedAbortController . signal , timeoutCtrl ] ) ;
134- return ollama . embed ( { ...request , signal } as Parameters < typeof ollama . embed > [ 0 ] ) ;
135- }
136-
137178export function getEmbeddingBatchSize ( ) : number {
138179 const requested = toIntegerOr ( process . env . CONTEXTPLUS_EMBED_BATCH_SIZE , DEFAULT_EMBED_BATCH_SIZE ) ;
139180 return Math . min ( MAX_EMBED_BATCH_SIZE , Math . max ( MIN_EMBED_BATCH_SIZE , requested ) ) ;
@@ -152,7 +193,8 @@ function getErrorMessage(error: unknown): string {
152193function isContextLengthError ( error : unknown ) : boolean {
153194 const message = getErrorMessage ( error ) . toLowerCase ( ) ;
154195 return message . includes ( "input length exceeds context length" )
155- || ( message . includes ( "context" ) && message . includes ( "exceed" ) ) ;
196+ || ( message . includes ( "context" ) && message . includes ( "exceed" ) )
197+ || message . includes ( "maximum context length" ) ;
156198}
157199
158200function shrinkEmbeddingInput ( input : string ) : string {
@@ -167,9 +209,11 @@ async function embedSingleAdaptive(input: string): Promise<number[]> {
167209
168210 for ( let attempt = 0 ; attempt <= MAX_SINGLE_INPUT_RETRIES ; attempt ++ ) {
169211 try {
170- const response = await embedWithTimeout ( buildEmbedRequest ( [ candidate ] ) ) ;
171- if ( ! response . embeddings [ 0 ] ) throw new Error ( "Missing embedding vector in Ollama response" ) ;
172- return response . embeddings [ 0 ] ;
212+ const timeoutCtrl = AbortSignal . timeout ( EMBED_TIMEOUT_MS ) ;
213+ const signal = AbortSignal . any ( [ embedAbortController . signal , timeoutCtrl ] ) ;
214+ const embeddings = await callProviderEmbed ( [ candidate ] , signal ) ;
215+ if ( ! embeddings [ 0 ] ) throw new Error ( "Missing embedding vector in response" ) ;
216+ return embeddings [ 0 ] ;
173217 } catch ( error ) {
174218 if ( ! isContextLengthError ( error ) ) throw error ;
175219 const nextCandidate = shrinkEmbeddingInput ( candidate ) ;
@@ -183,11 +227,13 @@ async function embedSingleAdaptive(input: string): Promise<number[]> {
183227
184228async function embedBatchAdaptive ( batch : string [ ] ) : Promise < number [ ] [ ] > {
185229 try {
186- const response = await embedWithTimeout ( buildEmbedRequest ( batch ) ) ;
187- if ( response . embeddings . length !== batch . length ) {
188- throw new Error ( `Embedding response size mismatch: expected ${ batch . length } , got ${ response . embeddings . length } ` ) ;
230+ const timeoutCtrl = AbortSignal . timeout ( EMBED_TIMEOUT_MS ) ;
231+ const signal = AbortSignal . any ( [ embedAbortController . signal , timeoutCtrl ] ) ;
232+ const embeddings = await callProviderEmbed ( batch , signal ) ;
233+ if ( embeddings . length !== batch . length ) {
234+ throw new Error ( `Embedding response size mismatch: expected ${ batch . length } , got ${ embeddings . length } ` ) ;
189235 }
190- return response . embeddings ;
236+ return embeddings ;
191237 } catch ( error ) {
192238 if ( ! isContextLengthError ( error ) ) throw error ;
193239 if ( batch . length === 1 ) {
0 commit comments