stt_bridge.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. """Bridge between AudioEngine and stt.mm.mk WebSocket STT service.
  2. Connects to wss://stt.mm.mk/ws/transcribe, streams processed audio,
  3. and forwards transcription results back via a callback.
  4. Uses a queue + sender thread so feed_audio() never blocks the audio callback.
  5. """
  6. from __future__ import annotations
  7. import json
  8. import logging
  9. import queue
  10. import threading
  11. import time
  12. from dataclasses import dataclass
  13. import numpy as np
  14. logger = logging.getLogger("stt_bridge")
  15. try:
  16. import websocket as ws_client # websocket-client library
  17. except ImportError:
  18. ws_client = None
  19. logger.warning("websocket-client not installed — STT bridge unavailable")
  20. @dataclass
  21. class SttSettings:
  22. enabled: bool = False
  23. language: str = "pl"
  24. timestamps: bool = True
  25. diarize: bool = True
  26. itn: bool = True
  27. detect_emotion: bool = False
  28. server_vad: bool = False
  29. vad_threshold: float = 0.3
  30. vad_pad_ms: int = 400
  31. vad_min_ms: int = 100
  32. class SttBridge:
  33. """Manages WebSocket connection to stt.mm.mk and streams audio."""
  34. STT_URL = "wss://stt.mm.mk/ws/transcribe"
  35. # Max queued audio chunks before dropping (prevent memory buildup)
  36. MAX_QUEUE = 50
  37. def __init__(self, on_message=None):
  38. self._lock = threading.Lock()
  39. self._settings = SttSettings()
  40. self._on_message = on_message
  41. self._ws = None
  42. self._ws_thread: threading.Thread | None = None
  43. self._sender_thread: threading.Thread | None = None
  44. self._audio_queue: queue.Queue[bytes] = queue.Queue(maxsize=self.MAX_QUEUE)
  45. self._connected = False
  46. self._should_run = False
  47. self._sample_rate = 16000
  48. def get_settings(self) -> dict:
  49. with self._lock:
  50. return {
  51. "stt_enabled": self._settings.enabled,
  52. "stt_language": self._settings.language,
  53. "stt_timestamps": self._settings.timestamps,
  54. "stt_diarize": self._settings.diarize,
  55. "stt_itn": self._settings.itn,
  56. "stt_detect_emotion": self._settings.detect_emotion,
  57. "stt_server_vad": self._settings.server_vad,
  58. "stt_vad_threshold": self._settings.vad_threshold,
  59. "stt_vad_pad_ms": self._settings.vad_pad_ms,
  60. "stt_vad_min_ms": self._settings.vad_min_ms,
  61. "stt_connected": self._connected,
  62. }
  63. def update_settings(self, **kwargs) -> dict:
  64. reconnect_keys = {
  65. "language", "timestamps", "diarize", "itn",
  66. "detect_emotion", "server_vad", "vad_threshold",
  67. "vad_pad_ms", "vad_min_ms",
  68. }
  69. changed_enabled = False
  70. need_reconnect = False
  71. with self._lock:
  72. for key, val in kwargs.items():
  73. attr = key.replace("stt_", "")
  74. if hasattr(self._settings, attr):
  75. old = getattr(self._settings, attr)
  76. setattr(self._settings, attr, type(old)(val))
  77. if attr == "enabled" and old != self._settings.enabled:
  78. changed_enabled = True
  79. if attr in reconnect_keys:
  80. need_reconnect = True
  81. if changed_enabled:
  82. if self._settings.enabled:
  83. self._start_connection()
  84. else:
  85. self._stop_connection()
  86. elif self._settings.enabled and self._connected and need_reconnect:
  87. self._stop_connection()
  88. self._start_connection()
  89. return self.get_settings()
  90. def feed_audio(self, audio: np.ndarray, sample_rate: int) -> None:
  91. """Feed processed audio to STT. Non-blocking — drops if queue is full."""
  92. if not self._connected or not self._settings.enabled:
  93. return
  94. self._sample_rate = sample_rate
  95. pcm16 = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
  96. try:
  97. self._audio_queue.put_nowait(pcm16.tobytes())
  98. except queue.Full:
  99. # Drop oldest chunk to make room
  100. try:
  101. self._audio_queue.get_nowait()
  102. except queue.Empty:
  103. pass
  104. try:
  105. self._audio_queue.put_nowait(pcm16.tobytes())
  106. except queue.Full:
  107. pass
  108. def _sender_loop(self):
  109. """Background thread that drains the queue and sends to WebSocket."""
  110. while self._should_run:
  111. try:
  112. data = self._audio_queue.get(timeout=0.2)
  113. except queue.Empty:
  114. continue
  115. if not self._connected or self._ws is None:
  116. continue
  117. try:
  118. self._ws.send(data, opcode=0x2)
  119. except Exception:
  120. pass
  121. def _build_url(self) -> str:
  122. s = self._settings
  123. parts = [
  124. "language=" + s.language,
  125. "rate=" + str(self._sample_rate),
  126. "stream_id=mic-system-" + str(int(time.time())),
  127. ]
  128. if s.timestamps:
  129. parts.append("timestamps=1")
  130. if s.diarize:
  131. parts.append("diarize=1")
  132. if s.itn:
  133. parts.append("itn=1")
  134. if s.detect_emotion:
  135. parts.append("detect_emotion=1")
  136. if s.server_vad:
  137. parts.append("vad=1")
  138. parts.append("vad_threshold=" + str(s.vad_threshold))
  139. parts.append("vad_pad_ms=" + str(s.vad_pad_ms))
  140. parts.append("vad_min_ms=" + str(s.vad_min_ms))
  141. return self.STT_URL + "?" + "&".join(parts)
  142. def _start_connection(self):
  143. if ws_client is None:
  144. logger.error("websocket-client not installed")
  145. return
  146. self._should_run = True
  147. # Clear queue
  148. while not self._audio_queue.empty():
  149. try:
  150. self._audio_queue.get_nowait()
  151. except queue.Empty:
  152. break
  153. url = self._build_url()
  154. logger.info("STT connecting to %s", url)
  155. bridge = self
  156. def on_open(ws):
  157. bridge._connected = True
  158. logger.info("STT WebSocket connected")
  159. if bridge._on_message:
  160. bridge._on_message({"type": "stt_status", "connected": True})
  161. def on_message(ws, message):
  162. try:
  163. msg = json.loads(message)
  164. if bridge._on_message:
  165. bridge._on_message(msg)
  166. except Exception as e:
  167. logger.error("STT message parse error: %s", e)
  168. def on_error(ws, error):
  169. logger.error("STT WebSocket error: %s", error)
  170. def on_close(ws, close_status_code, close_msg):
  171. bridge._connected = False
  172. logger.info("STT WebSocket closed: %s %s", close_status_code, close_msg)
  173. if bridge._on_message:
  174. bridge._on_message({"type": "stt_status", "connected": False})
  175. if bridge._should_run:
  176. time.sleep(2)
  177. if bridge._should_run:
  178. bridge._start_connection()
  179. self._ws = ws_client.WebSocketApp(
  180. url,
  181. on_open=on_open,
  182. on_message=on_message,
  183. on_error=on_error,
  184. on_close=on_close,
  185. )
  186. self._ws_thread = threading.Thread(
  187. target=self._ws.run_forever,
  188. kwargs={"ping_interval": 20, "ping_timeout": 10},
  189. daemon=True,
  190. )
  191. self._ws_thread.start()
  192. # Start sender thread
  193. self._sender_thread = threading.Thread(
  194. target=self._sender_loop,
  195. daemon=True,
  196. )
  197. self._sender_thread.start()
  198. def _stop_connection(self):
  199. self._should_run = False
  200. self._connected = False
  201. if self._ws:
  202. try:
  203. self._ws.close()
  204. except Exception:
  205. pass
  206. self._ws = None
  207. self._ws_thread = None
  208. self._sender_thread = None
  209. # Drain queue
  210. while not self._audio_queue.empty():
  211. try:
  212. self._audio_queue.get_nowait()
  213. except queue.Empty:
  214. break
  215. def stop(self):
  216. self._stop_connection()