mirror of
https://git.mirrors.martin98.com/https://github.com/bytedance/deer-flow
synced 2025-08-19 20:29:11 +08:00
chore: add github workflow for lint&ut (#9)
This commit is contained in:
parent
9260c84005
commit
8cb4c3b7cd
28
.github/lint.yaml
vendored
Normal file
28
.github/lint.yaml
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
name: Lint Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ 'main' ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ '*' ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Install the latest version of uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
uv venv --python 3.12
|
||||||
|
uv pip install -e ".[dev]"
|
||||||
|
|
||||||
|
- name: Run linters
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
make lint
|
29
.github/unittest.yaml
vendored
Normal file
29
.github/unittest.yaml
vendored
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
name: Test Cases Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ 'main' ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ '*' ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Install the latest version of uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
uv venv --python 3.12
|
||||||
|
uv pip install -e ".[dev]"
|
||||||
|
uv pip install -e ".[test]"
|
||||||
|
|
||||||
|
- name: Run test cases
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
TAVILY_API_KEY=mock-key make test
|
@ -13,7 +13,9 @@ from mcp.client.sse import sse_client
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def _get_tools_from_client_session(client_context_manager: Any, timeout_seconds: int = 10) -> List:
|
async def _get_tools_from_client_session(
|
||||||
|
client_context_manager: Any, timeout_seconds: int = 10
|
||||||
|
) -> List:
|
||||||
"""
|
"""
|
||||||
Helper function to get tools from a client session.
|
Helper function to get tools from a client session.
|
||||||
|
|
||||||
@ -76,7 +78,9 @@ async def load_mcp_tools(
|
|||||||
env=env, # Optional environment variables
|
env=env, # Optional environment variables
|
||||||
)
|
)
|
||||||
|
|
||||||
return await _get_tools_from_client_session(stdio_client(server_params), timeout_seconds)
|
return await _get_tools_from_client_session(
|
||||||
|
stdio_client(server_params), timeout_seconds
|
||||||
|
)
|
||||||
|
|
||||||
elif server_type == "sse":
|
elif server_type == "sse":
|
||||||
if not url:
|
if not url:
|
||||||
@ -84,7 +88,9 @@ async def load_mcp_tools(
|
|||||||
status_code=400, detail="URL is required for sse type"
|
status_code=400, detail="URL is required for sse type"
|
||||||
)
|
)
|
||||||
|
|
||||||
return await _get_tools_from_client_session(sse_client(url=url), timeout_seconds)
|
return await _get_tools_from_client_session(
|
||||||
|
sse_client(url=url), timeout_seconds
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
231
tests/integration/test_tts.py
Normal file
231
tests/integration/test_tts.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import uuid
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from src.tools.tts import VolcengineTTS
|
||||||
|
|
||||||
|
|
||||||
|
class TestVolcengineTTS:
|
||||||
|
"""Test suite for the VolcengineTTS class."""
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
"""Test that VolcengineTTS can be properly initialized."""
|
||||||
|
tts = VolcengineTTS(
|
||||||
|
appid="test_appid",
|
||||||
|
access_token="test_token",
|
||||||
|
cluster="test_cluster",
|
||||||
|
voice_type="test_voice",
|
||||||
|
host="test.host.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tts.appid == "test_appid"
|
||||||
|
assert tts.access_token == "test_token"
|
||||||
|
assert tts.cluster == "test_cluster"
|
||||||
|
assert tts.voice_type == "test_voice"
|
||||||
|
assert tts.host == "test.host.com"
|
||||||
|
assert tts.api_url == "https://test.host.com/api/v1/tts"
|
||||||
|
assert tts.header == {"Authorization": "Bearer;test_token"}
|
||||||
|
|
||||||
|
def test_initialization_with_defaults(self):
|
||||||
|
"""Test initialization with default values."""
|
||||||
|
tts = VolcengineTTS(
|
||||||
|
appid="test_appid",
|
||||||
|
access_token="test_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tts.appid == "test_appid"
|
||||||
|
assert tts.access_token == "test_token"
|
||||||
|
assert tts.cluster == "volcano_tts"
|
||||||
|
assert tts.voice_type == "BV700_V2_streaming"
|
||||||
|
assert tts.host == "openspeech.bytedance.com"
|
||||||
|
assert tts.api_url == "https://openspeech.bytedance.com/api/v1/tts"
|
||||||
|
|
||||||
|
@patch("src.tools.tts.requests.post")
|
||||||
|
def test_text_to_speech_success(self, mock_post):
|
||||||
|
"""Test successful text-to-speech conversion."""
|
||||||
|
# Mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
# Create a base64 encoded string for the mock audio data
|
||||||
|
mock_audio_data = base64.b64encode(b"audio_data").decode()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": mock_audio_data,
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
# Create TTS client
|
||||||
|
tts = VolcengineTTS(
|
||||||
|
appid="test_appid",
|
||||||
|
access_token="test_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = tts.text_to_speech("Hello, world!")
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["audio_data"] == mock_audio_data
|
||||||
|
assert "response" in result
|
||||||
|
|
||||||
|
# Verify the request
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
args, _ = mock_post.call_args
|
||||||
|
assert args[0] == "https://openspeech.bytedance.com/api/v1/tts"
|
||||||
|
|
||||||
|
# Verify request JSON - the data is passed as the second positional argument
|
||||||
|
request_json = json.loads(args[1])
|
||||||
|
assert request_json["app"]["appid"] == "test_appid"
|
||||||
|
assert request_json["app"]["token"] == "test_token"
|
||||||
|
assert request_json["app"]["cluster"] == "volcano_tts"
|
||||||
|
assert request_json["audio"]["voice_type"] == "BV700_V2_streaming"
|
||||||
|
assert request_json["audio"]["encoding"] == "mp3"
|
||||||
|
assert request_json["request"]["text"] == "Hello, world!"
|
||||||
|
|
||||||
|
@patch("src.tools.tts.requests.post")
|
||||||
|
def test_text_to_speech_api_error(self, mock_post):
|
||||||
|
"""Test error handling when API returns an error."""
|
||||||
|
# Mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 400
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"code": 400,
|
||||||
|
"message": "Bad request",
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
# Create TTS client
|
||||||
|
tts = VolcengineTTS(
|
||||||
|
appid="test_appid",
|
||||||
|
access_token="test_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = tts.text_to_speech("Hello, world!")
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result["success"] is False
|
||||||
|
assert result["error"] == {"code": 400, "message": "Bad request"}
|
||||||
|
assert result["audio_data"] is None
|
||||||
|
|
||||||
|
@patch("src.tools.tts.requests.post")
|
||||||
|
def test_text_to_speech_no_data(self, mock_post):
|
||||||
|
"""Test error handling when API response doesn't contain data."""
|
||||||
|
# Mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
# No data field
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
# Create TTS client
|
||||||
|
tts = VolcengineTTS(
|
||||||
|
appid="test_appid",
|
||||||
|
access_token="test_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = tts.text_to_speech("Hello, world!")
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result["success"] is False
|
||||||
|
assert result["error"] == "No audio data returned"
|
||||||
|
assert result["audio_data"] is None
|
||||||
|
|
||||||
|
@patch("src.tools.tts.requests.post")
|
||||||
|
def test_text_to_speech_with_custom_parameters(self, mock_post):
|
||||||
|
"""Test text_to_speech with custom parameters."""
|
||||||
|
# Mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
# Create a base64 encoded string for the mock audio data
|
||||||
|
mock_audio_data = base64.b64encode(b"audio_data").decode()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": mock_audio_data,
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
# Create TTS client
|
||||||
|
tts = VolcengineTTS(
|
||||||
|
appid="test_appid",
|
||||||
|
access_token="test_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method with custom parameters
|
||||||
|
result = tts.text_to_speech(
|
||||||
|
text="Custom text",
|
||||||
|
encoding="wav",
|
||||||
|
speed_ratio=1.2,
|
||||||
|
volume_ratio=0.8,
|
||||||
|
pitch_ratio=1.1,
|
||||||
|
text_type="ssml",
|
||||||
|
with_frontend=0,
|
||||||
|
frontend_type="custom",
|
||||||
|
uid="custom-uid",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["audio_data"] == mock_audio_data
|
||||||
|
|
||||||
|
# Verify request JSON - the data is passed as the second positional argument
|
||||||
|
args, kwargs = mock_post.call_args
|
||||||
|
request_json = json.loads(args[1])
|
||||||
|
assert request_json["audio"]["encoding"] == "wav"
|
||||||
|
assert request_json["audio"]["speed_ratio"] == 1.2
|
||||||
|
assert request_json["audio"]["volume_ratio"] == 0.8
|
||||||
|
assert request_json["audio"]["pitch_ratio"] == 1.1
|
||||||
|
assert request_json["request"]["text"] == "Custom text"
|
||||||
|
assert request_json["request"]["text_type"] == "ssml"
|
||||||
|
assert request_json["request"]["with_frontend"] == 0
|
||||||
|
assert request_json["request"]["frontend_type"] == "custom"
|
||||||
|
assert request_json["user"]["uid"] == "custom-uid"
|
||||||
|
|
||||||
|
@patch("src.tools.tts.requests.post")
|
||||||
|
@patch("src.tools.tts.uuid.uuid4")
|
||||||
|
def test_text_to_speech_auto_generated_uid(self, mock_uuid, mock_post):
|
||||||
|
"""Test that UUID is auto-generated if not provided."""
|
||||||
|
# Mock UUID
|
||||||
|
mock_uuid_value = "test-uuid-value"
|
||||||
|
mock_uuid.return_value = mock_uuid_value
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
# Create a base64 encoded string for the mock audio data
|
||||||
|
mock_audio_data = base64.b64encode(b"audio_data").decode()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": mock_audio_data,
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
# Create TTS client
|
||||||
|
tts = VolcengineTTS(
|
||||||
|
appid="test_appid",
|
||||||
|
access_token="test_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method without providing a UID
|
||||||
|
result = tts.text_to_speech("Hello, world!")
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["audio_data"] == mock_audio_data
|
||||||
|
|
||||||
|
# Verify the request JSON - the data is passed as the second positional argument
|
||||||
|
args, kwargs = mock_post.call_args
|
||||||
|
request_json = json.loads(args[1])
|
||||||
|
assert request_json["user"]["uid"] == str(mock_uuid_value)
|
Loading…
x
Reference in New Issue
Block a user