| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- """Bridge between AudioEngine and stt.mm.mk WebSocket STT service.
- Connects to wss://stt.mm.mk/ws/transcribe, streams processed audio,
- and forwards transcription results back via a callback.
- """
- from __future__ import annotations
- import json
- import logging
- import threading
- import time
- from dataclasses import dataclass
- import numpy as np
- logger = logging.getLogger("stt_bridge")
- try:
- import websocket as ws_client # websocket-client library
- except ImportError:
- ws_client = None
- logger.warning("websocket-client not installed — STT bridge unavailable")
- @dataclass
- class SttSettings:
- enabled: bool = False
- language: str = "pl"
- timestamps: bool = True
- diarize: bool = True
- itn: bool = True
- detect_emotion: bool = False
- server_vad: bool = False
- vad_threshold: float = 0.3
- vad_pad_ms: int = 400
- vad_min_ms: int = 100
- class SttBridge:
- """Manages WebSocket connection to stt.mm.mk and streams audio."""
- STT_URL = "wss://stt.mm.mk/ws/transcribe"
- def __init__(self, on_message=None):
- self._lock = threading.Lock()
- self._settings = SttSettings()
- self._on_message = on_message
- self._ws = None
- self._ws_thread: threading.Thread | None = None
- self._connected = False
- self._should_run = False
- self._sample_rate = 16000
- def get_settings(self) -> dict:
- with self._lock:
- return {
- "stt_enabled": self._settings.enabled,
- "stt_language": self._settings.language,
- "stt_timestamps": self._settings.timestamps,
- "stt_diarize": self._settings.diarize,
- "stt_itn": self._settings.itn,
- "stt_detect_emotion": self._settings.detect_emotion,
- "stt_server_vad": self._settings.server_vad,
- "stt_vad_threshold": self._settings.vad_threshold,
- "stt_vad_pad_ms": self._settings.vad_pad_ms,
- "stt_vad_min_ms": self._settings.vad_min_ms,
- "stt_connected": self._connected,
- }
- def update_settings(self, **kwargs) -> dict:
- reconnect_keys = {
- "language", "timestamps", "diarize", "itn",
- "detect_emotion", "server_vad", "vad_threshold",
- "vad_pad_ms", "vad_min_ms",
- }
- changed_enabled = False
- need_reconnect = False
- with self._lock:
- for key, val in kwargs.items():
- attr = key.replace("stt_", "")
- if hasattr(self._settings, attr):
- old = getattr(self._settings, attr)
- setattr(self._settings, attr, type(old)(val))
- if attr == "enabled" and old != self._settings.enabled:
- changed_enabled = True
- if attr in reconnect_keys:
- need_reconnect = True
- if changed_enabled:
- if self._settings.enabled:
- self._start_connection()
- else:
- self._stop_connection()
- elif self._settings.enabled and self._connected and need_reconnect:
- self._stop_connection()
- self._start_connection()
- return self.get_settings()
- def feed_audio(self, audio: np.ndarray, sample_rate: int) -> None:
- """Feed processed audio (post-beamforming/AGC) to STT."""
- if not self._connected or not self._settings.enabled:
- return
- self._sample_rate = sample_rate
- pcm16 = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
- try:
- if self._ws and self._connected:
- self._ws.send(pcm16.tobytes(), opcode=0x2)
- except Exception:
- pass
- def _build_url(self) -> str:
- s = self._settings
- parts = [
- "language=" + s.language,
- "rate=" + str(self._sample_rate),
- "stream_id=mic-system-" + str(int(time.time())),
- ]
- if s.timestamps:
- parts.append("timestamps=1")
- if s.diarize:
- parts.append("diarize=1")
- if s.itn:
- parts.append("itn=1")
- if s.detect_emotion:
- parts.append("detect_emotion=1")
- if s.server_vad:
- parts.append("vad=1")
- parts.append("vad_threshold=" + str(s.vad_threshold))
- parts.append("vad_pad_ms=" + str(s.vad_pad_ms))
- parts.append("vad_min_ms=" + str(s.vad_min_ms))
- return self.STT_URL + "?" + "&".join(parts)
- def _start_connection(self):
- if ws_client is None:
- logger.error("websocket-client not installed")
- return
- self._should_run = True
- url = self._build_url()
- logger.info("STT connecting to %s", url)
- bridge = self
- def on_open(ws):
- bridge._connected = True
- logger.info("STT WebSocket connected")
- if bridge._on_message:
- bridge._on_message({"type": "stt_status", "connected": True})
- def on_message(ws, message):
- try:
- msg = json.loads(message)
- if bridge._on_message:
- bridge._on_message(msg)
- except Exception as e:
- logger.error("STT message parse error: %s", e)
- def on_error(ws, error):
- logger.error("STT WebSocket error: %s", error)
- def on_close(ws, close_status_code, close_msg):
- bridge._connected = False
- logger.info("STT WebSocket closed: %s %s", close_status_code, close_msg)
- if bridge._on_message:
- bridge._on_message({"type": "stt_status", "connected": False})
- if bridge._should_run:
- time.sleep(2)
- if bridge._should_run:
- bridge._start_connection()
- self._ws = ws_client.WebSocketApp(
- url,
- on_open=on_open,
- on_message=on_message,
- on_error=on_error,
- on_close=on_close,
- )
- self._ws_thread = threading.Thread(
- target=self._ws.run_forever,
- kwargs={"ping_interval": 20, "ping_timeout": 10},
- daemon=True,
- )
- self._ws_thread.start()
- def _stop_connection(self):
- self._should_run = False
- self._connected = False
- if self._ws:
- try:
- self._ws.close()
- except Exception:
- pass
- self._ws = None
- self._ws_thread = None
- def stop(self):
- self._stop_connection()
|