From 8cb4c3b7cd7242859649238e6e3ddcbc79aeafc0 Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Thu, 8 May 2025 21:47:47 +0800 Subject: [PATCH] chore: add github workflow for lint&ut (#9) --- .github/lint.yaml | 28 +++++ .github/unittest.yaml | 29 +++++ src/server/mcp_utils.py | 12 +- tests/integration/test_tts.py | 231 ++++++++++++++++++++++++++++++++++ 4 files changed, 297 insertions(+), 3 deletions(-) create mode 100644 .github/lint.yaml create mode 100644 .github/unittest.yaml create mode 100644 tests/integration/test_tts.py diff --git a/.github/lint.yaml b/.github/lint.yaml new file mode 100644 index 0000000..5fa1cb0 --- /dev/null +++ b/.github/lint.yaml @@ -0,0 +1,28 @@ +name: Lint Check + +on: + push: + branches: [ 'main' ] + pull_request: + branches: [ '*' ] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v5 + with: + version: "latest" + + - name: Install dependencies + run: | + uv venv --python 3.12 + uv pip install -e ".[dev]" + + - name: Run linters + run: | + source .venv/bin/activate + make lint \ No newline at end of file diff --git a/.github/unittest.yaml b/.github/unittest.yaml new file mode 100644 index 0000000..7e5c9fb --- /dev/null +++ b/.github/unittest.yaml @@ -0,0 +1,29 @@ +name: Test Cases Check + +on: + push: + branches: [ 'main' ] + pull_request: + branches: [ '*' ] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v5 + with: + version: "latest" + + - name: Install dependencies + run: | + uv venv --python 3.12 + uv pip install -e ".[dev]" + uv pip install -e ".[test]" + + - name: Run test cases + run: | + source .venv/bin/activate + TAVILY_API_KEY=mock-key make test \ No newline at end of file diff --git a/src/server/mcp_utils.py b/src/server/mcp_utils.py index c6f3b2c..508ded9 100644 --- a/src/server/mcp_utils.py +++ b/src/server/mcp_utils.py @@ -13,7 +13,9 @@ from mcp.client.sse import sse_client logger = logging.getLogger(__name__) -async def _get_tools_from_client_session(client_context_manager: Any, timeout_seconds: int = 10) -> List: +async def _get_tools_from_client_session( + client_context_manager: Any, timeout_seconds: int = 10 +) -> List: """ Helper function to get tools from a client session. @@ -76,7 +78,9 @@ async def load_mcp_tools( env=env, # Optional environment variables ) - return await _get_tools_from_client_session(stdio_client(server_params), timeout_seconds) + return await _get_tools_from_client_session( + stdio_client(server_params), timeout_seconds + ) elif server_type == "sse": if not url: @@ -84,7 +88,9 @@ async def load_mcp_tools( status_code=400, detail="URL is required for sse type" ) - return await _get_tools_from_client_session(sse_client(url=url), timeout_seconds) + return await _get_tools_from_client_session( + sse_client(url=url), timeout_seconds + ) else: raise HTTPException( diff --git a/tests/integration/test_tts.py b/tests/integration/test_tts.py new file mode 100644 index 0000000..a22405d --- /dev/null +++ b/tests/integration/test_tts.py @@ -0,0 +1,231 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import json +import pytest +from unittest.mock import patch, MagicMock +import uuid +import base64 + +from src.tools.tts import VolcengineTTS + + +class TestVolcengineTTS: + """Test suite for the VolcengineTTS class.""" + + def test_initialization(self): + """Test that VolcengineTTS can be properly initialized.""" + tts = VolcengineTTS( + appid="test_appid", + access_token="test_token", + cluster="test_cluster", + voice_type="test_voice", + host="test.host.com", + ) + + assert tts.appid == "test_appid" + assert tts.access_token == "test_token" + assert tts.cluster == "test_cluster" + assert tts.voice_type == "test_voice" + assert tts.host == "test.host.com" + assert tts.api_url == "https://test.host.com/api/v1/tts" + assert tts.header == {"Authorization": "Bearer;test_token"} + + def test_initialization_with_defaults(self): + """Test initialization with default values.""" + tts = VolcengineTTS( + appid="test_appid", + access_token="test_token", + ) + + assert tts.appid == "test_appid" + assert tts.access_token == "test_token" + assert tts.cluster == "volcano_tts" + assert tts.voice_type == "BV700_V2_streaming" + assert tts.host == "openspeech.bytedance.com" + assert tts.api_url == "https://openspeech.bytedance.com/api/v1/tts" + + @patch("src.tools.tts.requests.post") + def test_text_to_speech_success(self, mock_post): + """Test successful text-to-speech conversion.""" + # Mock response + mock_response = MagicMock() + mock_response.status_code = 200 + # Create a base64 encoded string for the mock audio data + mock_audio_data = base64.b64encode(b"audio_data").decode() + mock_response.json.return_value = { + "code": 0, + "message": "success", + "data": mock_audio_data, + } + mock_post.return_value = mock_response + + # Create TTS client + tts = VolcengineTTS( + appid="test_appid", + access_token="test_token", + ) + + # Call the method + result = tts.text_to_speech("Hello, world!") + + # Verify the result + assert result["success"] is True + assert result["audio_data"] == mock_audio_data + assert "response" in result + + # Verify the request + mock_post.assert_called_once() + args, _ = mock_post.call_args + assert args[0] == "https://openspeech.bytedance.com/api/v1/tts" + + # Verify request JSON - the data is passed as the second positional argument + request_json = json.loads(args[1]) + assert request_json["app"]["appid"] == "test_appid" + assert request_json["app"]["token"] == "test_token" + assert request_json["app"]["cluster"] == "volcano_tts" + assert request_json["audio"]["voice_type"] == "BV700_V2_streaming" + assert request_json["audio"]["encoding"] == "mp3" + assert request_json["request"]["text"] == "Hello, world!" + + @patch("src.tools.tts.requests.post") + def test_text_to_speech_api_error(self, mock_post): + """Test error handling when API returns an error.""" + # Mock response + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "code": 400, + "message": "Bad request", + } + mock_post.return_value = mock_response + + # Create TTS client + tts = VolcengineTTS( + appid="test_appid", + access_token="test_token", + ) + + # Call the method + result = tts.text_to_speech("Hello, world!") + + # Verify the result + assert result["success"] is False + assert result["error"] == {"code": 400, "message": "Bad request"} + assert result["audio_data"] is None + + @patch("src.tools.tts.requests.post") + def test_text_to_speech_no_data(self, mock_post): + """Test error handling when API response doesn't contain data.""" + # Mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "code": 0, + "message": "success", + # No data field + } + mock_post.return_value = mock_response + + # Create TTS client + tts = VolcengineTTS( + appid="test_appid", + access_token="test_token", + ) + + # Call the method + result = tts.text_to_speech("Hello, world!") + + # Verify the result + assert result["success"] is False + assert result["error"] == "No audio data returned" + assert result["audio_data"] is None + + @patch("src.tools.tts.requests.post") + def test_text_to_speech_with_custom_parameters(self, mock_post): + """Test text_to_speech with custom parameters.""" + # Mock response + mock_response = MagicMock() + mock_response.status_code = 200 + # Create a base64 encoded string for the mock audio data + mock_audio_data = base64.b64encode(b"audio_data").decode() + mock_response.json.return_value = { + "code": 0, + "message": "success", + "data": mock_audio_data, + } + mock_post.return_value = mock_response + + # Create TTS client + tts = VolcengineTTS( + appid="test_appid", + access_token="test_token", + ) + + # Call the method with custom parameters + result = tts.text_to_speech( + text="Custom text", + encoding="wav", + speed_ratio=1.2, + volume_ratio=0.8, + pitch_ratio=1.1, + text_type="ssml", + with_frontend=0, + frontend_type="custom", + uid="custom-uid", + ) + + # Verify the result + assert result["success"] is True + assert result["audio_data"] == mock_audio_data + + # Verify request JSON - the data is passed as the second positional argument + args, kwargs = mock_post.call_args + request_json = json.loads(args[1]) + assert request_json["audio"]["encoding"] == "wav" + assert request_json["audio"]["speed_ratio"] == 1.2 + assert request_json["audio"]["volume_ratio"] == 0.8 + assert request_json["audio"]["pitch_ratio"] == 1.1 + assert request_json["request"]["text"] == "Custom text" + assert request_json["request"]["text_type"] == "ssml" + assert request_json["request"]["with_frontend"] == 0 + assert request_json["request"]["frontend_type"] == "custom" + assert request_json["user"]["uid"] == "custom-uid" + + @patch("src.tools.tts.requests.post") + @patch("src.tools.tts.uuid.uuid4") + def test_text_to_speech_auto_generated_uid(self, mock_uuid, mock_post): + """Test that UUID is auto-generated if not provided.""" + # Mock UUID + mock_uuid_value = "test-uuid-value" + mock_uuid.return_value = mock_uuid_value + + # Mock response + mock_response = MagicMock() + mock_response.status_code = 200 + # Create a base64 encoded string for the mock audio data + mock_audio_data = base64.b64encode(b"audio_data").decode() + mock_response.json.return_value = { + "code": 0, + "message": "success", + "data": mock_audio_data, + } + mock_post.return_value = mock_response + + # Create TTS client + tts = VolcengineTTS( + appid="test_appid", + access_token="test_token", + ) + + # Call the method without providing a UID + result = tts.text_to_speech("Hello, world!") + + # Verify the result + assert result["success"] is True + assert result["audio_data"] == mock_audio_data + + # Verify the request JSON - the data is passed as the second positional argument + args, kwargs = mock_post.call_args + request_json = json.loads(args[1]) + assert request_json["user"]["uid"] == str(mock_uuid_value)