Selaa lähdekoodia

Fix STT: use subprocess instead of threads to avoid eventlet conflict

websocket-client threads deadlock with eventlet monkey-patched sockets.
Now STT runs as a separate subprocess (stt_worker.py) communicating
via stdin (length-prefixed PCM binary) and stdout (JSON lines).
Complete isolation from the eventlet event loop.
Paweł Chodaczek 1 kuukausi sitten
vanhempi
commit
d2b671be1c
2 muutettua tiedostoa jossa 179 lisäystä ja 124 poistoa
  1. 85 124
      stt_bridge.py
  2. 94 0
      stt_worker.py

+ 85 - 124
stt_bridge.py

@@ -1,27 +1,25 @@
 """Bridge between AudioEngine and stt.mm.mk WebSocket STT service.
 
-Connects to wss://stt.mm.mk/ws/transcribe, streams processed audio,
-and forwards transcription results back via a callback.
-Uses a queue + sender thread so feed_audio() never blocks the audio callback.
+Runs STT WebSocket in a **subprocess** to avoid conflicts with eventlet.
+Communication: stdin (length-prefixed PCM binary) / stdout (JSON lines).
 """
 from __future__ import annotations
 
 import json
 import logging
-import queue
+import os
+import subprocess
 import threading
 import time
 from dataclasses import dataclass
+from pathlib import Path
 
 import numpy as np
 
 logger = logging.getLogger("stt_bridge")
 
-try:
-    import websocket as ws_client  # websocket-client library
-except ImportError:
-    ws_client = None
-    logger.warning("websocket-client not installed — STT bridge unavailable")
+_WORKER_PATH = str(Path(__file__).resolve().parent / "stt_worker.py")
+_VENV_PYTHON = str(Path(__file__).resolve().parent / ".venv" / "bin" / "python")
 
 
 @dataclass
@@ -39,22 +37,17 @@ class SttSettings:
 
 
 class SttBridge:
-    """Manages WebSocket connection to stt.mm.mk and streams audio."""
+    """Manages subprocess STT worker that connects to stt.mm.mk."""
 
     STT_URL = "wss://stt.mm.mk/ws/transcribe"
-    # Max queued audio chunks before dropping (prevent memory buildup)
-    MAX_QUEUE = 50
 
     def __init__(self, on_message=None):
         self._lock = threading.Lock()
         self._settings = SttSettings()
         self._on_message = on_message
-        self._ws = None
-        self._ws_thread: threading.Thread | None = None
-        self._sender_thread: threading.Thread | None = None
-        self._audio_queue: queue.Queue[bytes] = queue.Queue(maxsize=self.MAX_QUEUE)
+        self._process: subprocess.Popen | None = None
+        self._reader_thread: threading.Thread | None = None
         self._connected = False
-        self._should_run = False
         self._sample_rate = 16000
 
     def get_settings(self) -> dict:
@@ -95,51 +88,34 @@ class SttBridge:
 
         if changed_enabled:
             if self._settings.enabled:
-                self._start_connection()
+                self._start_worker()
             else:
-                self._stop_connection()
-        elif self._settings.enabled and self._connected and need_reconnect:
-            self._stop_connection()
-            self._start_connection()
+                self._stop_worker()
+        elif self._settings.enabled and self._process is not None and need_reconnect:
+            self._stop_worker()
+            self._start_worker()
 
         return self.get_settings()
 
     def feed_audio(self, audio: np.ndarray, sample_rate: int) -> None:
-        """Feed processed audio to STT. Non-blocking — drops if queue is full."""
-        if not self._connected or not self._settings.enabled:
+        """Feed processed audio to STT subprocess. Non-blocking."""
+        proc = self._process
+        if proc is None or proc.poll() is not None:
+            return
+        if not self._settings.enabled:
             return
 
         self._sample_rate = sample_rate
         pcm16 = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
+        payload = pcm16.tobytes()
 
         try:
-            self._audio_queue.put_nowait(pcm16.tobytes())
-        except queue.Full:
-            # Drop oldest chunk to make room
-            try:
-                self._audio_queue.get_nowait()
-            except queue.Empty:
-                pass
-            try:
-                self._audio_queue.put_nowait(pcm16.tobytes())
-            except queue.Full:
-                pass
-
-    def _sender_loop(self):
-        """Background thread that drains the queue and sends to WebSocket."""
-        while self._should_run:
-            try:
-                data = self._audio_queue.get(timeout=0.2)
-            except queue.Empty:
-                continue
-
-            if not self._connected or self._ws is None:
-                continue
-
-            try:
-                self._ws.send(data, opcode=0x2)
-            except Exception:
-                pass
+            # Length-prefixed binary: 4 bytes little-endian length + PCM data
+            header = len(payload).to_bytes(4, "little")
+            proc.stdin.write(header + payload)
+            proc.stdin.flush()
+        except (BrokenPipeError, OSError):
+            pass
 
     def _build_url(self) -> str:
         s = self._settings
@@ -163,91 +139,76 @@ class SttBridge:
             parts.append("vad_min_ms=" + str(s.vad_min_ms))
         return self.STT_URL + "?" + "&".join(parts)
 
-    def _start_connection(self):
-        if ws_client is None:
-            logger.error("websocket-client not installed")
-            return
-
-        self._should_run = True
-
-        # Clear queue
-        while not self._audio_queue.empty():
-            try:
-                self._audio_queue.get_nowait()
-            except queue.Empty:
-                break
-
+    def _start_worker(self):
         url = self._build_url()
-        logger.info("STT connecting to %s", url)
-
-        bridge = self
+        logger.info("STT starting worker subprocess: %s", url)
 
-        def on_open(ws):
-            bridge._connected = True
-            logger.info("STT WebSocket connected")
-            if bridge._on_message:
-                bridge._on_message({"type": "stt_status", "connected": True})
+        python = _VENV_PYTHON if os.path.exists(_VENV_PYTHON) else "python3"
 
-        def on_message(ws, message):
-            try:
-                msg = json.loads(message)
-                if bridge._on_message:
-                    bridge._on_message(msg)
-            except Exception as e:
-                logger.error("STT message parse error: %s", e)
-
-        def on_error(ws, error):
-            logger.error("STT WebSocket error: %s", error)
-
-        def on_close(ws, close_status_code, close_msg):
-            bridge._connected = False
-            logger.info("STT WebSocket closed: %s %s", close_status_code, close_msg)
-            if bridge._on_message:
-                bridge._on_message({"type": "stt_status", "connected": False})
-            if bridge._should_run:
-                time.sleep(2)
-                if bridge._should_run:
-                    bridge._start_connection()
-
-        self._ws = ws_client.WebSocketApp(
-            url,
-            on_open=on_open,
-            on_message=on_message,
-            on_error=on_error,
-            on_close=on_close,
-        )
+        try:
+            self._process = subprocess.Popen(
+                [python, _WORKER_PATH, url],
+                stdin=subprocess.PIPE,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.DEVNULL,
+                bufsize=0,
+            )
+        except Exception as e:
+            logger.error("Failed to start STT worker: %s", e)
+            return
 
-        self._ws_thread = threading.Thread(
-            target=self._ws.run_forever,
-            kwargs={"ping_interval": 20, "ping_timeout": 10},
+        self._reader_thread = threading.Thread(
+            target=self._read_stdout,
             daemon=True,
         )
-        self._ws_thread.start()
+        self._reader_thread.start()
 
-        # Start sender thread
-        self._sender_thread = threading.Thread(
-            target=self._sender_loop,
-            daemon=True,
-        )
-        self._sender_thread.start()
+    def _read_stdout(self):
+        """Read JSON lines from worker stdout and forward via callback."""
+        proc = self._process
+        if proc is None:
+            return
 
-    def _stop_connection(self):
-        self._should_run = False
+        try:
+            for line in proc.stdout:
+                line = line.strip()
+                if not line:
+                    continue
+                try:
+                    msg = json.loads(line)
+                except (json.JSONDecodeError, ValueError):
+                    continue
+
+                if msg.get("type") == "stt_status":
+                    self._connected = bool(msg.get("connected", False))
+
+                if self._on_message:
+                    self._on_message(msg)
+        except Exception:
+            pass
+        finally:
+            self._connected = False
+            if self._on_message:
+                self._on_message({"type": "stt_status", "connected": False})
+
+    def _stop_worker(self):
         self._connected = False
-        if self._ws:
+        proc = self._process
+        self._process = None
+        if proc is not None:
             try:
-                self._ws.close()
+                proc.stdin.close()
             except Exception:
                 pass
-            self._ws = None
-        self._ws_thread = None
-        self._sender_thread = None
-        # Drain queue
-        while not self._audio_queue.empty():
             try:
-                self._audio_queue.get_nowait()
-            except queue.Empty:
-                break
+                proc.terminate()
+                proc.wait(timeout=3)
+            except Exception:
+                try:
+                    proc.kill()
+                except Exception:
+                    pass
+        self._reader_thread = None
 
     def stop(self):
-        self._stop_connection()
+        self._stop_worker()

+ 94 - 0
stt_worker.py

@@ -0,0 +1,94 @@
+#!/usr/bin/env python3
+"""STT worker subprocess — connects to stt.mm.mk via WebSocket.
+
+Reads raw PCM int16 binary chunks from stdin, sends to STT server.
+Writes JSON messages (one per line) to stdout.
+
+Usage: python3 stt_worker.py <ws_url>
+"""
+import json
+import sys
+import threading
+import time
+
+import websocket
+
+def main():
+    if len(sys.argv) < 2:
+        sys.exit(1)
+
+    url = sys.argv[1]
+    connected = False
+    ws = None
+
+    def on_open(w):
+        nonlocal connected
+        connected = True
+        _emit({"type": "stt_status", "connected": True})
+
+    def on_message(w, message):
+        try:
+            msg = json.loads(message)
+            _emit(msg)
+        except Exception:
+            pass
+
+    def on_error(w, error):
+        _emit({"type": "stt_error", "error": str(error)})
+
+    def on_close(w, code, msg):
+        nonlocal connected
+        connected = False
+        _emit({"type": "stt_status", "connected": False})
+
+    def _emit(obj):
+        try:
+            line = json.dumps(obj, ensure_ascii=True)
+            sys.stdout.write(line + "\n")
+            sys.stdout.flush()
+        except Exception:
+            pass
+
+    ws = websocket.WebSocketApp(
+        url,
+        on_open=on_open,
+        on_message=on_message,
+        on_error=on_error,
+        on_close=on_close,
+    )
+
+    # Run WebSocket in background thread
+    ws_thread = threading.Thread(
+        target=ws.run_forever,
+        kwargs={"ping_interval": 20, "ping_timeout": 10},
+        daemon=True,
+    )
+    ws_thread.start()
+
+    # Read PCM chunks from stdin and forward to WebSocket
+    try:
+        while True:
+            # Read 4-byte length prefix, then that many bytes of PCM
+            header = sys.stdin.buffer.read(4)
+            if not header or len(header) < 4:
+                break
+            length = int.from_bytes(header, "little")
+            if length <= 0 or length > 1_000_000:
+                continue
+            data = sys.stdin.buffer.read(length)
+            if not data or len(data) < length:
+                break
+            if connected and ws:
+                try:
+                    ws.send(data, opcode=0x2)
+                except Exception:
+                    pass
+    except (BrokenPipeError, KeyboardInterrupt):
+        pass
+    finally:
+        if ws:
+            ws.close()
+
+
+if __name__ == "__main__":
+    main()