-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcall_center.py
More file actions
executable file
·241 lines (197 loc) · 8.05 KB
/
call_center.py
File metadata and controls
executable file
·241 lines (197 loc) · 8.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import logging
import sys
import threading
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
import uvicorn
import webrtcvad
from websockets.sync.client import connect
from config import LLMServerConfig, LoggingConfig, WebSocketConfig
from helper.custom_sts_handler import LLM, Speech2Text, Text2Speech
from helper.llm_backends.api import CacheServerAPIBackend
from helper.PROMPT import SYSTEM_PROMPT
from helper.wav_handler import WavHandler
from helper.ws_command import WSCommandHelper
from model.rtp import PayloadType
from model.ws_command import CommandType
from web_chat import app as web_app
from web_chat import broadcast_message
ws_cmd = WSCommandHelper()
wav_handler = WavHandler()
llm_server_config = LLMServerConfig()
logger = logging.getLogger()
@dataclass
class RTPPacket:
payload_type: PayloadType
data: bytes
@classmethod
def from_hex(cls, payload_type: int, hex_string: str) -> "RTPPacket":
return cls(
payload_type=PayloadType(payload_type), data=bytes.fromhex(hex_string)
)
class RTPSession:
def __init__(
self,
vad_mode: int = 0,
sample_rate: int = 8000,
frame_duration_ms: int = 20,
):
self.call_id: str = ""
self.sample_rate: int = sample_rate
self.frame_duration_ms = frame_duration_ms
self.codec: PayloadType = PayloadType.PCMA
self.samples_per_frame = (sample_rate * frame_duration_ms) // 1000
self.buffer: list[RTPPacket] = []
self.vad = webrtcvad.Vad(vad_mode)
self.minimum_number_packet = 50
def add_packet(self, packet: RTPPacket) -> bool:
pcm_data = wav_handler.hex2pcm([packet.data])
is_speech_frame = self.vad.is_speech(b"".join(pcm_data), self.sample_rate)
logger.debug(f"{is_speech_frame=} {pcm_data[0][:20]=}")
if is_speech_frame:
self.buffer.append(packet)
logger.debug(f"number of packet {len(self.buffer)}")
else:
if len(self.buffer) >= self.minimum_number_packet:
return True
return False
def flush(self) -> list[bytes]:
data = [p.data for p in self.buffer]
self.clear()
return data
def clear(self) -> None:
self.buffer.clear()
def handle_ans(session: RTPSession, call_id: str) -> None:
session.clear()
session.call_id = call_id
logger.info(f"Call started: {session.call_id}")
def handle_bye(session: RTPSession, call_id: str) -> None:
logger.info(f"Call ended: {session.call_id}")
async def main() -> None:
session = RTPSession()
llm_backend = CacheServerAPIBackend(
api_version=llm_server_config.api_version,
server_endpoint_url=llm_server_config.api_url,
system_prompt=SYSTEM_PROMPT,
)
llm_handler = LLM(llm_backend)
stt = Speech2Text()
tts = Text2Speech()
command_handler = {
CommandType.CALL_ANS: handle_ans,
CommandType.BYE: handle_bye,
}
logger.info("LLM and STT/TTS initialized")
packet_count = 0
ws_url = WebSocketConfig().ws_url
try:
with connect(ws_url) as websocket:
logger.info("WebSocket connected")
for message in websocket:
try:
cmd = ws_cmd.parser(str(message))
except Exception as e:
logger.warning(
f"Invalid message format for call {session.call_id}: {e}"
)
continue
if not isinstance(cmd.content, str):
continue
# handle CALL_ANS, BYE
handler = command_handler.get(cmd.type)
if handler:
handler(session, cmd.content)
continue
# handle RTP
if cmd.type != CommandType.RTP:
logger.info(cmd.type)
continue
try:
payload_type, rtp_hex = cmd.content.split("##")
packet = RTPPacket.from_hex(int(payload_type), rtp_hex)
except (ValueError, AttributeError) as e:
logger.warning(
f"Malformed RTP packet for call {session.call_id}, packet count {packet_count}: {e}"
)
continue
except Exception as e:
logger.error(
f"Unexpected error: {e}",
exc_info=True,
)
continue
session.codec = packet.payload_type
if session.call_id == "":
continue
if session.add_packet(packet):
packet_count += 120
logger.info(f"Processed {packet_count} packets")
wav_path = None
response_audio_path = Path(f"./output/response/{uuid4()}.wav")
if not response_audio_path.parent.exists():
response_audio_path.parent.mkdir(parents=True, exist_ok=True)
try:
wav_path = wav_handler.hex2wav(session.flush(), session.codec)
logger.info(f"WAV file converted at {wav_path}")
audio_transcribe, language = stt.transcribe(wav_path)
await broadcast_message(
session.call_id, "user", audio_transcribe, language
)
llm_response = await llm_handler.generate_response(
audio_transcribe, language, user_id=session.call_id
)
logger.info(f"LLM Response: {llm_response}")
await broadcast_message(
session.call_id, "assistant", llm_response, language
)
tts.generate(llm_response, response_audio_path, language)
wav_data = wav_handler.wav2base64(response_audio_path)
logger.info(f"WAV data: {wav_data[:30]}...")
websocket.send(
str(
ws_cmd.builder(
CommandType.RTP, f"{session.call_id}##{wav_data}"
)
)
)
logger.info("Audio sent")
except Exception as e:
logger.error(
f"Processing failed for call {session.call_id}: {e}"
)
session.clear()
finally:
logging.info("Cleaning up temporary files")
if wav_path and wav_path.exists():
wav_path.unlink()
if response_audio_path and response_audio_path.exists():
response_audio_path.unlink()
except KeyboardInterrupt:
logger.info("Interrupted by user")
sys.exit(0)
except Exception as e:
logger.critical(f"Fatal error: {e}", exc_info=True)
sys.exit(1)
def start_web_server(host: str = "127.0.0.1", port: int = 8088) -> None:
"""Start the web chat server in a background thread"""
config = uvicorn.Config(web_app, host=host, port=port, log_level="warning")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
logger.info(f"Web chat server started at http://{host}:{port}")
if __name__ == "__main__":
import asyncio
import sys
logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s] - %(asctime)s - %(message)s - %(pathname)s:%(lineno)d",
filemode="w+",
filename=LoggingConfig().call_center_log_file,
datefmt="%y-%m-%d %H:%M:%S",
)
console_handler = logging.StreamHandler(sys.stdout)
logger = logging.getLogger()
logger.addHandler(console_handler)
start_web_server()
asyncio.run(main())