class VoicePipeline:
"""一个主观化的语音代理流水线。它分为三个步骤:
1. 将音频输入转录为文本。
2. 运行提供的 `workflow`,生成一系列文本响应。
3. 将文本响应转换为流式音频输出。
"""
def __init__(
self,
*,
workflow: VoiceWorkflowBase,
stt_model: STTModel | str | None = None,
tts_model: TTSModel | str | None = None,
config: VoicePipelineConfig | None = None,
):
"""创建一个新的语音流水线。
参数:
workflow: 要运行的工作流。参见 `VoiceWorkflowBase`。
stt_model: 要使用的语音转文本模型。如果未提供,将使用默认的 OpenAI 模型。
tts_model: 要使用的文本转语音模型。如果未提供,将使用默认的 OpenAI 模型。
config: 流水线配置。如果未提供,将使用默认配置。
"""
self.workflow = workflow
self.stt_model = stt_model if isinstance(stt_model, STTModel) else None
self.tts_model = tts_model if isinstance(tts_model, TTSModel) else None
self._stt_model_name = stt_model if isinstance(stt_model, str) else None
self._tts_model_name = tts_model if isinstance(tts_model, str) else None
self.config = config or VoicePipelineConfig()
async def run(self, audio_input: AudioInput | StreamedAudioInput) -> StreamedAudioResult:
"""运行语音流水线。
参数:
audio_input: 要处理的音频输入。可以是 `AudioInput` 实例(单个静态缓冲区),
也可以是 `StreamedAudioInput` 实例(可以追加音频数据的流)。
返回:
一个 `StreamedAudioResult` 实例。你可以使用该对象来流式传输音频事件并播放它们。
"""
if isinstance(audio_input, AudioInput):
return await self._run_single_turn(audio_input)
elif isinstance(audio_input, StreamedAudioInput):
return await self._run_multi_turn(audio_input)
else:
raise UserError(f"Unsupported audio input type: {type(audio_input)}")
def _get_tts_model(self) -> TTSModel:
if not self.tts_model:
self.tts_model = self.config.model_provider.get_tts_model(self._tts_model_name)
return self.tts_model
def _get_stt_model(self) -> STTModel:
if not self.stt_model:
self.stt_model = self.config.model_provider.get_stt_model(self._stt_model_name)
return self.stt_model
async def _process_audio_input(self, audio_input: AudioInput) -> str:
model = self._get_stt_model()
return await model.transcribe(
audio_input,
self.config.stt_settings,
self.config.trace_include_sensitive_data,
self.config.trace_include_sensitive_audio_data,
)
async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult:
# Since this is single turn, we can use the TraceCtxManager to manage starting/ending the
# trace
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None, # Automatically generated
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
disabled=self.config.tracing_disabled,
):
input_text = await self._process_audio_input(audio_input)
output = StreamedAudioResult(
self._get_tts_model(), self.config.tts_settings, self.config
)
async def stream_events():
try:
async for text_event in self.workflow.run(input_text):
await output._add_text(text_event)
await output._turn_done()
await output._done()
except Exception as e:
logger.error(f"Error processing single turn: {e}")
await output._add_error(e)
raise e
output._set_task(asyncio.create_task(stream_events()))
return output
async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult:
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None,
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
disabled=self.config.tracing_disabled,
):
output = StreamedAudioResult(
self._get_tts_model(), self.config.tts_settings, self.config
)
transcription_session = await self._get_stt_model().create_session(
audio_input,
self.config.stt_settings,
self.config.trace_include_sensitive_data,
self.config.trace_include_sensitive_audio_data,
)
async def process_turns():
try:
async for input_text in transcription_session.transcribe_turns():
result = self.workflow.run(input_text)
async for text_event in result:
await output._add_text(text_event)
await output._turn_done()
except Exception as e:
logger.error(f"Error processing turns: {e}")
await output._add_error(e)
raise e
finally:
await transcription_session.close()
await output._done()
output._set_task(asyncio.create_task(process_turns()))
return output