File size: 3,891 Bytes
e954acb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import json
import threading
import time
import urllib.parse
from typing import Optional, Callable
import websocket


class FireworksTranscription:
    """Fireworks AI transcription for Gradio integration."""

    WEBSOCKET_URL = "ws://audio-streaming.us-virginia-1.direct.fireworks.ai/v1/audio/transcriptions/streaming"

    def __init__(self, api_key: str):
        self.api_key = api_key
        self.websocket_client = None
        self.is_connected = False
        self.segments = {}
        self.lock = threading.Lock()
        self.transcription_callback: Optional[Callable[[str], None]] = None

    def set_callback(self, callback: Callable[[str], None]):
        """Set callback to receive live transcription updates."""
        self.transcription_callback = callback

    def _connect(self) -> bool:
        """Connect to Fireworks WebSocket."""
        try:
            params = urllib.parse.urlencode({"language": "en"})
            full_url = f"{self.WEBSOCKET_URL}?{params}"

            self.websocket_client = websocket.WebSocketApp(
                full_url,
                header={"Authorization": self.api_key},
                on_open=self._on_open,
                on_message=self._on_message,
                on_error=self._on_error,
            )

            # Start WebSocket in background thread
            ws_thread = threading.Thread(
                target=self.websocket_client.run_forever, daemon=True
            )
            ws_thread.start()

            # Wait for connection (max 5 seconds)
            timeout = 5
            start_time = time.time()
            while not self.is_connected and (time.time() - start_time) < timeout:
                time.sleep(0.1)

            return self.is_connected

        except Exception as e:
            print(f"Connection error: {e}")
            return False

    def _send_audio_chunk(self, chunk: bytes) -> bool:
        """Send audio chunk to Fireworks."""
        if not self.is_connected or not self.websocket_client:
            return False

        try:
            self.websocket_client.send(chunk, opcode=websocket.ABNF.OPCODE_BINARY)
            return True
        except Exception as e:
            print(f"Error sending audio chunk: {e}")
            return False

    def _on_open(self, ws):
        """Handle WebSocket connection opening."""
        self.is_connected = True
        print("✅ Connected to Fireworks transcription service")

    def _on_message(self, ws, message):
        """Handle transcription messages from Fireworks."""
        try:
            data = json.loads(message)

            # Process segments
            if "segments" in data:
                with self.lock:
                    # Update segments
                    for segment in data["segments"]:
                        segment_id = segment["id"]
                        text = segment["text"]
                        self.segments[segment_id] = text

                    # Build complete current transcription
                    complete_text = self._build_complete_text()

                    # Call callback with live update
                    if self.transcription_callback and complete_text.strip():
                        self.transcription_callback(complete_text)

        except json.JSONDecodeError as e:
            print(f"Failed to parse message: {e}")
        except Exception as e:
            print(f"Error processing message: {e}")

    @staticmethod
    def _on_error(ws, error):
        """Handle WebSocket errors."""
        print(f"WebSocket error: {error}")

    def _build_complete_text(self) -> str:
        """Build complete text from all segments."""
        if not self.segments:
            return ""

        sorted_segments = sorted(self.segments.items(), key=lambda x: int(x[0]))
        return " ".join(segment[1] for segment in sorted_segments if segment[1].strip())