Skip to content

Commit 9183594

Browse files
committed
Implement first version of embeddings method
1 parent 93ff75f commit 9183594

2 files changed

Lines changed: 88 additions & 7 deletions

File tree

llms_wrapper/llms.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@
5353
"min_delay", # minimum delay between queries for that model
5454
]
5555

56+
57+
def cleaned_args(args: dict):
58+
"""If there is an API key in the dict, censor it"""
59+
args = args.copy()
60+
if "api_key" in args:
61+
args["api_key"] = "***"
62+
return args
63+
64+
5665
def any2message(
5766
message: str|List[Dict[str,str]]|Dict[str,str],
5867
vars: Optional[Dict] = None,
@@ -725,12 +734,6 @@ def query(
725734
otherwise answer contains the response and error is the empty string.
726735
The boolean key "ok" is True if there is no error, False otherwise.
727736
"""
728-
def cleaned_args(args: dict):
729-
"""If there is an API key in the dict, censor it"""
730-
args = args.copy()
731-
if "api_key" in args:
732-
args["api_key"] = "***"
733-
return args
734737
if self.debug:
735738
debug = True
736739
if self.cost_logger:
@@ -1095,6 +1098,84 @@ def chunk_generator(model_generator, retobj):
10951098
ret["ok"] = False
10961099
return ret
10971100

1101+
def embeddings(
1102+
self,
1103+
llmalias: str,
1104+
texts: List[str],
1105+
return_cost: bool = True,
1106+
return_response: bool = False,
1107+
debug=False,
1108+
litellm_debug=None,
1109+
**kwargs,
1110+
) -> Dict[str, Any]:
1111+
"""
1112+
Get the embeddings for the batch of texts.
1113+
1114+
Args:
1115+
llmalias: LLM alias
1116+
texts: a list of texts
1117+
return_cost: if true, the cost will get returned
1118+
return_response: if true, the original response will be returned
1119+
debug: enable debugging
1120+
litellm_debug: enable even more debugging from the LiteLLM library
1121+
**kwargs: additional parameters
1122+
1123+
Returns:
1124+
The return object is a dict with the following keys:
1125+
- answer: the list of embeddings
1126+
- response: the original response
1127+
- cost: the cost
1128+
- n_prompt_tokens
1129+
- ok: flag to indicate is processing was successful
1130+
- error: error message if not successful
1131+
"""
1132+
llm = self.llms[llmalias].config
1133+
completion_kwargs = dict_except(
1134+
llm,
1135+
KNOWN_LLM_CONFIG_FIELDS,
1136+
ignore_underscored=True,
1137+
)
1138+
ret = {}
1139+
logger.debug(f"Initial completion kwargs: {cleaned_args(completion_kwargs)}")
1140+
if llm.get("api_key"):
1141+
completion_kwargs["api_key"] = llm["api_key"]
1142+
elif llm.get("api_key_env"):
1143+
completion_kwargs["api_key"] = os.getenv(llm["api_key_env"])
1144+
if llm.get("api_url"):
1145+
completion_kwargs["api_base"] = llm["api_url"]
1146+
if kwargs:
1147+
completion_kwargs.update(dict_except(kwargs, KNOWN_LLM_CONFIG_FIELDS, ignore_underscored=True))
1148+
if debug:
1149+
logger.debug(f"calling query with completion kwargs: {cleaned_args(completion_kwargs)}")
1150+
try:
1151+
response = litellm.embedding(
1152+
model=llm["llm"],
1153+
input=texts,
1154+
**completion_kwargs
1155+
)
1156+
except Exception as e:
1157+
logger.debug(f"Exception in query with completion kwargs: {e}", exc_info=True)
1158+
ret["error"] = str(e)
1159+
ret["answer"] = []
1160+
ret["ok"] = False
1161+
ret["response"] = None
1162+
ret["n_prompt_tokens"] = 0
1163+
return ret
1164+
# response should be a dict with data[].embedding and usage.prompt_tokens
1165+
if return_cost:
1166+
ret["cost"] = response._hidden_params["response_cost"]
1167+
# extract the list of embeddings
1168+
embs = [d["embedding"] for d in response["data"]]
1169+
ret["answer"] = embs
1170+
if return_response:
1171+
ret["response"] = response
1172+
else:
1173+
ret["response"] = None
1174+
ret["ok"] = True
1175+
ret["error"] = ""
1176+
ret["n_prompt_tokens"] = response["usage"]["prompt_tokens"]
1177+
return ret
1178+
10981179

10991180
# For now, this class simply represents the LLM by the config dict and a pointer to the LLMS object it is contained
11001181
# in. In order to avoid changing any code in the LLMS object where we expect the llm config to be a dict

llms_wrapper/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
import importlib.metadata
2-
__version__ = "0.9.1.7"
2+
__version__ = "0.9.1.8"
33

0 commit comments

Comments
 (0)