class OpenAISTTTranscriptionSession(StreamedTranscriptionSession):
"""OpenAI 的语音转文本(STT)模型的转录会话。"""
def __init__(
self,
input: StreamedAudioInput,
client: AsyncOpenAI,
model: str,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
):
self.connected: bool = False
self._client = client
self._model = model
self._settings = settings
self._turn_detection = settings.turn_detection or DEFAULT_TURN_DETECTION
self._trace_include_sensitive_data = trace_include_sensitive_data
self._trace_include_sensitive_audio_data = trace_include_sensitive_audio_data
self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = input.queue
self._output_queue: asyncio.Queue[str | ErrorSentinel | SessionCompleteSentinel] = (
asyncio.Queue()
)
self._websocket: websockets.ClientConnection | None = None
self._event_queue: asyncio.Queue[dict[str, Any] | WebsocketDoneSentinel] = asyncio.Queue()
self._state_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._turn_audio_buffer: list[npt.NDArray[np.int16 | np.float32]] = []
self._tracing_span: Span[TranscriptionSpanData] | None = None
# tasks
self._listener_task: asyncio.Task[Any] | None = None
self._process_events_task: asyncio.Task[Any] | None = None
self._stream_audio_task: asyncio.Task[Any] | None = None
self._connection_task: asyncio.Task[Any] | None = None
self._stored_exception: Exception | None = None
def _start_turn(self) -> None:
self._tracing_span = transcription_span(
model=self._model,
model_config={
"temperature": self._settings.temperature,
"language": self._settings.language,
"prompt": self._settings.prompt,
"turn_detection": self._turn_detection,
},
)
self._tracing_span.start()
def _end_turn(self, _transcript: str) -> None:
if len(_transcript) < 1:
return
if self._tracing_span:
if self._trace_include_sensitive_audio_data:
self._tracing_span.span_data.input = _audio_to_base64(self._turn_audio_buffer)
self._tracing_span.span_data.input_format = "pcm"
if self._trace_include_sensitive_data:
self._tracing_span.span_data.output = _transcript
self._tracing_span.finish()
self._turn_audio_buffer = []
self._tracing_span = None
async def _event_listener(self) -> None:
assert self._websocket is not None, "Websocket not initialized"
async for message in self._websocket:
try:
event = json.loads(message)
if event.get("type") == "error":
raise STTWebsocketConnectionError(f"Error event: {event.get('error')}")
if event.get("type") in [
"session.updated",
"transcription_session.updated",
"session.created",
"transcription_session.created",
]:
await self._state_queue.put(event)
await self._event_queue.put(event)
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise STTWebsocketConnectionError("Error parsing events") from e
await self._event_queue.put(WebsocketDoneSentinel())
async def _configure_session(self) -> None:
assert self._websocket is not None, "Websocket not initialized"
await self._websocket.send(
json.dumps(
{
"type": "transcription_session.update",
"session": {
"input_audio_format": "pcm16",
"input_audio_transcription": {"model": self._model},
"turn_detection": self._turn_detection,
},
}
)
)
async def _setup_connection(self, ws: websockets.ClientConnection) -> None:
self._websocket = ws
self._listener_task = asyncio.create_task(self._event_listener())
try:
event = await _wait_for_event(
self._state_queue,
["session.created", "transcription_session.created"],
SESSION_CREATION_TIMEOUT,
)
except TimeoutError as e:
wrapped_err = STTWebsocketConnectionError(
"Timeout waiting for transcription_session.created event"
)
await self._output_queue.put(ErrorSentinel(wrapped_err))
raise wrapped_err from e
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await self._configure_session()
try:
event = await _wait_for_event(
self._state_queue,
["session.updated", "transcription_session.updated"],
SESSION_UPDATE_TIMEOUT,
)
if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Session updated")
else:
logger.debug(f"Session updated: {event}")
except TimeoutError as e:
wrapped_err = STTWebsocketConnectionError(
"Timeout waiting for transcription_session.updated event"
)
await self._output_queue.put(ErrorSentinel(wrapped_err))
raise wrapped_err from e
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise
async def _handle_events(self) -> None:
while True:
try:
event = await asyncio.wait_for(
self._event_queue.get(), timeout=EVENT_INACTIVITY_TIMEOUT
)
if isinstance(event, WebsocketDoneSentinel):
# processed all events and websocket is done
break
event_type = event.get("type", "unknown")
if event_type == "conversation.item.input_audio_transcription.completed":
transcript = cast(str, event.get("transcript", ""))
if len(transcript) > 0:
self._end_turn(transcript)
self._start_turn()
await self._output_queue.put(transcript)
await asyncio.sleep(0) # yield control
except asyncio.TimeoutError:
# No new events for a while. Assume the session is done.
break
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await self._output_queue.put(SessionCompleteSentinel())
async def _stream_audio(
self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]]
) -> None:
assert self._websocket is not None, "Websocket not initialized"
self._start_turn()
while True:
buffer = await audio_queue.get()
if buffer is None:
break
self._turn_audio_buffer.append(buffer)
try:
await self._websocket.send(
json.dumps(
{
"type": "input_audio_buffer.append",
"audio": base64.b64encode(buffer.tobytes()).decode("utf-8"),
}
)
)
except websockets.ConnectionClosed:
break
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await asyncio.sleep(0) # yield control
async def _process_websocket_connection(self) -> None:
try:
async with websockets.connect(
"wss://api.openai.com/v1/realtime?intent=transcription",
additional_headers={
"Authorization": f"Bearer {self._client.api_key}",
"OpenAI-Beta": "realtime=v1",
"OpenAI-Log-Session": "1",
},
) as ws:
await self._setup_connection(ws)
self._process_events_task = asyncio.create_task(self._handle_events())
self._stream_audio_task = asyncio.create_task(self._stream_audio(self._input_queue))
self.connected = True
if self._listener_task:
await self._listener_task
else:
logger.error("Listener task not initialized")
raise AgentsException("Listener task not initialized")
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
def _check_errors(self) -> None:
if self._connection_task and self._connection_task.done():
exc = self._connection_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._process_events_task and self._process_events_task.done():
exc = self._process_events_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._stream_audio_task and self._stream_audio_task.done():
exc = self._stream_audio_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._listener_task and self._listener_task.done():
exc = self._listener_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
def _cleanup_tasks(self) -> None:
if self._listener_task and not self._listener_task.done():
self._listener_task.cancel()
if self._process_events_task and not self._process_events_task.done():
self._process_events_task.cancel()
if self._stream_audio_task and not self._stream_audio_task.done():
self._stream_audio_task.cancel()
if self._connection_task and not self._connection_task.done():
self._connection_task.cancel()
async def transcribe_turns(self) -> AsyncIterator[str]:
self._connection_task = asyncio.create_task(self._process_websocket_connection())
while True:
try:
turn = await self._output_queue.get()
except asyncio.CancelledError:
break
if (
turn is None
or isinstance(turn, ErrorSentinel)
or isinstance(turn, SessionCompleteSentinel)
):
self._output_queue.task_done()
break
yield turn
self._output_queue.task_done()
if self._tracing_span:
self._end_turn("")
if self._websocket:
await self._websocket.close()
self._check_errors()
if self._stored_exception:
raise self._stored_exception
async def close(self) -> None:
if self._websocket:
await self._websocket.close()
self._cleanup_tasks()