feat: support vision models from xinference (#4094)

Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
Minamiyama 2024-05-07 17:37:36 +08:00 committed by GitHub
parent bb7c62777d
commit f361c7004d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 35 deletions

View File

@ -28,7 +28,10 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
@ -115,6 +118,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if extra_param.support_function_call:
credentials['support_function_call'] = True
if extra_param.support_vision:
credentials['support_vision'] = True
if extra_param.context_length:
credentials['context_length'] = extra_param.context_length
@ -155,7 +161,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item.text
text += item['text']
value = text
@ -260,7 +266,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
@ -339,7 +364,17 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
features = []
support_function_call = credentials.get('support_function_call', False)
if support_function_call:
features.append(ModelFeature.TOOL_CALL)
support_vision = credentials.get('support_vision', False)
if support_vision:
features.append(ModelFeature.VISION)
context_length = credentials.get('context_length', 2048)
entity = AIModelEntity(
@ -349,9 +384,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=[
ModelFeature.TOOL_CALL
] if support_function_call else [],
features=features,
model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
@ -408,7 +441,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
'function': helper.dump_model(tool)
} for tool in tools
]
vision = credentials.get('support_vision', False)
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
resp = client.chat.completions.create(
model=credentials['model_uid'],

View File

@ -14,13 +14,15 @@ class XinferenceModelExtraParameter:
max_tokens: int = 512
context_length: int = 2048
support_function_call: bool = False
support_vision: bool = False
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
support_function_call: bool, max_tokens: int, context_length: int) -> None:
support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None:
self.model_format = model_format
self.model_handle_type = model_handle_type
self.model_ability = model_ability
self.support_function_call = support_function_call
self.support_vision = support_vision
self.max_tokens = max_tokens
self.context_length = context_length
@ -89,6 +91,7 @@ class XinferenceHelper:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
support_function_call = 'tools' in model_ability
support_vision = 'vision' in model_ability
max_tokens = response_json.get('max_tokens', 512)
context_length = response_json.get('context_length', 2048)
@ -98,6 +101,7 @@ class XinferenceHelper:
model_handle_type=model_handle_type,
model_ability=model_ability,
support_function_call=support_function_call,
support_vision=support_vision,
max_tokens=max_tokens,
context_length=context_length
)