add example api url endpoint in placeholder (#1887)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Chenhe Gu 2024-01-04 01:16:51 +08:00 committed by GitHub
parent 5ca4c4a44d
commit 77f9e8ce0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 57 deletions

View File

@ -1,5 +1,6 @@
import logging import logging
from decimal import Decimal from decimal import Decimal
from urllib.parse import urljoin
import requests import requests
import json import json
@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \ from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage AssistantPromptMessage, PromptMessageContent, \
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \ PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
DefaultParameterName, \
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials['endpoint_url'] endpoint_url = credentials['endpoint_url']
if not endpoint_url.endswith('/'):
endpoint_url += '/'
# prepare the payload for a simple ping to the model # prepare the payload for a simple ping to the model
data = { data = {
@ -105,8 +111,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"content": "ping" "content": "ping"
}, },
] ]
endpoint_url = urljoin(endpoint_url, 'chat/completions')
elif completion_type is LLMMode.COMPLETION: elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping' data['prompt'] = 'ping'
endpoint_url = urljoin(endpoint_url, 'completions')
else: else:
raise ValueError("Unsupported completion type for model configuration.") raise ValueError("Unsupported completion type for model configuration.")
@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
) )
if response.status_code != 200: if response.status_code != 200:
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}') raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if (completion_type is LLMMode.CHAT
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
elif (completion_type is LLMMode.COMPLETION
and ('object' not in json_result or json_result['object'] != 'text_completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'text_completion\'')
except CredentialsValidateFailedError:
raise
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model_type=ModelType.LLM, model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MODE: 'chat' ModelPropertyKey.MODE: credentials.get('mode'),
}, },
parameter_rules=[ parameter_rules=[
ParameterRule( ParameterRule(
@ -197,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return entity return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
user: Optional[str] = None) -> Union[LLMResult, Generator]: stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
""" """
Invoke llm completion model Invoke llm completion model
@ -223,6 +247,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"] endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith('/'):
endpoint_url += '/'
data = { data = {
"model": model, "model": model,
@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
completion_type = LLMMode.value_of(credentials['mode']) completion_type = LLMMode.value_of(credentials['mode'])
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
elif completion_type == LLMMode.COMPLETION: elif completion_type == LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content data['prompt'] = prompt_messages[0].content
else: else:
raise ValueError("Unsupported completion type for model configuration.") raise ValueError("Unsupported completion type for model configuration.")
@ -245,7 +273,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
data["tool_choice"] = "auto" data["tool_choice"] = "auto"
for tool in tools: for tool in tools:
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool))) formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools data["tools"] = formatted_tools
@ -276,7 +304,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator: prompt_messages: list[PromptMessage]) -> Generator:
""" """
Handle llm stream response Handle llm stream response
@ -313,6 +341,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk: if chunk:
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip() decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
chunk_json = None
try: try:
chunk_json = json.loads(decoded_chunk) chunk_json = json.loads(decoded_chunk)
# stream ended # stream ended
@ -323,41 +352,53 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
finish_reason="Non-JSON encountered." finish_reason="Non-JSON encountered."
) )
if len(chunk_json['choices']) == 0: if not chunk_json or len(chunk_json['choices']) == 0:
continue continue
delta = chunk_json['choices'][0]['delta'] choice = chunk_json['choices'][0]
chunk_index = chunk_json['choices'][0]['index'] chunk_index = choice['index'] if 'index' in choice else chunk_index
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''): if 'delta' in choice:
delta = choice['delta']
if delta.get('content') is None or delta.get('content') == '':
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
elif 'text' in choice:
if choice.get('text') is None or choice.get('text') == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=choice.get('text', '')
)
full_assistant_content += choice.get('text', '')
else:
continue continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
# check payload indicator for completion # check payload indicator for completion
if chunk_json['choices'][0].get('finish_reason') is not None: if chunk_json['choices'][0].get('finish_reason') is not None:
yield create_final_llm_result_chunk( yield create_final_llm_result_chunk(
index=chunk_index, index=chunk_index,
message=assistant_prompt_message, message=assistant_prompt_message,
finish_reason=chunk_json['choices'][0]['finish_reason'] finish_reason=chunk_json['choices'][0]['finish_reason']
) )
else: else:
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
@ -374,8 +415,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
finish_reason="End of stream." finish_reason="End of stream."
) )
chunk_index += 1
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult: prompt_messages: list[PromptMessage]) -> LLMResult:
response_json = response.json() response_json = response.json()
@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls: if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
in
message.tool_calls] message.tool_calls]
# function_call = message.tool_calls[0] # function_call = message.tool_calls[0]
# message_dict["function_call"] = { # message_dict["function_call"] = {

View File

@ -33,8 +33,8 @@ model_credential_schema:
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入您的 API endpoint URL zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Enter your API endpoint URL en_US: Base URL, eg. https://api.openai.com/v1
- variable: mode - variable: mode
show_on: show_on:
- variable: __model_type - variable: __model_type

View File

@ -1,6 +1,7 @@
import time import time
from decimal import Decimal from decimal import Decimal
from typing import Optional from typing import Optional
from urllib.parse import urljoin
import requests import requests
import json import json
@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url'] endpoint_url = urljoin(endpoint_url, 'embeddings')
extra_model_kwargs = {} extra_model_kwargs = {}
if user: if user:
@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url'] endpoint_url = urljoin(endpoint_url, 'embeddings')
payload = { payload = {
'input': 'ping', 'input': 'ping',
@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
) )
if response.status_code != 200: if response.status_code != 200:
raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}") raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if 'model' not in json_result:
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response')
except CredentialsValidateFailedError:
raise
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1, ModelPropertyKey.MAX_CHUNKS: 1,
}, },
parameter_rules=[], parameter_rules=[],