Skip to content

Commit 412ccfd

Browse files
authored
Merge pull request #410 from aitomatic/fenrir-azure-demo-dana
Added demo for use with Azure
2 parents ab16e05 + 01a6d5d commit 412ccfd

15 files changed

Lines changed: 477 additions & 0 deletions
20.2 KB
Binary file not shown.
19.5 KB
Binary file not shown.
19.5 KB
Binary file not shown.
19.6 KB
Binary file not shown.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
AZURE_OPENAI_API_KEY=
2+
AZURE_OPENAI_ENDPOINT=
3+
AZURE_OPENAI_API_VERSION=
4+
AZURE_OPENAI_GPT_MODEL_NAME=
5+
AZURE_OPENAI_GPT_DEPLOYMENT_NAME=
6+
AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME=
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
myenv/
2+
.data/.indexes
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<!-- markdownlint-disable MD013 MD043 -->
2+
3+
# OpenSSA sample
4+
5+
## ライブラリインストール(初回のみ)
6+
7+
```bash
8+
poetry install
9+
```
10+
11+
## 実行前手順
12+
13+
- `.env.template`をコピーして`.env`ファイルを作成し、ファイル内の環境変数を記載する
14+
- `.data` ディレクトリを作成して、そこに回答の情報となるPDFファイルを配置する
15+
16+
## メモ
17+
18+
- 初回実行時に `.data` 内に `.indexes` ディレクトリが作成されインデックスが配置される
19+
- 初回はその分実行に時間がかかる
20+
21+
## 実行
22+
23+
```bash
24+
poetry run python main.py
25+
```

examples/japanse-easy-demo-azure/lm/__init__.py

Whitespace-only changes.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from llama_index.llms.azure_openai import AzureOpenAI
2+
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding as LlamaEmbedding
3+
from llama_index.core import Settings
4+
from .config import LMConfig
5+
6+
azure_llama_index_lm = AzureOpenAI(
7+
deployment_name=LMConfig.AZURE_OPENAI_GPT_DEPLOYMENT_NAME,
8+
api_key=LMConfig.AZURE_OPENAI_API_KEY,
9+
azure_endpoint=LMConfig.AZURE_OPENAI_ENDPOINT,
10+
api_version=LMConfig.AZURE_OPENAI_API_VERSION,
11+
)
12+
13+
azure_llama_index_embedding = LlamaEmbedding(
14+
deployment_name=LMConfig.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME,
15+
api_key=LMConfig.AZURE_OPENAI_API_KEY,
16+
azure_endpoint=LMConfig.AZURE_OPENAI_ENDPOINT,
17+
api_version=LMConfig.AZURE_OPENAI_API_VERSION
18+
)
19+
20+
Settings.llm = azure_llama_index_lm
21+
Settings.embed_model = azure_llama_index_embedding
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
import json
5+
import time
6+
from typing import TYPE_CHECKING
7+
from loguru import logger
8+
9+
from openai import AzureOpenAI
10+
from openssa.core.util.lm.base import BaseLM
11+
from .config import LMConfig
12+
13+
14+
if TYPE_CHECKING:
15+
from openai.types.chat.chat_completion import ChatCompletion
16+
from openssa.core.util.lm.base import LMChatHist
17+
18+
19+
@dataclass
20+
class AzureOpenAILM(BaseLM):
21+
"""OpenAI LM."""
22+
23+
azure_endpoint: str = ''
24+
api_version: str = LMConfig.AZURE_OPENAI_API_VERSION
25+
client: AzureOpenAI = field(init=False)
26+
27+
def __post_init__(self):
28+
"""Initialize OpenAI client."""
29+
self.client: AzureOpenAI = AzureOpenAI(
30+
api_key=self.api_key,
31+
api_version=self.api_version,
32+
azure_endpoint=self.azure_endpoint,
33+
)
34+
35+
@classmethod
36+
def from_defaults(cls) -> AzureOpenAILM:
37+
"""Get OpenAI LM instance with default parameters."""
38+
# pylint: disable=unexpected-keyword-arg
39+
return cls(
40+
model=LMConfig.AZURE_OPENAI_GPT_MODEL_NAME,
41+
api_base=None,
42+
api_key=LMConfig.AZURE_OPENAI_API_KEY,
43+
azure_endpoint=LMConfig.AZURE_OPENAI_ENDPOINT,
44+
)
45+
46+
def call(self, messages: LMChatHist, **kwargs) -> ChatCompletion:
47+
"""Call OpenAI LM API and return response object."""
48+
return self.client.chat.completions.create(
49+
messages=messages,
50+
model=self.model,
51+
seed=kwargs.pop("seed", LMConfig.DEFAULT_SEED),
52+
temperature=kwargs.pop("temperature", LMConfig.DEFAULT_TEMPERATURE),
53+
**kwargs,
54+
)
55+
56+
def get_response(
57+
self,
58+
prompt: str,
59+
history: LMChatHist | None = None,
60+
json_format: bool = False,
61+
**kwargs,
62+
) -> str:
63+
"""Call OpenAI LM API and return response content."""
64+
messages: LMChatHist = history or []
65+
messages.append({"role": "user", "content": prompt})
66+
67+
if json_format:
68+
kwargs["response_format"] = {"type": "json_object"}
69+
70+
"""
71+
The json_format is set to false by default
72+
if true, it retries up to five times with amplified retry times.
73+
"""
74+
MAX_RETRIES = 5
75+
retries = 0
76+
backoff = 1
77+
78+
while retries < MAX_RETRIES:
79+
try:
80+
logger.info(f"RETRIES: {retries}")
81+
return json.loads(
82+
response := self.call(messages, **kwargs)
83+
.choices[0]
84+
.message.content
85+
)
86+
except json.decoder.JSONDecodeError:
87+
retries += 1
88+
logger.debug(
89+
f"INVALID JSON, TO BE RETRIED ({retries}/{MAX_RETRIES}):\n{response}"
90+
)
91+
time.sleep(backoff)
92+
backoff *= 2
93+
94+
raise ValueError("Max retries reached. Unable to parse JSON response.")
95+
96+
return self.call(messages, **kwargs).choices[0].message.content

0 commit comments

Comments
 (0)