From a6ab97c970060de23a49b6dcfd7da78f274cea33 Mon Sep 17 00:00:00 2001 From: He Tao Date: Fri, 18 Apr 2025 15:28:31 +0800 Subject: [PATCH] feat: integrate volcengine tts functionality --- .env.example | 8 ++- README.md | 33 +++++++++- src/server/app.py | 65 +++++++++++++++++- src/server/chat_request.py | 18 ++++- src/tools/__init__.py | 4 ++ src/tools/tts.py | 131 +++++++++++++++++++++++++++++++++++++ 6 files changed, 251 insertions(+), 8 deletions(-) create mode 100644 src/tools/tts.py diff --git a/.env.example b/.env.example index 8227804..4309b55 100644 --- a/.env.example +++ b/.env.example @@ -2,10 +2,14 @@ DEBUG=True APP_ENV=development -# Add other environment variables as needed -# tavily, duckduckgo, brave_search, arxiv +# Search Engine SEARCH_API=tavily TAVILY_API_KEY=tvly-xxx BRAVE_SEARCH_API_KEY=brave-xxx # JINA_API_KEY=jina_xxx # Optional, default is None +# Volcengine TTS +VOLCENGINE_TTS_APPID=xxx +VOLCENGINE_TTS_ACCESS_TOKEN=xxx +# VOLCENGINE_TTS_CLUSTER=volcano_tts # Optional, default is volcano_tts +# VOLCENGINE_TTS_VOICE_TYPE=BV700_V2_streaming # Optional, default is BV700_V2_streaming diff --git a/README.md b/README.md index f7c6f28..24e7d66 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,13 @@ cd deer-flow # Install dependencies, uv will take care of the python interpreter and venv creation, and install the required packages uv sync -# Configure .env with your Search Engine API keys +# Configure .env with your API keys # Tavily: https://app.tavily.com/home # Brave_SEARCH: https://brave.com/search/api/ +# volcengine TTS: Add your TTS credentials if you have them cp .env.example .env -# See the 'Supported Search Engines' section below for all available options +# See the 'Supported Search Engines' and 'Text-to-Speech Integration' sections below for all available options # Configure conf.yaml for your LLM model and API keys # Gemini: https://ai.google.dev/gemini-api/docs/openai @@ -120,6 +121,34 @@ The system employs a streamlined workflow with the following components: - Processes and structures the collected information - Generates comprehensive research reports +## Text-to-Speech Integration + +DeerFlow now includes a Text-to-Speech (TTS) feature that allows you to convert research reports to speech. This feature uses the volcengine TTS API to generate high-quality audio from text. + +### Features + +- Convert any text or research report to natural-sounding speech +- Adjust speech parameters like speed, volume, and pitch +- Support for multiple voice types +- Available through both API and web interface + +### Using the TTS API + +You can access the TTS functionality through the `/api/tts` endpoint: + +```bash +# Example API call using curl +curl --location 'http://localhost:8000/api/tts' \ +--header 'Content-Type: application/json' \ +--data '{ + "text": "This is a test of the text-to-speech functionality.", + "speed_ratio": 1.0, + "volume_ratio": 1.0, + "pitch_ratio": 1.0 +}' \ +--output speech.mp3 +``` + ## Examples The following examples demonstrate the capabilities of DeerFlow: diff --git a/src/server/app.py b/src/server/app.py index 1eca654..3f6a982 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -1,19 +1,22 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import base64 import json import logging +import os from typing import List, cast from uuid import uuid4 -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, Response from langchain_core.messages import AIMessageChunk, ToolMessage from langgraph.types import Command from src.graph.builder import build_graph -from src.server.chat_request import ChatMessage, ChatRequest +from src.server.chat_request import ChatMessage, ChatRequest, TTSRequest +from src.tools import VolcengineTTS logger = logging.getLogger(__name__) @@ -137,3 +140,59 @@ def _make_event(event_type: str, data: dict[str, any]): if data.get("content") == "": data.pop("content") return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + +@app.post("/api/tts") +async def text_to_speech(request: TTSRequest): + """Convert text to speech using volcengine TTS API.""" + try: + app_id = os.getenv("VOLCENGINE_TTS_APPID", "") + if not app_id: + raise HTTPException( + status_code=400, detail="VOLCENGINE_TTS_APPID is not set" + ) + access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "") + if not access_token: + raise HTTPException( + status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set" + ) + cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts") + voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming") + + tts_client = VolcengineTTS( + appid=app_id, + access_token=access_token, + cluster=cluster, + voice_type=voice_type, + ) + # Call the TTS API + result = tts_client.text_to_speech( + text=request.text[:1024], + encoding=request.encoding, + speed_ratio=request.speed_ratio, + volume_ratio=request.volume_ratio, + pitch_ratio=request.pitch_ratio, + text_type=request.text_type, + with_frontend=request.with_frontend, + frontend_type=request.frontend_type, + ) + + if not result["success"]: + raise HTTPException(status_code=500, detail=str(result["error"])) + + # Decode the base64 audio data + audio_data = base64.b64decode(result["audio_data"]) + + # Return the audio file + return Response( + content=audio_data, + media_type=f"audio/{request.encoding}", + headers={ + "Content-Disposition": ( + f"attachment; filename=tts_output.{request.encoding}" + ) + }, + ) + except Exception as e: + logger.exception(f"Error in TTS endpoint: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/server/chat_request.py b/src/server/chat_request.py index 55d2472..4601ad7 100644 --- a/src/server/chat_request.py +++ b/src/server/chat_request.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Any from pydantic import BaseModel, Field @@ -44,3 +44,19 @@ class ChatRequest(BaseModel): interrupt_feedback: Optional[str] = Field( None, description="Interrupt feedback from the user on the plan" ) + + +class TTSRequest(BaseModel): + text: str = Field(..., description="The text to convert to speech") + voice_type: Optional[str] = Field( + "BV700_V2_streaming", description="The voice type to use" + ) + encoding: Optional[str] = Field("mp3", description="The audio encoding format") + speed_ratio: Optional[float] = Field(1.0, description="Speech speed ratio") + volume_ratio: Optional[float] = Field(1.0, description="Speech volume ratio") + pitch_ratio: Optional[float] = Field(1.0, description="Speech pitch ratio") + text_type: Optional[str] = Field("plain", description="Text type (plain or ssml)") + with_frontend: Optional[int] = Field( + 1, description="Whether to use frontend processing" + ) + frontend_type: Optional[str] = Field("unitTson", description="Frontend type") diff --git a/src/tools/__init__.py b/src/tools/__init__.py index 774938f..7854f94 100644 --- a/src/tools/__init__.py +++ b/src/tools/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import os + from .crawl import crawl_tool from .python_repl import python_repl_tool from .search import ( @@ -9,6 +11,7 @@ from .search import ( brave_search_tool, arxiv_search_tool, ) +from .tts import VolcengineTTS from src.config import SELECTED_SEARCH_ENGINE, SearchEngine # Map search engine names to their respective tools @@ -25,4 +28,5 @@ __all__ = [ "crawl_tool", "web_search_tool", "python_repl_tool", + "VolcengineTTS", ] diff --git a/src/tools/tts.py b/src/tools/tts.py new file mode 100644 index 0000000..58e4c2f --- /dev/null +++ b/src/tools/tts.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Text-to-Speech module using volcengine TTS API. +""" + +import json +import uuid +import logging +import requests +from typing import Optional, Dict, Any + +logger = logging.getLogger(__name__) + + +class VolcengineTTS: + """ + Client for volcengine Text-to-Speech API. + """ + + def __init__( + self, + appid: str, + access_token: str, + cluster: str = "volcano_tts", + voice_type: str = "BV700_V2_streaming", + host: str = "openspeech.bytedance.com", + ): + """ + Initialize the volcengine TTS client. + + Args: + appid: Platform application ID + access_token: Access token for authentication + cluster: TTS cluster name + voice_type: Voice type to use + host: API host + """ + self.appid = appid + self.access_token = access_token + self.cluster = cluster + self.voice_type = voice_type + self.host = host + self.api_url = f"https://{host}/api/v1/tts" + self.header = {"Authorization": f"Bearer;{access_token}"} + + def text_to_speech( + self, + text: str, + encoding: str = "mp3", + speed_ratio: float = 1.0, + volume_ratio: float = 1.0, + pitch_ratio: float = 1.0, + text_type: str = "plain", + with_frontend: int = 1, + frontend_type: str = "unitTson", + uid: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Convert text to speech using volcengine TTS API. + + Args: + text: Text to convert to speech + encoding: Audio encoding format + speed_ratio: Speech speed ratio + volume_ratio: Speech volume ratio + pitch_ratio: Speech pitch ratio + text_type: Text type (plain or ssml) + with_frontend: Whether to use frontend processing + frontend_type: Frontend type + uid: User ID (generated if not provided) + + Returns: + Dictionary containing the API response and base64-encoded audio data + """ + if not uid: + uid = str(uuid.uuid4()) + + request_json = { + "app": { + "appid": self.appid, + "token": self.access_token, + "cluster": self.cluster, + }, + "user": {"uid": uid}, + "audio": { + "voice_type": self.voice_type, + "encoding": encoding, + "speed_ratio": speed_ratio, + "volume_ratio": volume_ratio, + "pitch_ratio": pitch_ratio, + }, + "request": { + "reqid": str(uuid.uuid4()), + "text": text, + "text_type": text_type, + "operation": "query", + "with_frontend": with_frontend, + "frontend_type": frontend_type, + }, + } + + try: + logger.debug(f"Sending TTS request for text: {text[:50]}...") + response = requests.post( + self.api_url, json.dumps(request_json), headers=self.header + ) + response_json = response.json() + + if response.status_code != 200: + logger.error(f"TTS API error: {response_json}") + return {"success": False, "error": response_json, "audio_data": None} + + if "data" not in response_json: + logger.error(f"TTS API returned no data: {response_json}") + return { + "success": False, + "error": "No audio data returned", + "audio_data": None, + } + + return { + "success": True, + "response": response_json, + "audio_data": response_json["data"], # Base64 encoded audio data + } + + except Exception as e: + logger.exception(f"Error in TTS API call: {str(e)}") + return {"success": False, "error": str(e), "audio_data": None}