-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathaskcode_index_query.py
More file actions
212 lines (167 loc) · 6.69 KB
/
askcode_index_query.py
File metadata and controls
212 lines (167 loc) · 6.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import re
import sys
import json
import tempfile
import uuid
from chat.ask_codebase.store.qdrant import QdrantWrapper as Q, get_client
from chat.ask_codebase.indexing.embedding import EmbeddingWrapper as E
from langchain.embeddings import HuggingFaceEmbeddings
from chat.ask_codebase.indexing.loader.file import (
FileLoader,
FileSource,
gen_local_reference_maker,
)
from chat.util.misc import is_source_code
from chat.ask_codebase.chains.simple_qa import SimpleQA
from chat.ask_codebase.chains.stuff_dc_qa import StuffDocumentCodeQa
def get_app_data_dir(app_name):
home = os.path.expanduser("~")
if os.name == "nt": # For Windows
appPath = os.path.join(home, "AppData", "Roaming", app_name)
else: # For Unix and Linux
appPath = os.path.join(home, ".local", "share", app_name)
if not os.path.exists(appPath):
os.makedirs(appPath)
return appPath
supportedFileTypes = []
STORAGE_FILE = os.path.join(get_app_data_dir("devchat"), "qdrant_storage2")
SOURCE_NAME = ""
# 为已经分析的文件记录最后修改时间
g_file_last_modified_saved = {}
def load_file_last_modified(filePath: str):
# filePath表示存储了文件最后修改时间的文件名,内容实用JSON存储
# 如果文件不存在,表示尚未进行分析,结束函数执行
if not os.path.exists(filePath):
return {}
# 如果文件存在,读取文件内容,解析文件中记录的每一个文件的最后修改时间
with open(filePath, 'r', encoding="utf-8") as f:
fileLastModified = json.load(f)
return fileLastModified
def save_file_last_modified(filePath: str, fileLastModified: dict):
# filePath表示存储了文件最后修改时间的文件名,内容实用JSON存储
with open(filePath, 'w+', encoding="utf-8") as f:
json.dump(fileLastModified, f)
return fileLastModified
def is_source_code_new(filePath: str):
# 使用正则表达式来判断一个文件是否是源码文件
for pattern in supportedFileTypes:
if re.match(pattern.strip(), filePath):
return True
return False
def is_file_modified(filePath: str) -> bool:
if not is_source_code_new(filePath):
return False
# 获取当前路径
currentPath = os.getcwd()
# 将filePath转换为相对路径
relativePath = os.path.relpath(filePath, currentPath)
# 检查文件路径中是否包含'.xxx'形式的目录
for part in relativePath.split(os.sep):
if part.startswith('.') or part in ["node_modules", "__pycache__"]:
return False
# 获取文件上次分析时记录的最后修改时间
fileLastModified = g_file_last_modified_saved.get(relativePath, 0)
# 获取文件当前的最后修改时间
fileCurrentModified = os.path.getmtime(filePath)
# 如果最后修改时间不同,那么更新记录的最后修改时间,并返回True
if fileLastModified != fileCurrentModified:
g_file_last_modified_saved[relativePath] = fileCurrentModified
return True
return False
def index(repo_path: str):
try:
client = get_client(STORAGE_FILE)
source = FileSource(
path=repo_path,
rel_root=repo_path,
ref_maker=gen_local_reference_maker(repo_path),
file_filter=is_file_modified,
)
loader = FileLoader(sources=[source])
documents = loader.load()
e = E(embedding=HuggingFaceEmbeddings())
data = e.embed(documents)
q = Q.create(
source_name=SOURCE_NAME,
embedding_cls=HuggingFaceEmbeddings,
client=client,
)
q.insert(data)
except Exception as e:
print(e)
sys.exit(1)
import json
def query(question: str, doc_context: str, lsp_brige_port: int):
try:
client = get_client(mode=STORAGE_FILE)
q = Q.reuse(
source_name=SOURCE_NAME,
embedding_cls=HuggingFaceEmbeddings,
client=client,
)
chain = StuffDocumentCodeQa(q)
ans, docs = chain.run(question)
print(f"LSP brige port: {lsp_brige_port}")
print(f"\n# Question: \n{question}")
print(f"\n# Answer: \n{ans}")
print(f"\n# Relevant Documents: \n")
doc_dict = {"path": "AskCode Context","content": json.dumps([{"filepath": d.metadata.get('filepath'), "content": d.page_content} for d in docs])}
with open(doc_context, 'w') as f:
json.dump(doc_dict, f)
for d in docs:
print(f"- filepath: {d.metadata.get('filepath')}")
print(f" location: {d.metadata.get('reference')}\n")
print(f"Save doc context to {doc_context}")
except Exception as e:
print(e)
sys.exit(1)
def main():
try:
global supportedFileTypes
if len(sys.argv) < 2:
print("Usage: python index_and_query.py [command] [args]")
print("Available commands: index, query")
sys.exit(1)
command = sys.argv[1]
if command == "index":
if len(sys.argv) < 4:
print("Usage: python index_and_query.py index [repo_path] [supportedFileTypes]")
sys.exit(1)
repo_path = sys.argv[2]
# 获取supportedFileTypes的值
supportedFileTypes = sys.argv[3].split(',')
index(repo_path)
elif command == "query":
if len(sys.argv) < 5:
print("Usage: python index_and_query.py query [question] [doc_context] [port]")
sys.exit(1)
question = sys.argv[2]
doc_context = sys.argv[3]
port = sys.argv[4]
query(question, doc_context, port)
else:
print("Invalid command. Available commands: index, query")
sys.exit(1)
except Exception as e:
print(e)
sys.exit(1)
if __name__ == "__main__":
try:
currentPath = os.getcwd()
g_file_last_modified_saved = load_file_last_modified('./.chat/.index_modified.json')
if os.path.exists(".chat/askcode.json"):
with open(".chat/askcode.json", "r") as f:
askcode_data = json.load(f)
SOURCE_NAME = askcode_data.get("SOURCE_NAME", str(uuid.uuid4()))
else:
SOURCE_NAME = str(uuid.uuid4())
currentPath = os.getcwd()
with open(".chat/askcode.json", "w+") as f:
json.dump({"SOURCE_NAME": SOURCE_NAME}, f)
main()
save_file_last_modified('./.chat/.index_modified.json', g_file_last_modified_saved)
sys.exit(0)
except Exception as e:
print(e)
sys.exit(1)