|
10 | 10 |
|
11 | 11 | logger = logging.getLogger(__name__) |
12 | 12 |
|
13 | | -M = 1_000_000 |
14 | | - |
15 | | -OPENAI_EMBEDDING_MODELS = [ |
16 | | - "text-embedding-3-small", |
17 | | - "text-embedding-3-large", |
18 | | -] |
19 | | - |
20 | 13 | AZURE_EMBEDDING_MODELS = [ |
21 | 14 | "azure-text-embedding-3-small", |
22 | 15 | "azure-text-embedding-3-large", |
23 | 16 | ] |
24 | 17 |
|
25 | | -OPENAI_EMBEDDING_COSTS = { |
26 | | - "text-embedding-3-small": 0.02 / M, |
27 | | - "text-embedding-3-large": 0.13 / M, |
28 | | -} |
29 | | - |
30 | 18 |
|
31 | 19 | class EmbeddingClient: |
32 | | - def __init__(self, model_name: str = "text-embedding-3-small"): |
| 20 | + def __init__(self, model_name: str = "text-embedding-3-small", base_url: str | None = None): |
33 | 21 | """ |
34 | 22 | Initialize the EmbeddingClient. |
35 | 23 |
|
36 | 24 | Args: |
37 | | - model (str): The OpenAI embedding model name to use. |
| 25 | + model_name: The embedding model name to use. |
| 26 | + base_url: Optional base URL for the embedding API endpoint. |
38 | 27 | """ |
39 | | - self.client, self.model = self._get_client_model(model_name) |
| 28 | + self.client, self.model = self._get_client_model(model_name, base_url) |
40 | 29 |
|
41 | | - def _get_client_model(self, model_name: str) -> tuple[openai.OpenAI, str]: |
42 | | - if model_name in OPENAI_EMBEDDING_MODELS: |
43 | | - # Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY |
44 | | - # This allows users to use OpenRouter for LLMs while using OpenAI for embeddings |
45 | | - embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") |
46 | | - client = openai.OpenAI(api_key=embedding_api_key) |
47 | | - model_to_use = model_name |
48 | | - elif model_name in AZURE_EMBEDDING_MODELS: |
| 30 | + def _get_client_model( |
| 31 | + self, model_name: str, base_url: str | None = None |
| 32 | + ) -> tuple[openai.OpenAI, str]: |
| 33 | + if model_name in AZURE_EMBEDDING_MODELS: |
49 | 34 | # get rid of the azure- prefix |
50 | 35 | model_to_use = model_name.split("azure-")[-1] |
51 | 36 | client = openai.AzureOpenAI( |
52 | 37 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), |
53 | 38 | api_version=os.getenv("AZURE_API_VERSION"), |
54 | | - azure_endpoint=os.getenv("AZURE_API_ENDPOINT"), |
| 39 | + azure_endpoint=os.environ["AZURE_API_ENDPOINT"], |
55 | 40 | ) |
56 | 41 | else: |
57 | | - raise ValueError(f"Invalid embedding model: {model_name}") |
| 42 | + # Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY |
| 43 | + # This allows users to use OpenRouter for LLMs while using OpenAI for embeddings |
| 44 | + embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") |
| 45 | + client = openai.OpenAI(api_key=embedding_api_key, base_url=base_url) |
| 46 | + model_to_use = model_name |
58 | 47 |
|
59 | 48 | return client, model_to_use |
60 | 49 |
|
|
0 commit comments