mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 22:25:53 +08:00
feat: support vision models from xinference (#4094)
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
parent
bb7c62777d
commit
f361c7004d
@ -28,7 +28,10 @@ from core.model_runtime.entities.common_entities import I18nObject
|
|||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
|
PromptMessageContentType,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
@ -115,6 +118,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
if extra_param.support_function_call:
|
if extra_param.support_function_call:
|
||||||
credentials['support_function_call'] = True
|
credentials['support_function_call'] = True
|
||||||
|
|
||||||
|
if extra_param.support_vision:
|
||||||
|
credentials['support_vision'] = True
|
||||||
|
|
||||||
if extra_param.context_length:
|
if extra_param.context_length:
|
||||||
credentials['context_length'] = extra_param.context_length
|
credentials['context_length'] = extra_param.context_length
|
||||||
|
|
||||||
@ -155,7 +161,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
text = ''
|
text = ''
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item['type'] == 'text':
|
if isinstance(item, dict) and item['type'] == 'text':
|
||||||
text += item.text
|
text += item['text']
|
||||||
|
|
||||||
value = text
|
value = text
|
||||||
|
|
||||||
@ -260,7 +266,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise ValueError("User message content must be str")
|
sub_messages = []
|
||||||
|
for message_content in message.content:
|
||||||
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
|
message_content = cast(PromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "text",
|
||||||
|
"text": message_content.data
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
|
message_content = cast(ImagePromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": message_content.data,
|
||||||
|
"detail": message_content.detail.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
message_dict = {"role": "user", "content": sub_messages}
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
@ -339,7 +364,17 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||||
|
|
||||||
|
|
||||||
|
features = []
|
||||||
|
|
||||||
support_function_call = credentials.get('support_function_call', False)
|
support_function_call = credentials.get('support_function_call', False)
|
||||||
|
if support_function_call:
|
||||||
|
features.append(ModelFeature.TOOL_CALL)
|
||||||
|
|
||||||
|
support_vision = credentials.get('support_vision', False)
|
||||||
|
if support_vision:
|
||||||
|
features.append(ModelFeature.VISION)
|
||||||
|
|
||||||
context_length = credentials.get('context_length', 2048)
|
context_length = credentials.get('context_length', 2048)
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
@ -349,9 +384,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
),
|
),
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
features=[
|
features=features,
|
||||||
ModelFeature.TOOL_CALL
|
|
||||||
] if support_function_call else [],
|
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.MODE: completion_type,
|
ModelPropertyKey.MODE: completion_type,
|
||||||
ModelPropertyKey.CONTEXT_SIZE: context_length
|
ModelPropertyKey.CONTEXT_SIZE: context_length
|
||||||
@ -408,7 +441,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
'function': helper.dump_model(tool)
|
'function': helper.dump_model(tool)
|
||||||
} for tool in tools
|
} for tool in tools
|
||||||
]
|
]
|
||||||
|
vision = credentials.get('support_vision', False)
|
||||||
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
|
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
|
||||||
resp = client.chat.completions.create(
|
resp = client.chat.completions.create(
|
||||||
model=credentials['model_uid'],
|
model=credentials['model_uid'],
|
||||||
|
@ -14,13 +14,15 @@ class XinferenceModelExtraParameter:
|
|||||||
max_tokens: int = 512
|
max_tokens: int = 512
|
||||||
context_length: int = 2048
|
context_length: int = 2048
|
||||||
support_function_call: bool = False
|
support_function_call: bool = False
|
||||||
|
support_vision: bool = False
|
||||||
|
|
||||||
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
|
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
|
||||||
support_function_call: bool, max_tokens: int, context_length: int) -> None:
|
support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None:
|
||||||
self.model_format = model_format
|
self.model_format = model_format
|
||||||
self.model_handle_type = model_handle_type
|
self.model_handle_type = model_handle_type
|
||||||
self.model_ability = model_ability
|
self.model_ability = model_ability
|
||||||
self.support_function_call = support_function_call
|
self.support_function_call = support_function_call
|
||||||
|
self.support_vision = support_vision
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
|
|
||||||
@ -89,6 +91,7 @@ class XinferenceHelper:
|
|||||||
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
|
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
|
||||||
|
|
||||||
support_function_call = 'tools' in model_ability
|
support_function_call = 'tools' in model_ability
|
||||||
|
support_vision = 'vision' in model_ability
|
||||||
max_tokens = response_json.get('max_tokens', 512)
|
max_tokens = response_json.get('max_tokens', 512)
|
||||||
|
|
||||||
context_length = response_json.get('context_length', 2048)
|
context_length = response_json.get('context_length', 2048)
|
||||||
@ -98,6 +101,7 @@ class XinferenceHelper:
|
|||||||
model_handle_type=model_handle_type,
|
model_handle_type=model_handle_type,
|
||||||
model_ability=model_ability,
|
model_ability=model_ability,
|
||||||
support_function_call=support_function_call,
|
support_function_call=support_function_call,
|
||||||
|
support_vision=support_vision,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
context_length=context_length
|
context_length=context_length
|
||||||
)
|
)
|
Loading…
x
Reference in New Issue
Block a user