stt_bridge.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """Bridge between AudioEngine and stt.mm.mk WebSocket STT service.
  2. Runs STT WebSocket in a **subprocess** to avoid conflicts with eventlet.
  3. Communication: stdin (length-prefixed PCM binary) / stdout (JSON lines).
  4. feed_audio() is non-blocking: puts data in a bounded queue, a writer
  5. thread drains the queue to subprocess stdin.
  6. """
  7. from __future__ import annotations
  8. import json
  9. import logging
  10. import os
  11. import queue
  12. import subprocess
  13. import threading
  14. import time
  15. from dataclasses import dataclass
  16. from pathlib import Path
  17. import numpy as np
  18. logger = logging.getLogger("stt_bridge")
  19. _WORKER_PATH = str(Path(__file__).resolve().parent / "stt_worker.py")
  20. _VENV_PYTHON = str(Path(__file__).resolve().parent / ".venv" / "bin" / "python")
  21. @dataclass
  22. class SttSettings:
  23. enabled: bool = False
  24. language: str = "pl"
  25. timestamps: bool = True
  26. diarize: bool = True
  27. itn: bool = True
  28. detect_emotion: bool = False
  29. server_vad: bool = False
  30. vad_threshold: float = 0.3
  31. vad_pad_ms: int = 400
  32. vad_min_ms: int = 100
  33. class SttBridge:
  34. """Manages subprocess STT worker that connects to stt.mm.mk."""
  35. STT_URL = "wss://stt.mm.mk/ws/transcribe"
  36. MAX_QUEUE = 30
  37. def __init__(self, on_message=None):
  38. self._lock = threading.Lock()
  39. self._settings = SttSettings()
  40. self._on_message = on_message
  41. self._process: subprocess.Popen | None = None
  42. self._reader_thread: threading.Thread | None = None
  43. self._writer_thread: threading.Thread | None = None
  44. self._audio_queue: queue.Queue[bytes] = queue.Queue(maxsize=self.MAX_QUEUE)
  45. self._should_run = False
  46. self._connected = False
  47. self._sample_rate = 16000
  48. def get_settings(self) -> dict:
  49. with self._lock:
  50. return {
  51. "stt_enabled": self._settings.enabled,
  52. "stt_language": self._settings.language,
  53. "stt_timestamps": self._settings.timestamps,
  54. "stt_diarize": self._settings.diarize,
  55. "stt_itn": self._settings.itn,
  56. "stt_detect_emotion": self._settings.detect_emotion,
  57. "stt_server_vad": self._settings.server_vad,
  58. "stt_vad_threshold": self._settings.vad_threshold,
  59. "stt_vad_pad_ms": self._settings.vad_pad_ms,
  60. "stt_vad_min_ms": self._settings.vad_min_ms,
  61. "stt_connected": self._connected,
  62. }
  63. def update_settings(self, **kwargs) -> dict:
  64. reconnect_keys = {
  65. "language", "timestamps", "diarize", "itn",
  66. "detect_emotion", "server_vad", "vad_threshold",
  67. "vad_pad_ms", "vad_min_ms",
  68. }
  69. changed_enabled = False
  70. need_reconnect = False
  71. with self._lock:
  72. for key, val in kwargs.items():
  73. attr = key.replace("stt_", "")
  74. if hasattr(self._settings, attr):
  75. old = getattr(self._settings, attr)
  76. setattr(self._settings, attr, type(old)(val))
  77. if attr == "enabled" and old != self._settings.enabled:
  78. changed_enabled = True
  79. if attr in reconnect_keys:
  80. need_reconnect = True
  81. if changed_enabled:
  82. if self._settings.enabled:
  83. self._start_worker()
  84. else:
  85. self._stop_worker()
  86. elif self._settings.enabled and self._process is not None and need_reconnect:
  87. self._stop_worker()
  88. self._start_worker()
  89. return self.get_settings()
  90. def feed_audio(self, audio: np.ndarray, sample_rate: int) -> None:
  91. """Feed processed audio to STT. Completely non-blocking.
  92. Puts length-prefixed PCM into a bounded queue. If queue is full,
  93. drops the chunk silently (better than blocking the audio callback).
  94. """
  95. if not self._should_run or not self._settings.enabled:
  96. return
  97. self._sample_rate = sample_rate
  98. pcm16 = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
  99. payload = pcm16.tobytes()
  100. header = len(payload).to_bytes(4, "little")
  101. try:
  102. self._audio_queue.put_nowait(header + payload)
  103. except queue.Full:
  104. # Drop — never block the audio callback
  105. try:
  106. self._audio_queue.get_nowait() # drop oldest
  107. except queue.Empty:
  108. pass
  109. try:
  110. self._audio_queue.put_nowait(header + payload)
  111. except queue.Full:
  112. pass
  113. def _writer_loop(self):
  114. """Background thread: drains queue → writes to subprocess stdin."""
  115. while self._should_run:
  116. try:
  117. data = self._audio_queue.get(timeout=0.2)
  118. except queue.Empty:
  119. continue
  120. proc = self._process
  121. if proc is None or proc.poll() is not None:
  122. continue
  123. try:
  124. proc.stdin.write(data)
  125. proc.stdin.flush()
  126. except (BrokenPipeError, OSError):
  127. pass
  128. def _build_url(self) -> str:
  129. s = self._settings
  130. parts = [
  131. "language=" + s.language,
  132. "rate=" + str(self._sample_rate),
  133. "stream_id=mic-system-" + str(int(time.time())),
  134. ]
  135. if s.timestamps:
  136. parts.append("timestamps=1")
  137. if s.diarize:
  138. parts.append("diarize=1")
  139. if s.itn:
  140. parts.append("itn=1")
  141. if s.detect_emotion:
  142. parts.append("detect_emotion=1")
  143. if s.server_vad:
  144. parts.append("vad=1")
  145. parts.append("vad_threshold=" + str(s.vad_threshold))
  146. parts.append("vad_pad_ms=" + str(s.vad_pad_ms))
  147. parts.append("vad_min_ms=" + str(s.vad_min_ms))
  148. return self.STT_URL + "?" + "&".join(parts)
  149. def _start_worker(self):
  150. self._should_run = True
  151. # Drain queue
  152. while not self._audio_queue.empty():
  153. try:
  154. self._audio_queue.get_nowait()
  155. except queue.Empty:
  156. break
  157. url = self._build_url()
  158. logger.info("STT starting worker subprocess: %s", url)
  159. python = _VENV_PYTHON if os.path.exists(_VENV_PYTHON) else "python3"
  160. try:
  161. self._process = subprocess.Popen(
  162. [python, _WORKER_PATH, url],
  163. stdin=subprocess.PIPE,
  164. stdout=subprocess.PIPE,
  165. stderr=subprocess.DEVNULL,
  166. bufsize=0,
  167. )
  168. except Exception as e:
  169. logger.error("Failed to start STT worker: %s", e)
  170. self._should_run = False
  171. return
  172. self._reader_thread = threading.Thread(
  173. target=self._read_stdout,
  174. daemon=True,
  175. )
  176. self._reader_thread.start()
  177. self._writer_thread = threading.Thread(
  178. target=self._writer_loop,
  179. daemon=True,
  180. )
  181. self._writer_thread.start()
  182. def _read_stdout(self):
  183. """Read JSON lines from worker stdout and forward via callback."""
  184. proc = self._process
  185. if proc is None:
  186. return
  187. try:
  188. for line in proc.stdout:
  189. line = line.strip()
  190. if not line:
  191. continue
  192. try:
  193. msg = json.loads(line)
  194. except (json.JSONDecodeError, ValueError):
  195. continue
  196. if msg.get("type") == "stt_status":
  197. self._connected = bool(msg.get("connected", False))
  198. if self._on_message:
  199. self._on_message(msg)
  200. except Exception:
  201. pass
  202. finally:
  203. self._connected = False
  204. if self._on_message:
  205. self._on_message({"type": "stt_status", "connected": False})
  206. def _stop_worker(self):
  207. self._should_run = False
  208. self._connected = False
  209. proc = self._process
  210. self._process = None
  211. # Drain queue
  212. while not self._audio_queue.empty():
  213. try:
  214. self._audio_queue.get_nowait()
  215. except queue.Empty:
  216. break
  217. if proc is not None:
  218. try:
  219. proc.stdin.close()
  220. except Exception:
  221. pass
  222. try:
  223. proc.terminate()
  224. proc.wait(timeout=3)
  225. except Exception:
  226. try:
  227. proc.kill()
  228. except Exception:
  229. pass
  230. self._reader_thread = None
  231. self._writer_thread = None
  232. def stop(self):
  233. self._stop_worker()