stt_bridge.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. """Bridge between AudioEngine and stt.mm.mk WebSocket STT service.
  2. Runs STT WebSocket in a **subprocess** to avoid conflicts with eventlet.
  3. Communication: stdin (length-prefixed PCM binary) / stdout (JSON lines).
  4. """
  5. from __future__ import annotations
  6. import json
  7. import logging
  8. import os
  9. import subprocess
  10. import threading
  11. import time
  12. from dataclasses import dataclass
  13. from pathlib import Path
  14. import numpy as np
  15. logger = logging.getLogger("stt_bridge")
  16. _WORKER_PATH = str(Path(__file__).resolve().parent / "stt_worker.py")
  17. _VENV_PYTHON = str(Path(__file__).resolve().parent / ".venv" / "bin" / "python")
  18. @dataclass
  19. class SttSettings:
  20. enabled: bool = False
  21. language: str = "pl"
  22. timestamps: bool = True
  23. diarize: bool = True
  24. itn: bool = True
  25. detect_emotion: bool = False
  26. server_vad: bool = False
  27. vad_threshold: float = 0.3
  28. vad_pad_ms: int = 400
  29. vad_min_ms: int = 100
  30. class SttBridge:
  31. """Manages subprocess STT worker that connects to stt.mm.mk."""
  32. STT_URL = "wss://stt.mm.mk/ws/transcribe"
  33. def __init__(self, on_message=None):
  34. self._lock = threading.Lock()
  35. self._settings = SttSettings()
  36. self._on_message = on_message
  37. self._process: subprocess.Popen | None = None
  38. self._reader_thread: threading.Thread | None = None
  39. self._connected = False
  40. self._sample_rate = 16000
  41. def get_settings(self) -> dict:
  42. with self._lock:
  43. return {
  44. "stt_enabled": self._settings.enabled,
  45. "stt_language": self._settings.language,
  46. "stt_timestamps": self._settings.timestamps,
  47. "stt_diarize": self._settings.diarize,
  48. "stt_itn": self._settings.itn,
  49. "stt_detect_emotion": self._settings.detect_emotion,
  50. "stt_server_vad": self._settings.server_vad,
  51. "stt_vad_threshold": self._settings.vad_threshold,
  52. "stt_vad_pad_ms": self._settings.vad_pad_ms,
  53. "stt_vad_min_ms": self._settings.vad_min_ms,
  54. "stt_connected": self._connected,
  55. }
  56. def update_settings(self, **kwargs) -> dict:
  57. reconnect_keys = {
  58. "language", "timestamps", "diarize", "itn",
  59. "detect_emotion", "server_vad", "vad_threshold",
  60. "vad_pad_ms", "vad_min_ms",
  61. }
  62. changed_enabled = False
  63. need_reconnect = False
  64. with self._lock:
  65. for key, val in kwargs.items():
  66. attr = key.replace("stt_", "")
  67. if hasattr(self._settings, attr):
  68. old = getattr(self._settings, attr)
  69. setattr(self._settings, attr, type(old)(val))
  70. if attr == "enabled" and old != self._settings.enabled:
  71. changed_enabled = True
  72. if attr in reconnect_keys:
  73. need_reconnect = True
  74. if changed_enabled:
  75. if self._settings.enabled:
  76. self._start_worker()
  77. else:
  78. self._stop_worker()
  79. elif self._settings.enabled and self._process is not None and need_reconnect:
  80. self._stop_worker()
  81. self._start_worker()
  82. return self.get_settings()
  83. def feed_audio(self, audio: np.ndarray, sample_rate: int) -> None:
  84. """Feed processed audio to STT subprocess. Non-blocking."""
  85. proc = self._process
  86. if proc is None or proc.poll() is not None:
  87. return
  88. if not self._settings.enabled:
  89. return
  90. self._sample_rate = sample_rate
  91. pcm16 = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
  92. payload = pcm16.tobytes()
  93. try:
  94. # Length-prefixed binary: 4 bytes little-endian length + PCM data
  95. header = len(payload).to_bytes(4, "little")
  96. proc.stdin.write(header + payload)
  97. proc.stdin.flush()
  98. except (BrokenPipeError, OSError):
  99. pass
  100. def _build_url(self) -> str:
  101. s = self._settings
  102. parts = [
  103. "language=" + s.language,
  104. "rate=" + str(self._sample_rate),
  105. "stream_id=mic-system-" + str(int(time.time())),
  106. ]
  107. if s.timestamps:
  108. parts.append("timestamps=1")
  109. if s.diarize:
  110. parts.append("diarize=1")
  111. if s.itn:
  112. parts.append("itn=1")
  113. if s.detect_emotion:
  114. parts.append("detect_emotion=1")
  115. if s.server_vad:
  116. parts.append("vad=1")
  117. parts.append("vad_threshold=" + str(s.vad_threshold))
  118. parts.append("vad_pad_ms=" + str(s.vad_pad_ms))
  119. parts.append("vad_min_ms=" + str(s.vad_min_ms))
  120. return self.STT_URL + "?" + "&".join(parts)
  121. def _start_worker(self):
  122. url = self._build_url()
  123. logger.info("STT starting worker subprocess: %s", url)
  124. python = _VENV_PYTHON if os.path.exists(_VENV_PYTHON) else "python3"
  125. try:
  126. self._process = subprocess.Popen(
  127. [python, _WORKER_PATH, url],
  128. stdin=subprocess.PIPE,
  129. stdout=subprocess.PIPE,
  130. stderr=subprocess.DEVNULL,
  131. bufsize=0,
  132. )
  133. except Exception as e:
  134. logger.error("Failed to start STT worker: %s", e)
  135. return
  136. self._reader_thread = threading.Thread(
  137. target=self._read_stdout,
  138. daemon=True,
  139. )
  140. self._reader_thread.start()
  141. def _read_stdout(self):
  142. """Read JSON lines from worker stdout and forward via callback."""
  143. proc = self._process
  144. if proc is None:
  145. return
  146. try:
  147. for line in proc.stdout:
  148. line = line.strip()
  149. if not line:
  150. continue
  151. try:
  152. msg = json.loads(line)
  153. except (json.JSONDecodeError, ValueError):
  154. continue
  155. if msg.get("type") == "stt_status":
  156. self._connected = bool(msg.get("connected", False))
  157. if self._on_message:
  158. self._on_message(msg)
  159. except Exception:
  160. pass
  161. finally:
  162. self._connected = False
  163. if self._on_message:
  164. self._on_message({"type": "stt_status", "connected": False})
  165. def _stop_worker(self):
  166. self._connected = False
  167. proc = self._process
  168. self._process = None
  169. if proc is not None:
  170. try:
  171. proc.stdin.close()
  172. except Exception:
  173. pass
  174. try:
  175. proc.terminate()
  176. proc.wait(timeout=3)
  177. except Exception:
  178. try:
  179. proc.kill()
  180. except Exception:
  181. pass
  182. self._reader_thread = None
  183. def stop(self):
  184. self._stop_worker()