stt_bridge.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. """Bridge between AudioEngine and stt.mm.mk WebSocket STT service.
  2. Connects to wss://stt.mm.mk/ws/transcribe, streams processed audio,
  3. and forwards transcription results back via a callback.
  4. """
  5. from __future__ import annotations
  6. import json
  7. import logging
  8. import threading
  9. import time
  10. from dataclasses import dataclass
  11. import numpy as np
  12. logger = logging.getLogger("stt_bridge")
  13. try:
  14. import websocket as ws_client # websocket-client library
  15. except ImportError:
  16. ws_client = None
  17. logger.warning("websocket-client not installed — STT bridge unavailable")
  18. @dataclass
  19. class SttSettings:
  20. enabled: bool = False
  21. language: str = "pl"
  22. timestamps: bool = True
  23. diarize: bool = True
  24. itn: bool = True
  25. detect_emotion: bool = False
  26. server_vad: bool = False
  27. vad_threshold: float = 0.3
  28. vad_pad_ms: int = 400
  29. vad_min_ms: int = 100
  30. class SttBridge:
  31. """Manages WebSocket connection to stt.mm.mk and streams audio."""
  32. STT_URL = "wss://stt.mm.mk/ws/transcribe"
  33. def __init__(self, on_message=None):
  34. self._lock = threading.Lock()
  35. self._settings = SttSettings()
  36. self._on_message = on_message
  37. self._ws = None
  38. self._ws_thread: threading.Thread | None = None
  39. self._connected = False
  40. self._should_run = False
  41. self._sample_rate = 16000
  42. def get_settings(self) -> dict:
  43. with self._lock:
  44. return {
  45. "stt_enabled": self._settings.enabled,
  46. "stt_language": self._settings.language,
  47. "stt_timestamps": self._settings.timestamps,
  48. "stt_diarize": self._settings.diarize,
  49. "stt_itn": self._settings.itn,
  50. "stt_detect_emotion": self._settings.detect_emotion,
  51. "stt_server_vad": self._settings.server_vad,
  52. "stt_vad_threshold": self._settings.vad_threshold,
  53. "stt_vad_pad_ms": self._settings.vad_pad_ms,
  54. "stt_vad_min_ms": self._settings.vad_min_ms,
  55. "stt_connected": self._connected,
  56. }
  57. def update_settings(self, **kwargs) -> dict:
  58. reconnect_keys = {
  59. "language", "timestamps", "diarize", "itn",
  60. "detect_emotion", "server_vad", "vad_threshold",
  61. "vad_pad_ms", "vad_min_ms",
  62. }
  63. changed_enabled = False
  64. need_reconnect = False
  65. with self._lock:
  66. for key, val in kwargs.items():
  67. attr = key.replace("stt_", "")
  68. if hasattr(self._settings, attr):
  69. old = getattr(self._settings, attr)
  70. setattr(self._settings, attr, type(old)(val))
  71. if attr == "enabled" and old != self._settings.enabled:
  72. changed_enabled = True
  73. if attr in reconnect_keys:
  74. need_reconnect = True
  75. if changed_enabled:
  76. if self._settings.enabled:
  77. self._start_connection()
  78. else:
  79. self._stop_connection()
  80. elif self._settings.enabled and self._connected and need_reconnect:
  81. self._stop_connection()
  82. self._start_connection()
  83. return self.get_settings()
  84. def feed_audio(self, audio: np.ndarray, sample_rate: int) -> None:
  85. """Feed processed audio (post-beamforming/AGC) to STT."""
  86. if not self._connected or not self._settings.enabled:
  87. return
  88. self._sample_rate = sample_rate
  89. pcm16 = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
  90. try:
  91. if self._ws and self._connected:
  92. self._ws.send(pcm16.tobytes(), opcode=0x2)
  93. except Exception:
  94. pass
  95. def _build_url(self) -> str:
  96. s = self._settings
  97. parts = [
  98. "language=" + s.language,
  99. "rate=" + str(self._sample_rate),
  100. "stream_id=mic-system-" + str(int(time.time())),
  101. ]
  102. if s.timestamps:
  103. parts.append("timestamps=1")
  104. if s.diarize:
  105. parts.append("diarize=1")
  106. if s.itn:
  107. parts.append("itn=1")
  108. if s.detect_emotion:
  109. parts.append("detect_emotion=1")
  110. if s.server_vad:
  111. parts.append("vad=1")
  112. parts.append("vad_threshold=" + str(s.vad_threshold))
  113. parts.append("vad_pad_ms=" + str(s.vad_pad_ms))
  114. parts.append("vad_min_ms=" + str(s.vad_min_ms))
  115. return self.STT_URL + "?" + "&".join(parts)
  116. def _start_connection(self):
  117. if ws_client is None:
  118. logger.error("websocket-client not installed")
  119. return
  120. self._should_run = True
  121. url = self._build_url()
  122. logger.info("STT connecting to %s", url)
  123. bridge = self
  124. def on_open(ws):
  125. bridge._connected = True
  126. logger.info("STT WebSocket connected")
  127. if bridge._on_message:
  128. bridge._on_message({"type": "stt_status", "connected": True})
  129. def on_message(ws, message):
  130. try:
  131. msg = json.loads(message)
  132. if bridge._on_message:
  133. bridge._on_message(msg)
  134. except Exception as e:
  135. logger.error("STT message parse error: %s", e)
  136. def on_error(ws, error):
  137. logger.error("STT WebSocket error: %s", error)
  138. def on_close(ws, close_status_code, close_msg):
  139. bridge._connected = False
  140. logger.info("STT WebSocket closed: %s %s", close_status_code, close_msg)
  141. if bridge._on_message:
  142. bridge._on_message({"type": "stt_status", "connected": False})
  143. if bridge._should_run:
  144. time.sleep(2)
  145. if bridge._should_run:
  146. bridge._start_connection()
  147. self._ws = ws_client.WebSocketApp(
  148. url,
  149. on_open=on_open,
  150. on_message=on_message,
  151. on_error=on_error,
  152. on_close=on_close,
  153. )
  154. self._ws_thread = threading.Thread(
  155. target=self._ws.run_forever,
  156. kwargs={"ping_interval": 20, "ping_timeout": 10},
  157. daemon=True,
  158. )
  159. self._ws_thread.start()
  160. def _stop_connection(self):
  161. self._should_run = False
  162. self._connected = False
  163. if self._ws:
  164. try:
  165. self._ws.close()
  166. except Exception:
  167. pass
  168. self._ws = None
  169. self._ws_thread = None
  170. def stop(self):
  171. self._stop_connection()