Skip to content

Commit d6a1765

Browse files
committed
Add devchat log --insert
1 parent 28a89fc commit d6a1765

4 files changed

Lines changed: 133 additions & 16 deletions

File tree

devchat/_cli/log.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
11
import json
22
import sys
3+
from typing import Optional, List, Dict
4+
from pydantic import BaseModel
35
import rich_click as click
4-
from devchat.openai.openai_chat import OpenAIChat, OpenAIChatConfig
6+
from devchat.openai.openai_chat import OpenAIChat, OpenAIChatConfig, OpenAIPrompt
57
from devchat.store import Store
6-
from devchat.utils import get_logger
8+
from devchat.utils import get_logger, get_user_info
79
from devchat._cli.utils import handle_errors, init_dir, get_model_config
810

11+
12+
class PromptData(BaseModel):
13+
model: str
14+
messages: List[Dict]
15+
parent: Optional[str] = None
16+
references: Optional[List[str]] = []
17+
timestamp: int
18+
request_tokens: int
19+
response_tokens: int
20+
21+
922
logger = get_logger(__name__)
1023

1124

@@ -14,13 +27,15 @@
1427
@click.option('-n', '--max-count', default=1, help='Limit the number of commits to output.')
1528
@click.option('-t', '--topic', 'topic_root', default=None,
1629
help='Hash of the root prompt of the topic to select prompts from.')
17-
@click.option('--delete', default=None, help='Delete a leaf prompt from the log.')
18-
def log(skip, max_count, topic_root, delete):
30+
@click.option('--insert', default=None, help='JSON string of the prompt to insert into the log.')
31+
@click.option('--delete', default=None, help='Hash of the leaf prompt to delete from the log.')
32+
def log(skip, max_count, topic_root, insert, delete):
1933
"""
2034
Manage the prompt history.
2135
"""
22-
if delete and (skip != 0 or max_count != 1 or topic_root is not None):
23-
click.echo("Error: The --delete option cannot be used with other options.", err=True)
36+
if (insert or delete) and (skip != 0 or max_count != 1 or topic_root is not None):
37+
click.echo("Error: The --insert or --delete option cannot be used with other options.",
38+
err=True)
2439
sys.exit(1)
2540

2641
repo_chat_dir, user_chat_dir = init_dir()
@@ -39,6 +54,19 @@ def log(skip, max_count, topic_root, delete):
3954
else:
4055
click.echo(f"Failed to delete prompt {delete}.")
4156
else:
57+
if insert:
58+
prompt_data = PromptData(**json.loads(insert))
59+
user, email = get_user_info()
60+
prompt = OpenAIPrompt(prompt_data.model, user, email)
61+
prompt.model = prompt_data.model
62+
prompt.input_messages(prompt_data.messages)
63+
prompt.parent = prompt_data.parent
64+
prompt.references = prompt_data.references
65+
prompt._timestamp = prompt_data.timestamp
66+
prompt._request_tokens = prompt_data.request_tokens
67+
prompt._response_tokens = prompt_data.response_tokens
68+
store.store_prompt(prompt)
69+
4270
recent_prompts = store.select_prompts(skip, skip + max_count, topic_root)
4371
logs = []
4472
for record in recent_prompts:

devchat/openai/openai_prompt.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,15 @@ def input_messages(self, messages: List[dict]):
8484
logger.warning("Invalid new context message: %s", message)
8585

8686
if not self.request:
87-
last_user_message = self._history_messages[Message.CHAT].pop()
88-
if last_user_message.role in ("user", "function"):
89-
self._new_messages["request"] = last_user_message
90-
else:
91-
logger.warning("Invalid user request: %s", last_user_message)
87+
while True:
88+
last_message = self._history_messages[Message.CHAT].pop()
89+
if last_message.role in ("user", "function"):
90+
self._new_messages["request"] = last_message
91+
break
92+
if last_message.role == "assistant":
93+
self._new_messages["responses"].append(last_message)
94+
continue
95+
self._history_messages[Message.CHAT].append(last_message)
9296

9397
def append_new(self, message_type: str, content: str,
9498
available_tokens: int = sys.maxsize) -> bool:
@@ -232,7 +236,7 @@ def _validate_model(self, response_data: dict):
232236
f"got '{response_data['model']}'")
233237

234238
def _timestamp_from_dict(self, response_data: dict):
235-
if self._timestamp is None:
239+
if not self._timestamp:
236240
self._timestamp = response_data['created']
237241
elif self._timestamp != response_data['created']:
238242
raise ValueError(f"Time mismatch: expected {self._timestamp}, "

devchat/prompt.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ class Prompt(ABC):
4444
})
4545
parent: str = None
4646
references: List[str] = field(default_factory=list)
47-
_timestamp: int = None
47+
_timestamp: int = 0
4848
_request_tokens: int = 0
4949
_response_tokens: int = 0
5050
_response_reasons: List[str] = field(default_factory=list)
5151
_hash: str = None
5252

53-
def _complete_for_hash(self) -> bool:
53+
def _complete_for_hashing(self) -> bool:
5454
"""
5555
Check if the prompt is complete for hashing.
5656
@@ -62,6 +62,10 @@ def _complete_for_hash(self) -> bool:
6262
self.request, self.responses)
6363
return False
6464

65+
if not self.timestamp:
66+
logger.warning("Prompt lacks timestamp for hashing: %s", self.request)
67+
return False
68+
6569
if not self._response_tokens:
6670
return False
6771

@@ -114,7 +118,7 @@ def messages(self) -> List[dict]:
114118
def input_messages(self, messages: List[dict]):
115119
"""
116120
Input the messages from the chat API to new and history messages.
117-
The message list should be generated by the `messages` property.
121+
The message list must follow the convention of the `messages` property.
118122
119123
Args:
120124
messages (List[dict]): The messages from the chat API.
@@ -185,7 +189,7 @@ def finalize_hash(self) -> str:
185189
Returns:
186190
str: The hash of the prompt. None if the prompt is incomplete.
187191
"""
188-
if not self._complete_for_hash():
192+
if not self._complete_for_hashing():
189193
self._hash = None
190194

191195
if self._hash:

tests/test_cli_log.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,84 @@ def test_tokens_with_log(git_repo): # pylint: disable=W0613
9090
logs = json.loads(result.output)
9191
assert _within_range(logs[1]["request_tokens"], logs[0]["request_tokens"])
9292
assert _within_range(logs[1]["response_tokens"], logs[0]["response_tokens"])
93+
94+
95+
def test_log_insert(git_repo): # pylint: disable=W0613
96+
chat1 = """{
97+
"model": "gpt-3.5-turbo",
98+
"messages": [
99+
{
100+
"role": "user",
101+
"content": "This is Topic 1. Reply the topic number."
102+
},
103+
{
104+
"role": "assistant",
105+
"content": "Topic 1"
106+
}
107+
],
108+
"timestamp": 1610000000,
109+
"request_tokens": 100,
110+
"response_tokens": 100
111+
}"""
112+
result = runner.invoke(
113+
main,
114+
['log', '--insert', chat1]
115+
)
116+
prompt1 = json.loads(result.output)[0]
117+
118+
chat2 = """{
119+
"model": "gpt-3.5-turbo",
120+
"messages": [
121+
{
122+
"role": "user",
123+
"content": "This is Topic 2. Reply the topic number."
124+
},
125+
{
126+
"role": "assistant",
127+
"content": "Topic 2"
128+
}
129+
],
130+
"timestamp": 1620000000,
131+
"request_tokens": 200,
132+
"response_tokens": 200
133+
}"""
134+
result = runner.invoke(
135+
main,
136+
['log', '--insert', chat2]
137+
)
138+
prompt2 = json.loads(result.output)[0]
139+
140+
chat3 = """{
141+
"model": "gpt-3.5-turbo",
142+
"messages": [
143+
{
144+
"role": "user",
145+
"content": "Let's continue with Topic 1."
146+
},
147+
{
148+
"role": "assistant",
149+
"content": "Sure!"
150+
}
151+
],
152+
"parent": "%s",
153+
"timestamp": 1630000000,
154+
"request_tokens": 300,
155+
"response_tokens": 300
156+
}""" % prompt1['hash']
157+
result = runner.invoke(
158+
main,
159+
['log', '--insert', chat3]
160+
)
161+
prompt3 = json.loads(result.output)[0]
162+
assert prompt3['parent'] == prompt1['hash']
163+
164+
result = runner.invoke(main, ['log', '-n', 3])
165+
logs = json.loads(result.output)
166+
assert logs[0]['hash'] == prompt3['hash']
167+
assert logs[1]['hash'] == prompt2['hash']
168+
assert logs[2]['hash'] == prompt1['hash']
169+
170+
result = runner.invoke(main, ['topic', '--list'])
171+
topics = json.loads(result.output)
172+
assert topics[0]['root_prompt']['hash'] == prompt1['hash']
173+
assert topics[1]['root_prompt']['hash'] == prompt2['hash']

0 commit comments

Comments
 (0)