fix: miss usage of os.path.join for URL assembly and add tests on yarl (#4224)

This commit is contained in:
Bowen Liang 2024-05-10 18:14:48 +08:00 committed by GitHub
parent 01555463d2
commit 228de1f12a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 6 deletions

View File

@ -1,6 +1,5 @@
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from os.path import join
from typing import Optional, cast from typing import Optional, cast
from httpx import Timeout from httpx import Timeout
@ -19,6 +18,7 @@ from openai import (
) )
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall from openai.types.chat.chat_completion_message import FunctionCall
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
client_kwargs = { client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1", "api_key": "1",
"base_url": join(credentials['api_base'], 'v1') "base_url": str(URL(credentials['api_base']) / 'v1')
} }
return client_kwargs return client_kwargs

View File

@ -1,8 +1,8 @@
from base64 import b64decode from base64 import b64decode
from os.path import join
from typing import Any, Union from typing import Any, Union
from openai import OpenAI from openai import OpenAI
from yarl import URL
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
@ -23,7 +23,7 @@ class DallE2Tool(BuiltinTool):
if not openai_base_url: if not openai_base_url:
openai_base_url = None openai_base_url = None
else: else:
openai_base_url = join(openai_base_url, 'v1') openai_base_url = str(URL(openai_base_url) / 'v1')
client = OpenAI( client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'], api_key=self.runtime.credentials['openai_api_key'],

View File

@ -1,8 +1,8 @@
from base64 import b64decode from base64 import b64decode
from os.path import join
from typing import Any, Union from typing import Any, Union
from openai import OpenAI from openai import OpenAI
from yarl import URL
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
@ -23,7 +23,7 @@ class DallE3Tool(BuiltinTool):
if not openai_base_url: if not openai_base_url:
openai_base_url = None openai_base_url = None
else: else:
openai_base_url = join(openai_base_url, 'v1') openai_base_url = str(URL(openai_base_url) / 'v1')
client = OpenAI( client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'], api_key=self.runtime.credentials['openai_api_key'],

View File

@ -0,0 +1,23 @@
import pytest
from yarl import URL
def test_yarl_urls():
expected_1 = 'https://dify.ai/api'
assert str(URL('https://dify.ai') / 'api') == expected_1
assert str(URL('https://dify.ai/') / 'api') == expected_1
expected_2 = 'http://dify.ai:12345/api'
assert str(URL('http://dify.ai:12345') / 'api') == expected_2
assert str(URL('http://dify.ai:12345/') / 'api') == expected_2
expected_3 = 'https://dify.ai/api/v1'
assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3
assert str(URL('https://dify.ai') / 'api/v1') == expected_3
assert str(URL('https://dify.ai/') / 'api/v1') == expected_3
assert str(URL('https://dify.ai/api') / 'v1') == expected_3
assert str(URL('https://dify.ai/api/') / 'v1') == expected_3
with pytest.raises(ValueError) as e1:
str(URL('https://dify.ai') / '/api')
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"