mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 12:39:04 +08:00
add example api url endpoint in placeholder (#1887)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
parent
5ca4c4a44d
commit
77f9e8ce0f
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast
|
|||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.utils import helper
|
from core.model_runtime.utils import helper
|
||||||
|
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
|
||||||
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage
|
AssistantPromptMessage, PromptMessageContent, \
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \
|
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
|
||||||
|
ToolPromptMessage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
|
||||||
|
DefaultParameterName, \
|
||||||
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
|
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
endpoint_url = credentials['endpoint_url']
|
endpoint_url = credentials['endpoint_url']
|
||||||
|
if not endpoint_url.endswith('/'):
|
||||||
|
endpoint_url += '/'
|
||||||
|
|
||||||
# prepare the payload for a simple ping to the model
|
# prepare the payload for a simple ping to the model
|
||||||
data = {
|
data = {
|
||||||
@ -105,8 +111,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
"content": "ping"
|
"content": "ping"
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
elif completion_type is LLMMode.COMPLETION:
|
||||||
data['prompt'] = 'ping'
|
data['prompt'] = 'ping'
|
||||||
|
endpoint_url = urljoin(endpoint_url, 'completions')
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported completion type for model configuration.")
|
raise ValueError("Unsupported completion type for model configuration.")
|
||||||
|
|
||||||
@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}')
|
raise CredentialsValidateFailedError(
|
||||||
|
f'Credentials validation failed with status code {response.status_code}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_result = response.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
|
||||||
|
|
||||||
|
if (completion_type is LLMMode.CHAT
|
||||||
|
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
|
||||||
|
raise CredentialsValidateFailedError(
|
||||||
|
f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
|
||||||
|
elif (completion_type is LLMMode.COMPLETION
|
||||||
|
and ('object' not in json_result or json_result['object'] != 'text_completion')):
|
||||||
|
raise CredentialsValidateFailedError(
|
||||||
|
f'Credentials validation failed: invalid response object, must be \'text_completion\'')
|
||||||
|
except CredentialsValidateFailedError:
|
||||||
|
raise
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
||||||
|
|
||||||
@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||||
ModelPropertyKey.MODE: 'chat'
|
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
@ -197,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
|
|
||||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
|
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
stream: bool = True, \
|
||||||
|
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke llm completion model
|
Invoke llm completion model
|
||||||
|
|
||||||
@ -223,6 +247,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
endpoint_url = credentials["endpoint_url"]
|
endpoint_url = credentials["endpoint_url"]
|
||||||
|
if not endpoint_url.endswith('/'):
|
||||||
|
endpoint_url += '/'
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
completion_type = LLMMode.value_of(credentials['mode'])
|
completion_type = LLMMode.value_of(credentials['mode'])
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
|
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
||||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||||
elif completion_type == LLMMode.COMPLETION:
|
elif completion_type == LLMMode.COMPLETION:
|
||||||
|
endpoint_url = urljoin(endpoint_url, 'completions')
|
||||||
data['prompt'] = prompt_messages[0].content
|
data['prompt'] = prompt_messages[0].content
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported completion type for model configuration.")
|
raise ValueError("Unsupported completion type for model configuration.")
|
||||||
@ -245,7 +273,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
data["tool_choice"] = "auto"
|
data["tool_choice"] = "auto"
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool)))
|
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||||
|
|
||||||
data["tools"] = formatted_tools
|
data["tools"] = formatted_tools
|
||||||
|
|
||||||
@ -276,7 +304,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
prompt_messages: list[PromptMessage]) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm stream response
|
Handle llm stream response
|
||||||
|
|
||||||
@ -313,6 +341,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
if chunk:
|
if chunk:
|
||||||
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
|
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
|
||||||
|
|
||||||
|
chunk_json = None
|
||||||
try:
|
try:
|
||||||
chunk_json = json.loads(decoded_chunk)
|
chunk_json = json.loads(decoded_chunk)
|
||||||
# stream ended
|
# stream ended
|
||||||
@ -323,41 +352,53 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
finish_reason="Non-JSON encountered."
|
finish_reason="Non-JSON encountered."
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(chunk_json['choices']) == 0:
|
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = chunk_json['choices'][0]['delta']
|
choice = chunk_json['choices'][0]
|
||||||
chunk_index = chunk_json['choices'][0]['index']
|
chunk_index = choice['index'] if 'index' in choice else chunk_index
|
||||||
|
|
||||||
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''):
|
if 'delta' in choice:
|
||||||
|
delta = choice['delta']
|
||||||
|
if delta.get('content') is None or delta.get('content') == '':
|
||||||
|
continue
|
||||||
|
|
||||||
|
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||||
|
# assistant_message_function_call = delta.delta.function_call
|
||||||
|
|
||||||
|
# extract tool calls from response
|
||||||
|
if assistant_message_tool_calls:
|
||||||
|
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
|
# function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||||
|
# tool_calls = [function_call] if function_call else []
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=delta.get('content', ''),
|
||||||
|
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||||
|
)
|
||||||
|
|
||||||
|
full_assistant_content += delta.get('content', '')
|
||||||
|
elif 'text' in choice:
|
||||||
|
if choice.get('text') is None or choice.get('text') == '':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=choice.get('text', '')
|
||||||
|
)
|
||||||
|
|
||||||
|
full_assistant_content += choice.get('text', '')
|
||||||
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
|
||||||
# assistant_message_function_call = delta.delta.function_call
|
|
||||||
|
|
||||||
# extract tool calls from response
|
|
||||||
if assistant_message_tool_calls:
|
|
||||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
|
||||||
# function_call = self._extract_response_function_call(assistant_message_function_call)
|
|
||||||
# tool_calls = [function_call] if function_call else []
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
|
||||||
content=delta.get('content', ''),
|
|
||||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
|
||||||
)
|
|
||||||
|
|
||||||
full_assistant_content += delta.get('content', '')
|
|
||||||
|
|
||||||
# check payload indicator for completion
|
# check payload indicator for completion
|
||||||
if chunk_json['choices'][0].get('finish_reason') is not None:
|
if chunk_json['choices'][0].get('finish_reason') is not None:
|
||||||
|
|
||||||
yield create_final_llm_result_chunk(
|
yield create_final_llm_result_chunk(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=chunk_json['choices'][0]['finish_reason']
|
finish_reason=chunk_json['choices'][0]['finish_reason']
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
@ -374,8 +415,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
finish_reason="End of stream."
|
finish_reason="End of stream."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
|
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in
|
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||||
|
in
|
||||||
message.tool_calls]
|
message.tool_calls]
|
||||||
# function_call = message.tool_calls[0]
|
# function_call = message.tool_calls[0]
|
||||||
# message_dict["function_call"] = {
|
# message_dict["function_call"] = {
|
||||||
|
@ -33,8 +33,8 @@ model_credential_schema:
|
|||||||
type: text-input
|
type: text-input
|
||||||
required: true
|
required: true
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的 API endpoint URL
|
zh_Hans: Base URL, eg. https://api.openai.com/v1
|
||||||
en_US: Enter your API endpoint URL
|
en_US: Base URL, eg. https://api.openai.com/v1
|
||||||
- variable: mode
|
- variable: mode
|
||||||
show_on:
|
show_on:
|
||||||
- variable: __model_type
|
- variable: __model_type
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from urllib.parse import urljoin
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
|||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
endpoint_url = credentials.get('endpoint_url')
|
||||||
|
if not endpoint_url.endswith('/'):
|
||||||
|
endpoint_url += '/'
|
||||||
|
|
||||||
endpoint_url = credentials['endpoint_url']
|
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
if user:
|
if user:
|
||||||
@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
|||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
endpoint_url = credentials.get('endpoint_url')
|
||||||
|
if not endpoint_url.endswith('/'):
|
||||||
|
endpoint_url += '/'
|
||||||
|
|
||||||
endpoint_url = credentials['endpoint_url']
|
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
'input': 'ping',
|
'input': 'ping',
|
||||||
@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}")
|
raise CredentialsValidateFailedError(
|
||||||
|
f'Credentials validation failed with status code {response.status_code}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_result = response.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
|
||||||
|
|
||||||
|
if 'model' not in json_result:
|
||||||
|
raise CredentialsValidateFailedError(
|
||||||
|
f'Credentials validation failed: invalid response')
|
||||||
|
except CredentialsValidateFailedError:
|
||||||
|
raise
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
|||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||||
},
|
},
|
||||||
parameter_rules=[],
|
parameter_rules=[],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user