|
53 | 53 | "min_delay", # minimum delay between queries for that model |
54 | 54 | ] |
55 | 55 |
|
| 56 | + |
| 57 | +def cleaned_args(args: dict): |
| 58 | + """If there is an API key in the dict, censor it""" |
| 59 | + args = args.copy() |
| 60 | + if "api_key" in args: |
| 61 | + args["api_key"] = "***" |
| 62 | + return args |
| 63 | + |
| 64 | + |
56 | 65 | def any2message( |
57 | 66 | message: str|List[Dict[str,str]]|Dict[str,str], |
58 | 67 | vars: Optional[Dict] = None, |
@@ -725,12 +734,6 @@ def query( |
725 | 734 | otherwise answer contains the response and error is the empty string. |
726 | 735 | The boolean key "ok" is True if there is no error, False otherwise. |
727 | 736 | """ |
728 | | - def cleaned_args(args: dict): |
729 | | - """If there is an API key in the dict, censor it""" |
730 | | - args = args.copy() |
731 | | - if "api_key" in args: |
732 | | - args["api_key"] = "***" |
733 | | - return args |
734 | 737 | if self.debug: |
735 | 738 | debug = True |
736 | 739 | if self.cost_logger: |
@@ -1095,6 +1098,84 @@ def chunk_generator(model_generator, retobj): |
1095 | 1098 | ret["ok"] = False |
1096 | 1099 | return ret |
1097 | 1100 |
|
| 1101 | + def embeddings( |
| 1102 | + self, |
| 1103 | + llmalias: str, |
| 1104 | + texts: List[str], |
| 1105 | + return_cost: bool = True, |
| 1106 | + return_response: bool = False, |
| 1107 | + debug=False, |
| 1108 | + litellm_debug=None, |
| 1109 | + **kwargs, |
| 1110 | + ) -> Dict[str, Any]: |
| 1111 | + """ |
| 1112 | + Get the embeddings for the batch of texts. |
| 1113 | +
|
| 1114 | + Args: |
| 1115 | + llmalias: LLM alias |
| 1116 | + texts: a list of texts |
| 1117 | + return_cost: if true, the cost will get returned |
| 1118 | + return_response: if true, the original response will be returned |
| 1119 | + debug: enable debugging |
| 1120 | + litellm_debug: enable even more debugging from the LiteLLM library |
| 1121 | + **kwargs: additional parameters |
| 1122 | +
|
| 1123 | + Returns: |
| 1124 | + The return object is a dict with the following keys: |
| 1125 | + - answer: the list of embeddings |
| 1126 | + - response: the original response |
| 1127 | + - cost: the cost |
| 1128 | + - n_prompt_tokens |
| 1129 | + - ok: flag to indicate is processing was successful |
| 1130 | + - error: error message if not successful |
| 1131 | + """ |
| 1132 | + llm = self.llms[llmalias].config |
| 1133 | + completion_kwargs = dict_except( |
| 1134 | + llm, |
| 1135 | + KNOWN_LLM_CONFIG_FIELDS, |
| 1136 | + ignore_underscored=True, |
| 1137 | + ) |
| 1138 | + ret = {} |
| 1139 | + logger.debug(f"Initial completion kwargs: {cleaned_args(completion_kwargs)}") |
| 1140 | + if llm.get("api_key"): |
| 1141 | + completion_kwargs["api_key"] = llm["api_key"] |
| 1142 | + elif llm.get("api_key_env"): |
| 1143 | + completion_kwargs["api_key"] = os.getenv(llm["api_key_env"]) |
| 1144 | + if llm.get("api_url"): |
| 1145 | + completion_kwargs["api_base"] = llm["api_url"] |
| 1146 | + if kwargs: |
| 1147 | + completion_kwargs.update(dict_except(kwargs, KNOWN_LLM_CONFIG_FIELDS, ignore_underscored=True)) |
| 1148 | + if debug: |
| 1149 | + logger.debug(f"calling query with completion kwargs: {cleaned_args(completion_kwargs)}") |
| 1150 | + try: |
| 1151 | + response = litellm.embedding( |
| 1152 | + model=llm["llm"], |
| 1153 | + input=texts, |
| 1154 | + **completion_kwargs |
| 1155 | + ) |
| 1156 | + except Exception as e: |
| 1157 | + logger.debug(f"Exception in query with completion kwargs: {e}", exc_info=True) |
| 1158 | + ret["error"] = str(e) |
| 1159 | + ret["answer"] = [] |
| 1160 | + ret["ok"] = False |
| 1161 | + ret["response"] = None |
| 1162 | + ret["n_prompt_tokens"] = 0 |
| 1163 | + return ret |
| 1164 | + # response should be a dict with data[].embedding and usage.prompt_tokens |
| 1165 | + if return_cost: |
| 1166 | + ret["cost"] = response._hidden_params["response_cost"] |
| 1167 | + # extract the list of embeddings |
| 1168 | + embs = [d["embedding"] for d in response["data"]] |
| 1169 | + ret["answer"] = embs |
| 1170 | + if return_response: |
| 1171 | + ret["response"] = response |
| 1172 | + else: |
| 1173 | + ret["response"] = None |
| 1174 | + ret["ok"] = True |
| 1175 | + ret["error"] = "" |
| 1176 | + ret["n_prompt_tokens"] = response["usage"]["prompt_tokens"] |
| 1177 | + return ret |
| 1178 | + |
1098 | 1179 |
|
1099 | 1180 | # For now, this class simply represents the LLM by the config dict and a pointer to the LLMS object it is contained |
1100 | 1181 | # in. In order to avoid changing any code in the LLMS object where we expect the llm config to be a dict |
|
0 commit comments