diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 049ad67a77..7b4416f44e 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -11,6 +11,8 @@ - groq - replicate - huggingface_hub +- xinference +- triton_inference_server - zhipuai - baichuan - spark @@ -20,7 +22,6 @@ - moonshot - jina - chatglm -- xinference - yi - openllm - localai diff --git a/api/core/model_runtime/model_providers/triton_inference_server/__init__.py b/api/core/model_runtime/model_providers/triton_inference_server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_l_en.png new file mode 100644 index 0000000000..dd32d45803 Binary files /dev/null and b/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_s_en.svg new file mode 100644 index 0000000000..9fc02f9164 --- /dev/null +++ b/api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/__init__.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py new file mode 100644 index 0000000000..95272a41c2 --- /dev/null +++ b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py @@ -0,0 +1,267 @@ +from collections.abc import Generator + +from httpx import Response, post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + + +class TritonInferenceAILargeLanguageModel(LargeLanguageModel): + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + """ + invoke LLM + + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + """ + return self._generate( + model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, + tools=tools, stop=stop, stream=stream, user=user, + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + validate credentials + """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('server_url is required in credentials') + + try: + self._invoke(model=model, credentials=credentials, prompt_messages=[ + UserPromptMessage(content='ping') + ], model_parameters={}, stream=False) + except InvokeError as ex: + raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}') + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None) -> int: + """ + get number of tokens + + cause TritonInference LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default + """ + return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) + + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: + """ + convert prompt message to text + """ + text = '' + for item in message: + if isinstance(item, UserPromptMessage): + text += f'User: {item.content}' + elif isinstance(item, SystemPromptMessage): + text += f'System: {item.content}' + elif isinstance(item, AssistantPromptMessage): + text += f'Assistant: {item.content}' + else: + raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + return text + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ), + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=int(credentials.get('context_length', 2048)), + default=min(512, int(credentials.get('context_length', 2048))), + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ) + ] + + completion_type = None + + if 'completion_type' in credentials: + if credentials['completion_type'] == 'chat': + completion_type = LLMMode.CHAT.value + elif credentials['completion_type'] == 'completion': + completion_type = LLMMode.COMPLETION.value + else: + raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + parameter_rules=rules, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + model_properties={ + ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)), + }, + ) + + return entity + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + """ + generate text from LLM + """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('server_url is required in credentials') + + if 'stream' in credentials and not bool(credentials['stream']) and stream: + raise ValueError(f'stream is not supported by model {model}') + + try: + parameters = {} + if 'temperature' in model_parameters: + parameters['temperature'] = model_parameters['temperature'] + if 'top_p' in model_parameters: + parameters['top_p'] = model_parameters['top_p'] + if 'top_k' in model_parameters: + parameters['top_k'] = model_parameters['top_k'] + if 'presence_penalty' in model_parameters: + parameters['presence_penalty'] = model_parameters['presence_penalty'] + if 'frequency_penalty' in model_parameters: + parameters['frequency_penalty'] = model_parameters['frequency_penalty'] + + response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={ + 'text_input': self._convert_prompt_message_to_text(prompt_messages), + 'max_tokens': model_parameters.get('max_tokens', 512), + 'parameters': { + 'stream': False, + **parameters + }, + }, timeout=(10, 120)) + response.raise_for_status() + if response.status_code != 200: + raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}') + + if stream: + return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=response) + return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=response) + except Exception as ex: + raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}') + + def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response) -> LLMResult: + """ + handle normal chat generate response + """ + text = resp.json()['text_output'] + + usage = LLMUsage.empty_usage() + usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + usage.completion_tokens = self._get_num_tokens_by_gpt2(text) + + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=text + ), + usage=usage + ) + + def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response) -> Generator: + """ + handle normal chat generate response + """ + text = resp.json()['text_output'] + + usage = LLMUsage.empty_usage() + usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + usage.completion_tokens = self._get_num_tokens_by_gpt2(text) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=text + ), + usage=usage + ) + ) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + ], + InvokeServerUnavailableError: [ + ], + InvokeRateLimitError: [ + ], + InvokeAuthorizationError: [ + ], + InvokeBadRequestError: [ + ValueError + ] + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py new file mode 100644 index 0000000000..06846825ab --- /dev/null +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py @@ -0,0 +1,9 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + +class XinferenceAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml new file mode 100644 index 0000000000..50a804743d --- /dev/null +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml @@ -0,0 +1,84 @@ +provider: triton_inference_server +label: + en_US: Triton Inference Server +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#EFFDFD" +help: + title: + en_US: How to deploy Triton Inference Server + zh_Hans: 如何部署 Triton Inference Server + url: + en_US: https://github.com/triton-inference-server/server +supported_model_types: + - llm +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000 + en_US: Enter the url of your Triton Inference Server, e.g. http://192.168.1.100:8000 + - variable: context_size + label: + zh_Hans: 上下文大小 + en_US: Context size + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的上下文大小 + en_US: Enter the context size + default: 2048 + - variable: completion_type + label: + zh_Hans: 补全类型 + en_US: Model type + type: select + required: true + default: chat + placeholder: + zh_Hans: 在此输入您的补全类型 + en_US: Enter the completion type + options: + - label: + zh_Hans: 补全模型 + en_US: Completion model + value: completion + - label: + zh_Hans: 对话模型 + en_US: Chat model + value: chat + - variable: stream + label: + zh_Hans: 流式输出 + en_US: Stream output + type: select + required: true + default: true + placeholder: + zh_Hans: 是否支持流式输出 + en_US: Whether to support stream output + options: + - label: + zh_Hans: 是 + en_US: Yes + value: true + - label: + zh_Hans: 否 + en_US: No + value: false