From 228de1f12acccbccadd5a990cd0bb33765fc0323 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 10 May 2024 18:14:48 +0800 Subject: [PATCH] fix: miss usage of os.path.join for URL assembly and add tests on yarl (#4224) --- .../model_providers/chatglm/llm/llm.py | 4 ++-- .../provider/builtin/dalle/tools/dalle2.py | 4 ++-- .../provider/builtin/dalle/tools/dalle3.py | 4 ++-- api/tests/unit_tests/libs/test_yarl.py | 23 +++++++++++++++++++ 4 files changed, 29 insertions(+), 6 deletions(-) create mode 100644 api/tests/unit_tests/libs/test_yarl.py diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index 12dc75aece..e83d08af71 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -1,6 +1,5 @@ import logging from collections.abc import Generator -from os.path import join from typing import Optional, cast from httpx import Timeout @@ -19,6 +18,7 @@ from openai import ( ) from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_message import FunctionCall +from yarl import URL from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": join(credentials['api_base'], 'v1') + "base_url": str(URL(credentials['api_base']) / 'v1') } return client_kwargs diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index e41cbd9f65..450e782281 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -1,8 +1,8 @@ from base64 import b64decode -from os.path import join from typing import Any, Union from openai import OpenAI +from yarl import URL from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -23,7 +23,7 @@ class DallE2Tool(BuiltinTool): if not openai_base_url: openai_base_url = None else: - openai_base_url = join(openai_base_url, 'v1') + openai_base_url = str(URL(openai_base_url) / 'v1') client = OpenAI( api_key=self.runtime.credentials['openai_api_key'], diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index dc53025b02..87d18f68e0 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -1,8 +1,8 @@ from base64 import b64decode -from os.path import join from typing import Any, Union from openai import OpenAI +from yarl import URL from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -23,7 +23,7 @@ class DallE3Tool(BuiltinTool): if not openai_base_url: openai_base_url = None else: - openai_base_url = join(openai_base_url, 'v1') + openai_base_url = str(URL(openai_base_url) / 'v1') client = OpenAI( api_key=self.runtime.credentials['openai_api_key'], diff --git a/api/tests/unit_tests/libs/test_yarl.py b/api/tests/unit_tests/libs/test_yarl.py new file mode 100644 index 0000000000..75a5344126 --- /dev/null +++ b/api/tests/unit_tests/libs/test_yarl.py @@ -0,0 +1,23 @@ +import pytest +from yarl import URL + + +def test_yarl_urls(): + expected_1 = 'https://dify.ai/api' + assert str(URL('https://dify.ai') / 'api') == expected_1 + assert str(URL('https://dify.ai/') / 'api') == expected_1 + + expected_2 = 'http://dify.ai:12345/api' + assert str(URL('http://dify.ai:12345') / 'api') == expected_2 + assert str(URL('http://dify.ai:12345/') / 'api') == expected_2 + + expected_3 = 'https://dify.ai/api/v1' + assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3 + assert str(URL('https://dify.ai') / 'api/v1') == expected_3 + assert str(URL('https://dify.ai/') / 'api/v1') == expected_3 + assert str(URL('https://dify.ai/api') / 'v1') == expected_3 + assert str(URL('https://dify.ai/api/') / 'v1') == expected_3 + + with pytest.raises(ValueError) as e1: + str(URL('https://dify.ai') / '/api') + assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"