feat: support vision models from xinference (#4094)

Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
Minamiyama 2024-05-07 17:37:36 +08:00 committed by GitHub
parent bb7c62777d
commit f361c7004d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 35 deletions

View File

@ -28,7 +28,10 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
ToolPromptMessage, ToolPromptMessage,
@ -61,8 +64,8 @@ from core.model_runtime.utils import helper
class XinferenceAILargeLanguageModel(LargeLanguageModel): class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None, model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator: -> LLMResult | Generator:
""" """
@ -99,7 +102,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
try: try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
extra_param = XinferenceHelper.get_xinference_extra_parameter( extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid']
@ -111,10 +114,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
credentials['completion_type'] = 'completion' credentials['completion_type'] = 'completion'
else: else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
if extra_param.support_function_call: if extra_param.support_function_call:
credentials['support_function_call'] = True credentials['support_function_call'] = True
if extra_param.support_vision:
credentials['support_vision'] = True
if extra_param.context_length: if extra_param.context_length:
credentials['context_length'] = extra_param.context_length credentials['context_length'] = extra_param.context_length
@ -135,7 +141,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
""" """
return self._num_tokens_from_messages(prompt_messages, tools) return self._num_tokens_from_messages(prompt_messages, tools)
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
is_completion_model: bool = False) -> int: is_completion_model: bool = False) -> int:
def tokens(text: str): def tokens(text: str):
return self._get_num_tokens_by_gpt2(text) return self._get_num_tokens_by_gpt2(text)
@ -155,7 +161,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
text = '' text = ''
for item in value: for item in value:
if isinstance(item, dict) and item['type'] == 'text': if isinstance(item, dict) and item['type'] == 'text':
text += item.text text += item['text']
value = text value = text
@ -191,7 +197,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
num_tokens += self._num_tokens_for_tools(tools) num_tokens += self._num_tokens_for_tools(tools)
return num_tokens return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
""" """
Calculate num tokens for tool calling Calculate num tokens for tool calling
@ -234,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
num_tokens += tokens(required_field) num_tokens += tokens(required_field)
return num_tokens return num_tokens
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
""" """
convert prompt message to text convert prompt message to text
@ -260,7 +266,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if isinstance(message.content, str): if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": message.content}
else: else:
raise ValueError("User message content must be str") sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
@ -277,7 +302,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")
return message_dict return message_dict
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
@ -338,8 +363,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_type = LLMMode.COMPLETION.value completion_type = LLMMode.COMPLETION.value
else: else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
features = []
support_function_call = credentials.get('support_function_call', False) support_function_call = credentials.get('support_function_call', False)
if support_function_call:
features.append(ModelFeature.TOOL_CALL)
support_vision = credentials.get('support_vision', False)
if support_vision:
features.append(ModelFeature.VISION)
context_length = credentials.get('context_length', 2048) context_length = credentials.get('context_length', 2048)
entity = AIModelEntity( entity = AIModelEntity(
@ -349,10 +384,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
features=[ features=features,
ModelFeature.TOOL_CALL model_properties={
] if support_function_call else [],
model_properties={
ModelPropertyKey.MODE: completion_type, ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length ModelPropertyKey.CONTEXT_SIZE: context_length
}, },
@ -360,22 +393,22 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
) )
return entity return entity
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
tools: list[PromptMessageTool] | None = None, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator: -> LLMResult | Generator:
""" """
generate text from LLM generate text from LLM
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
""" """
if 'server_url' not in credentials: if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials') raise CredentialsValidateFailedError('server_url is required in credentials')
if credentials['server_url'].endswith('/'): if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1] credentials['server_url'] = credentials['server_url'][:-1]
@ -408,11 +441,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
'function': helper.dump_model(tool) 'function': helper.dump_model(tool)
} for tool in tools } for tool in tools
] ]
vision = credentials.get('support_vision', False)
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
resp = client.chat.completions.create( resp = client.chat.completions.create(
model=credentials['model_uid'], model=credentials['model_uid'],
messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
stream=stream, stream=stream,
user=user, user=user,
**generate_config, **generate_config,
@ -497,7 +530,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
""" """
if len(resp.choices) == 0: if len(resp.choices) == 0:
raise InvokeServerUnavailableError("Empty response") raise InvokeServerUnavailableError("Empty response")
assistant_message = resp.choices[0].message assistant_message = resp.choices[0].message
# convert tool call to assistant message tool call # convert tool call to assistant message tool call
@ -527,7 +560,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
) )
return response return response
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool], tools: list[PromptMessageTool],
resp: Iterator[ChatCompletionChunk]) -> Generator: resp: Iterator[ChatCompletionChunk]) -> Generator:
@ -544,7 +577,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
continue continue
# check if there is a tool call in the response # check if there is a tool call in the response
function_call = None function_call = None
tool_calls = [] tool_calls = []
@ -573,9 +606,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials, usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
@ -608,7 +641,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
""" """
if len(resp.choices) == 0: if len(resp.choices) == 0:
raise InvokeServerUnavailableError("Empty response") raise InvokeServerUnavailableError("Empty response")
assistant_message = resp.choices[0].text assistant_message = resp.choices[0].text
# transform assistant message to prompt message # transform assistant message to prompt message
@ -670,9 +703,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_tokens = self._num_tokens_from_messages( completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
) )
usage = self._calc_response_usage(model=model, credentials=credentials, usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,

View File

@ -14,13 +14,15 @@ class XinferenceModelExtraParameter:
max_tokens: int = 512 max_tokens: int = 512
context_length: int = 2048 context_length: int = 2048
support_function_call: bool = False support_function_call: bool = False
support_vision: bool = False
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
support_function_call: bool, max_tokens: int, context_length: int) -> None: support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None:
self.model_format = model_format self.model_format = model_format
self.model_handle_type = model_handle_type self.model_handle_type = model_handle_type
self.model_ability = model_ability self.model_ability = model_ability
self.support_function_call = support_function_call self.support_function_call = support_function_call
self.support_vision = support_vision
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.context_length = context_length self.context_length = context_length
@ -71,7 +73,7 @@ class XinferenceHelper:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200: if response.status_code != 200:
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
response_json = response.json() response_json = response.json()
model_format = response_json.get('model_format', 'ggmlv3') model_format = response_json.get('model_format', 'ggmlv3')
@ -87,17 +89,19 @@ class XinferenceHelper:
model_handle_type = 'chat' model_handle_type = 'chat'
else: else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
support_function_call = 'tools' in model_ability support_function_call = 'tools' in model_ability
support_vision = 'vision' in model_ability
max_tokens = response_json.get('max_tokens', 512) max_tokens = response_json.get('max_tokens', 512)
context_length = response_json.get('context_length', 2048) context_length = response_json.get('context_length', 2048)
return XinferenceModelExtraParameter( return XinferenceModelExtraParameter(
model_format=model_format, model_format=model_format,
model_handle_type=model_handle_type, model_handle_type=model_handle_type,
model_ability=model_ability, model_ability=model_ability,
support_function_call=support_function_call, support_function_call=support_function_call,
support_vision=support_vision,
max_tokens=max_tokens, max_tokens=max_tokens,
context_length=context_length context_length=context_length
) )