Skip to content

Commit 54ce6e9

Browse files
committed
Added demo for use with Azure
1 parent bec69a4 commit 54ce6e9

14 files changed

Lines changed: 471 additions & 0 deletions
20.2 KB
Binary file not shown.
19.5 KB
Binary file not shown.
19.5 KB
Binary file not shown.
19.6 KB
Binary file not shown.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
myenv/
2+
.data/.indexes
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<!-- markdownlint-disable MD013 MD043 -->
2+
3+
# OpenSSA sample
4+
5+
## ライブラリインストール(初回のみ)
6+
7+
```bash
8+
poetry install
9+
```
10+
11+
## 実行前手順
12+
13+
- `.env.template`をコピーして`.env`ファイルを作成し、ファイル内の環境変数を記載する
14+
- `.data` ディレクトリを作成して、そこに回答の情報となるPDFファイルを配置する
15+
16+
## メモ
17+
18+
- 初回実行時に `.data` 内に `.indexes` ディレクトリが作成されインデックスが配置される
19+
- 初回はその分実行に時間がかかる
20+
21+
## 実行
22+
23+
```bash
24+
poetry run python main.py
25+
```

examples/japanse-easy-demo-azure/lm/__init__.py

Whitespace-only changes.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from llama_index.llms.azure_openai import AzureOpenAI
2+
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding as LlamaEmbedding
3+
from llama_index.core import Settings
4+
from .config import LMConfig
5+
6+
azure_llama_index_lm = AzureOpenAI(
7+
deployment_name=LMConfig.AZURE_OPENAI_GPT_DEPLOYMENT_NAME,
8+
api_key=LMConfig.AZURE_OPENAI_API_KEY,
9+
azure_endpoint=LMConfig.AZURE_OPENAI_ENDPOINT,
10+
api_version=LMConfig.AZURE_OPENAI_API_VERSION,
11+
)
12+
13+
azure_llama_index_embedding = LlamaEmbedding(
14+
deployment_name=LMConfig.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME,
15+
api_key=LMConfig.AZURE_OPENAI_API_KEY,
16+
azure_endpoint=LMConfig.AZURE_OPENAI_ENDPOINT,
17+
api_version=LMConfig.AZURE_OPENAI_API_VERSION
18+
)
19+
20+
Settings.llm = azure_llama_index_lm
21+
Settings.embed_model = azure_llama_index_embedding
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
import json
5+
import time
6+
from typing import TYPE_CHECKING
7+
from loguru import logger
8+
9+
from openai import AzureOpenAI
10+
from openssa.core.util.lm.base import BaseLM
11+
from .config import LMConfig
12+
13+
14+
if TYPE_CHECKING:
15+
from openai.types.chat.chat_completion import ChatCompletion
16+
from openssa.core.util.lm.base import LMChatHist
17+
18+
19+
@dataclass
20+
class AzureOpenAILM(BaseLM):
21+
"""OpenAI LM."""
22+
23+
azure_endpoint: str = ''
24+
api_version: str = LMConfig.AZURE_OPENAI_API_VERSION
25+
client: AzureOpenAI = field(init=False)
26+
27+
def __post_init__(self):
28+
"""Initialize OpenAI client."""
29+
self.client: AzureOpenAI = AzureOpenAI(
30+
api_key=self.api_key,
31+
api_version=self.api_version,
32+
azure_endpoint=self.azure_endpoint,
33+
)
34+
35+
@classmethod
36+
def from_defaults(cls) -> AzureOpenAILM:
37+
"""Get OpenAI LM instance with default parameters."""
38+
# pylint: disable=unexpected-keyword-arg
39+
return cls(
40+
model=LMConfig.AZURE_OPENAI_GPT_MODEL_NAME,
41+
api_base=None,
42+
api_key=LMConfig.AZURE_OPENAI_API_KEY,
43+
azure_endpoint=LMConfig.AZURE_OPENAI_ENDPOINT,
44+
)
45+
46+
def call(self, messages: LMChatHist, **kwargs) -> ChatCompletion:
47+
"""Call OpenAI LM API and return response object."""
48+
return self.client.chat.completions.create(
49+
messages=messages,
50+
model=self.model,
51+
seed=kwargs.pop("seed", LMConfig.DEFAULT_SEED),
52+
temperature=kwargs.pop("temperature", LMConfig.DEFAULT_TEMPERATURE),
53+
**kwargs,
54+
)
55+
56+
def get_response(
57+
self,
58+
prompt: str,
59+
history: LMChatHist | None = None,
60+
json_format: bool = False,
61+
**kwargs,
62+
) -> str:
63+
"""Call OpenAI LM API and return response content."""
64+
messages: LMChatHist = history or []
65+
messages.append({"role": "user", "content": prompt})
66+
67+
if json_format:
68+
kwargs["response_format"] = {"type": "json_object"}
69+
70+
"""
71+
The json_format is set to false by default
72+
if true, it retries up to five times with amplified retry times.
73+
"""
74+
MAX_RETRIES = 5
75+
retries = 0
76+
backoff = 1
77+
78+
while retries < MAX_RETRIES:
79+
try:
80+
logger.info(f"RETRIES: {retries}")
81+
return json.loads(
82+
response := self.call(messages, **kwargs)
83+
.choices[0]
84+
.message.content
85+
)
86+
except json.decoder.JSONDecodeError:
87+
retries += 1
88+
logger.debug(
89+
f"INVALID JSON, TO BE RETRIED ({retries}/{MAX_RETRIES}):\n{response}"
90+
)
91+
time.sleep(backoff)
92+
backoff *= 2
93+
94+
raise ValueError("Max retries reached. Unable to parse JSON response.")
95+
96+
return self.call(messages, **kwargs).choices[0].message.content
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from typing import Any, TYPE_CHECKING
5+
6+
from openssa.core.knowledge._prompts import knowledge_injection_lm_chat_msgs
7+
from .azure_openai import AzureOpenAILM
8+
9+
from openssa.core.program_store._prompts import PROGRAM_SEARCH_PROMPT_TEMPLATE
10+
11+
if TYPE_CHECKING:
12+
from openssa.core.knowledge.base import Knowledge
13+
from openssa.core.programming.base.program import BaseProgram
14+
from openssa.core.resource.base import BaseResource
15+
from openssa.core.task.task import Task
16+
from openssa.core.util.lm.base import BaseLM, LMChatHist
17+
18+
19+
@dataclass
20+
class ProgramStore:
21+
"""Program Store containing searchable problem-solving Programs."""
22+
23+
# informative descriptions of stored problem-solving Programs, indexed by name
24+
descriptions: dict[str, str] = field(default_factory=dict,
25+
init=True,
26+
repr=False,
27+
hash=None,
28+
compare=True,
29+
metadata=None,
30+
kw_only=False)
31+
32+
# stored problem-solving Programs, indexed by name
33+
programs: dict[str, BaseProgram] = field(default_factory=dict,
34+
init=True,
35+
repr=False,
36+
hash=None,
37+
compare=True,
38+
metadata=None,
39+
kw_only=False)
40+
41+
# language model for searching among stored problem-solving Programs
42+
lm: BaseLM = field(default_factory=AzureOpenAILM.from_defaults,
43+
init=True,
44+
repr=True,
45+
hash=None,
46+
compare=True,
47+
metadata=None,
48+
kw_only=False)
49+
50+
def add_or_update_program(self, name: str, description: str, program: BaseProgram):
51+
"""Add or update a Program with its unique identifying name & informative description."""
52+
self.descriptions[name]: str = description # type: ignore
53+
self.programs[name]: BaseProgram = program # type: ignore
54+
55+
def find_program(self, task: Task, knowledge: set[Knowledge] | None = None,
56+
adaptations_from_known_programs: dict[str, Any] | None = None) -> BaseProgram | None:
57+
"""Find a suitable Program for the posed Problem, or return None."""
58+
knowledge_lm_hist: LMChatHist | None = (knowledge_injection_lm_chat_msgs(knowledge=knowledge)
59+
if knowledge
60+
else None)
61+
62+
valid_responses: set[str] = set(self.descriptions)
63+
valid_responses.add('NONE')
64+
65+
matching_program_name: str = ''
66+
while matching_program_name not in valid_responses:
67+
matching_program_name: str = self.lm.get_response(
68+
prompt=PROGRAM_SEARCH_PROMPT_TEMPLATE.format(problem=task.ask,
69+
resource_overviews={resource.unique_name: resource.overview
70+
for resource in task.resources},
71+
program_descriptions=self.descriptions),
72+
history=knowledge_lm_hist)
73+
74+
if matching_program_name == 'NONE':
75+
return None
76+
77+
adapted_program: BaseProgram = self.programs[matching_program_name].adapt(**(adaptations_from_known_programs or {}))
78+
adapted_program.task: Task = task # type: ignore
79+
return adapted_program

0 commit comments

Comments
 (0)