From f361c7004dbcaadb9b65d27965511566df8c33bb Mon Sep 17 00:00:00 2001 From: Minamiyama Date: Tue, 7 May 2024 17:37:36 +0800 Subject: [PATCH] feat: support vision models from xinference (#4094) Co-authored-by: Yeuoly --- .../model_providers/xinference/llm/llm.py | 93 +++++++++++++------ .../xinference/xinference_helper.py | 14 ++- 2 files changed, 72 insertions(+), 35 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 602d0b749f..cc3ce17975 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -28,7 +28,10 @@ from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContent, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, ToolPromptMessage, @@ -61,8 +64,8 @@ from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -99,7 +102,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): try: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - + extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], model_uid=credentials['model_uid'] @@ -111,10 +114,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): credentials['completion_type'] = 'completion' else: raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') - + if extra_param.support_function_call: credentials['support_function_call'] = True + if extra_param.support_vision: + credentials['support_vision'] = True + if extra_param.context_length: credentials['context_length'] = extra_param.context_length @@ -135,7 +141,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -155,7 +161,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): text = '' for item in value: if isinstance(item, dict) and item['type'] == 'text': - text += item.text + text += item['text'] value = text @@ -191,7 +197,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += self._num_tokens_for_tools(tools) return num_tokens - + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for tool calling @@ -234,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += tokens(required_field) return num_tokens - + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ convert prompt message to text @@ -260,7 +266,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if isinstance(message.content, str): message_dict = {"role": "user", "content": message.content} else: - raise ValueError("User message content must be str") + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": { + "url": message_content.data, + "detail": message_content.detail.value + } + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} @@ -277,7 +302,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: @@ -338,8 +363,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') - + + + features = [] + support_function_call = credentials.get('support_function_call', False) + if support_function_call: + features.append(ModelFeature.TOOL_CALL) + + support_vision = credentials.get('support_vision', False) + if support_vision: + features.append(ModelFeature.VISION) + context_length = credentials.get('context_length', 2048) entity = AIModelEntity( @@ -349,10 +384,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - features=[ - ModelFeature.TOOL_CALL - ] if support_function_call else [], - model_properties={ + features=features, + model_properties={ ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length }, @@ -360,22 +393,22 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ generate text from LLM see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` """ if 'server_url' not in credentials: raise CredentialsValidateFailedError('server_url is required in credentials') - + if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] @@ -408,11 +441,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): 'function': helper.dump_model(tool) } for tool in tools ] - + vision = credentials.get('support_vision', False) if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( model=credentials['model_uid'], - messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], + messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], stream=stream, user=user, **generate_config, @@ -497,7 +530,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") - + assistant_message = resp.choices[0].message # convert tool call to assistant message tool call @@ -527,7 +560,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) return response - + def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], resp: Iterator[ChatCompletionChunk]) -> Generator: @@ -544,7 +577,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): continue - + # check if there is a tool call in the response function_call = None tool_calls = [] @@ -573,9 +606,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -608,7 +641,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") - + assistant_message = resp.choices[0].text # transform assistant message to prompt message @@ -670,9 +703,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): completion_tokens = self._num_tokens_from_messages( messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True ) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 66dab65804..9a3fc9b193 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -14,13 +14,15 @@ class XinferenceModelExtraParameter: max_tokens: int = 512 context_length: int = 2048 support_function_call: bool = False + support_vision: bool = False - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, max_tokens: int, context_length: int) -> None: + def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], + support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability self.support_function_call = support_function_call + self.support_vision = support_vision self.max_tokens = max_tokens self.context_length = context_length @@ -71,7 +73,7 @@ class XinferenceHelper: raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') if response.status_code != 200: raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') - + response_json = response.json() model_format = response_json.get('model_format', 'ggmlv3') @@ -87,17 +89,19 @@ class XinferenceHelper: model_handle_type = 'chat' else: raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') - + support_function_call = 'tools' in model_ability + support_vision = 'vision' in model_ability max_tokens = response_json.get('max_tokens', 512) context_length = response_json.get('context_length', 2048) - + return XinferenceModelExtraParameter( model_format=model_format, model_handle_type=model_handle_type, model_ability=model_ability, support_function_call=support_function_call, + support_vision=support_vision, max_tokens=max_tokens, context_length=context_length ) \ No newline at end of file