mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 01:39:04 +08:00
feat: Allow using file variables directly in the LLM node and support more file types. (#10679)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
parent
535c72cad7
commit
c5f7d650b5
@ -27,7 +27,6 @@ class DifyConfig(
|
|||||||
# read from dotenv format config file
|
# read from dotenv format config file
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
frozen=True,
|
|
||||||
# ignore extra attributes
|
# ignore extra attributes
|
||||||
extra="ignore",
|
extra="ignore",
|
||||||
)
|
)
|
||||||
|
@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
|
|||||||
|
|
||||||
class ModelConfigConverter:
|
class ModelConfigConverter:
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
|
||||||
"""
|
"""
|
||||||
Convert app model config dict to entity.
|
Convert app model config dict to entity.
|
||||||
:param app_config: app config
|
:param app_config: app config
|
||||||
@ -38,27 +38,23 @@ class ModelConfigConverter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_credentials is None:
|
if model_credentials is None:
|
||||||
if not skip_check:
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
|
||||||
else:
|
|
||||||
model_credentials = {}
|
|
||||||
|
|
||||||
if not skip_check:
|
# check model
|
||||||
# check model
|
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
model=model_config.model, model_type=ModelType.LLM
|
||||||
model=model_config.model, model_type=ModelType.LLM
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if provider_model is None:
|
if provider_model is None:
|
||||||
model_name = model_config.model
|
model_name = model_config.model
|
||||||
raise ValueError(f"Model {model_name} not exist.")
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||||
|
|
||||||
# model config
|
# model config
|
||||||
completion_params = model_config.parameters
|
completion_params = model_config.parameters
|
||||||
@ -76,7 +72,7 @@ class ModelConfigConverter:
|
|||||||
|
|
||||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||||
|
|
||||||
if not skip_check and not model_schema:
|
if not model_schema:
|
||||||
raise ValueError(f"Model {model_name} not exist.")
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
return ModelConfigWithCredentialsEntity(
|
return ModelConfigWithCredentialsEntity(
|
||||||
|
@ -217,9 +217,12 @@ class WorkflowCycleManage:
|
|||||||
).total_seconds()
|
).total_seconds()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
db.session.refresh(workflow_run)
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
session.add(workflow_run)
|
||||||
|
session.refresh(workflow_run)
|
||||||
|
|
||||||
if trace_manager:
|
if trace_manager:
|
||||||
trace_manager.add_trace_task(
|
trace_manager.add_trace_task(
|
||||||
TraceTask(
|
TraceTask(
|
||||||
|
@ -3,7 +3,12 @@ import base64
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file import file_repository
|
from core.file import file_repository
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
|
from core.model_runtime.entities import (
|
||||||
|
AudioPromptMessageContent,
|
||||||
|
DocumentPromptMessageContent,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
VideoPromptMessageContent,
|
||||||
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
|
|
||||||
@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute):
|
|||||||
return file.remote_url
|
return file.remote_url
|
||||||
case FileAttribute.EXTENSION:
|
case FileAttribute.EXTENSION:
|
||||||
return file.extension
|
return file.extension
|
||||||
case _:
|
|
||||||
raise ValueError(f"Invalid file attribute: {attr}")
|
|
||||||
|
|
||||||
|
|
||||||
def to_prompt_message_content(
|
def to_prompt_message_content(
|
||||||
f: File,
|
f: File,
|
||||||
/,
|
/,
|
||||||
*,
|
*,
|
||||||
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
|
|
||||||
|
|
||||||
This function takes a File object and converts it to an appropriate PromptMessageContent
|
|
||||||
object, which can be used as a prompt for image or audio-based AI models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
f (File): The File object to convert.
|
|
||||||
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
|
|
||||||
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the file type is not supported or if required data is missing.
|
|
||||||
"""
|
|
||||||
match f.type:
|
match f.type:
|
||||||
case FileType.IMAGE:
|
case FileType.IMAGE:
|
||||||
|
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||||
data = _to_url(f)
|
data = _to_url(f)
|
||||||
else:
|
else:
|
||||||
@ -65,7 +52,7 @@ def to_prompt_message_content(
|
|||||||
|
|
||||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||||
case FileType.AUDIO:
|
case FileType.AUDIO:
|
||||||
encoded_string = _file_to_encoded_string(f)
|
encoded_string = _get_encoded_string(f)
|
||||||
if f.extension is None:
|
if f.extension is None:
|
||||||
raise ValueError("Missing file extension")
|
raise ValueError("Missing file extension")
|
||||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||||
@ -74,9 +61,20 @@ def to_prompt_message_content(
|
|||||||
data = _to_url(f)
|
data = _to_url(f)
|
||||||
else:
|
else:
|
||||||
data = _to_base64_data_string(f)
|
data = _to_base64_data_string(f)
|
||||||
|
if f.extension is None:
|
||||||
|
raise ValueError("Missing file extension")
|
||||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||||
|
case FileType.DOCUMENT:
|
||||||
|
data = _get_encoded_string(f)
|
||||||
|
if f.mime_type is None:
|
||||||
|
raise ValueError("Missing file mime_type")
|
||||||
|
return DocumentPromptMessageContent(
|
||||||
|
encode_format="base64",
|
||||||
|
mime_type=f.mime_type,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError("file type f.type is not supported")
|
raise ValueError(f"file type {f.type} is not supported")
|
||||||
|
|
||||||
|
|
||||||
def download(f: File, /):
|
def download(f: File, /):
|
||||||
@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /):
|
|||||||
case FileTransferMethod.REMOTE_URL:
|
case FileTransferMethod.REMOTE_URL:
|
||||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
content = response.content
|
data = response.content
|
||||||
encoded_string = base64.b64encode(content).decode("utf-8")
|
|
||||||
return encoded_string
|
|
||||||
case FileTransferMethod.LOCAL_FILE:
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||||
data = _download_file_content(upload_file.key)
|
data = _download_file_content(upload_file.key)
|
||||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
|
||||||
return encoded_string
|
|
||||||
case FileTransferMethod.TOOL_FILE:
|
case FileTransferMethod.TOOL_FILE:
|
||||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||||
data = _download_file_content(tool_file.file_key)
|
data = _download_file_content(tool_file.file_key)
|
||||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
|
||||||
return encoded_string
|
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||||
case _:
|
return encoded_string
|
||||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
|
||||||
|
|
||||||
|
|
||||||
def _to_base64_data_string(f: File, /):
|
def _to_base64_data_string(f: File, /):
|
||||||
@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /):
|
|||||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||||
|
|
||||||
|
|
||||||
def _file_to_encoded_string(f: File, /):
|
|
||||||
match f.type:
|
|
||||||
case FileType.IMAGE:
|
|
||||||
return _to_base64_data_string(f)
|
|
||||||
case FileType.VIDEO:
|
|
||||||
return _to_base64_data_string(f)
|
|
||||||
case FileType.AUDIO:
|
|
||||||
return _get_encoded_string(f)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"file type {f.type} is not supported")
|
|
||||||
|
|
||||||
|
|
||||||
def _to_url(f: File, /):
|
def _to_url(f: File, /):
|
||||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
if f.remote_url is None:
|
if f.remote_url is None:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
@ -27,7 +28,7 @@ class TokenBufferMemory:
|
|||||||
|
|
||||||
def get_history_prompt_messages(
|
def get_history_prompt_messages(
|
||||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
) -> list[PromptMessage]:
|
) -> Sequence[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Get history prompt messages.
|
Get history prompt messages.
|
||||||
:param max_token_limit: max token limit
|
:param max_token_limit: max token limit
|
||||||
|
@ -100,10 +100,10 @@ class ModelInstance:
|
|||||||
|
|
||||||
def invoke_llm(
|
def invoke_llm(
|
||||||
self,
|
self,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: Optional[dict] = None,
|
model_parameters: Optional[dict] = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
@ -31,7 +32,7 @@ class Callback(ABC):
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -60,7 +61,7 @@ class Callback(ABC):
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@ -90,7 +91,7 @@ class Callback(ABC):
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -120,7 +121,7 @@ class Callback(ABC):
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa
|
|||||||
from .message_entities import (
|
from .message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
AudioPromptMessageContent,
|
AudioPromptMessageContent,
|
||||||
|
DocumentPromptMessageContent,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContent,
|
PromptMessageContent,
|
||||||
@ -37,4 +38,5 @@ __all__ = [
|
|||||||
"LLMResultChunk",
|
"LLMResultChunk",
|
||||||
"LLMResultChunkDelta",
|
"LLMResultChunkDelta",
|
||||||
"AudioPromptMessageContent",
|
"AudioPromptMessageContent",
|
||||||
|
"DocumentPromptMessageContent",
|
||||||
]
|
]
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
|
|||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
AUDIO = "audio"
|
AUDIO = "audio"
|
||||||
VIDEO = "video"
|
VIDEO = "video"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageContent(BaseModel):
|
class PromptMessageContent(BaseModel):
|
||||||
@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent):
|
|||||||
detail: DETAIL = DETAIL.LOW
|
detail: DETAIL = DETAIL.LOW
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentPromptMessageContent(PromptMessageContent):
|
||||||
|
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||||
|
encode_format: Literal["base64"]
|
||||||
|
mime_type: str
|
||||||
|
data: str
|
||||||
|
|
||||||
|
|
||||||
class PromptMessage(ABC, BaseModel):
|
class PromptMessage(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for prompt message.
|
Model class for prompt message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: PromptMessageRole
|
role: PromptMessageRole
|
||||||
content: Optional[str | list[PromptMessageContent]] = None
|
content: Optional[str | Sequence[PromptMessageContent]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
|
@ -87,6 +87,9 @@ class ModelFeature(Enum):
|
|||||||
AGENT_THOUGHT = "agent-thought"
|
AGENT_THOUGHT = "agent-thought"
|
||||||
VISION = "vision"
|
VISION = "vision"
|
||||||
STREAM_TOOL_CALL = "stream-tool-call"
|
STREAM_TOOL_CALL = "stream-tool-call"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
VIDEO = "video"
|
||||||
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
class DefaultParameterName(str, Enum):
|
class DefaultParameterName(str, Enum):
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: Optional[dict] = None,
|
model_parameters: Optional[dict] = None,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
@ -212,7 +212,7 @@ if you are not sure about the structure.
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_parameters.pop("response_format")
|
model_parameters.pop("response_format")
|
||||||
stop = stop or []
|
stop = list(stop) if stop is not None else []
|
||||||
stop.extend(["\n```", "```\n"])
|
stop.extend(["\n```", "```\n"])
|
||||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||||
|
|
||||||
@ -408,7 +408,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
@ -479,7 +479,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
@ -601,7 +601,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
@ -647,7 +647,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
@ -694,7 +694,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
@ -742,7 +742,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -7,6 +7,7 @@ features:
|
|||||||
- vision
|
- vision
|
||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
|
- document
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 200000
|
context_size: 200000
|
||||||
|
@ -7,6 +7,7 @@ features:
|
|||||||
- vision
|
- vision
|
||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
|
- document
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 200000
|
context_size: 200000
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Sequence
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
@ -21,9 +21,9 @@ from httpx import Timeout
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities import (
|
||||||
from core.model_runtime.entities.message_entities import (
|
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
DocumentPromptMessageContent,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
@ -33,6 +33,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -86,10 +87,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
@ -130,9 +131,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
# Add the new header for claude-3-5-sonnet-20240620 model
|
# Add the new header for claude-3-5-sonnet-20240620 model
|
||||||
extra_headers = {}
|
extra_headers = {}
|
||||||
if model == "claude-3-5-sonnet-20240620":
|
if model == "claude-3-5-sonnet-20240620":
|
||||||
if model_parameters.get("max_tokens") > 4096:
|
if model_parameters.get("max_tokens", 0) > 4096:
|
||||||
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
||||||
|
|
||||||
|
if any(
|
||||||
|
isinstance(content, DocumentPromptMessageContent)
|
||||||
|
for prompt_message in prompt_messages
|
||||||
|
if isinstance(prompt_message.content, list)
|
||||||
|
for content in prompt_message.content
|
||||||
|
):
|
||||||
|
extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
|
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
|
||||||
response = client.beta.tools.messages.create(
|
response = client.beta.tools.messages.create(
|
||||||
@ -504,6 +513,21 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
|
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
|
elif isinstance(message_content, DocumentPromptMessageContent):
|
||||||
|
if message_content.mime_type != "application/pdf":
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported document type {message_content.mime_type}, "
|
||||||
|
"only support application/pdf"
|
||||||
|
)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "document",
|
||||||
|
"source": {
|
||||||
|
"type": message_content.encode_format,
|
||||||
|
"media_type": message_content.mime_type,
|
||||||
|
"data": message_content.data,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
@ -7,6 +7,7 @@ features:
|
|||||||
- multi-tool-call
|
- multi-tool-call
|
||||||
- agent-thought
|
- agent-thought
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 128000
|
context_size: 128000
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
|||||||
|
|
||||||
class PromptMessageUtil:
|
class PromptMessageUtil:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
|
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Prompt messages to prompt for saving.
|
Prompt messages to prompt for saving.
|
||||||
:param model_mode: model mode
|
:param model_mode: model mode
|
||||||
|
@ -118,11 +118,11 @@ class FileSegment(Segment):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def log(self) -> str:
|
def log(self) -> str:
|
||||||
return str(self.value)
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
return str(self.value)
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
|
|||||||
for item in self.value:
|
for item in self.value:
|
||||||
items.append(item.markdown)
|
items.append(item.markdown)
|
||||||
return "\n".join(items)
|
return "\n".join(items)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
return ""
|
||||||
|
@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PromptConfig(BaseModel):
|
class PromptConfig(BaseModel):
|
||||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("jinja2_variables", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_jinja2_variables(cls, v: Any):
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||||
@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
|||||||
class LLMNodeData(BaseNodeData):
|
class LLMNodeData(BaseNodeData):
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||||
prompt_config: Optional[PromptConfig] = None
|
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||||
memory: Optional[MemoryConfig] = None
|
memory: Optional[MemoryConfig] = None
|
||||||
context: ContextConfig
|
context: ContextConfig
|
||||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||||
|
|
||||||
|
@field_validator("prompt_config", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_prompt_config(cls, v: Any):
|
||||||
|
if v is None:
|
||||||
|
return PromptConfig()
|
||||||
|
return v
|
||||||
|
@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError):
|
|||||||
|
|
||||||
class NoPromptFoundError(LLMNodeError):
|
class NoPromptFoundError(LLMNodeError):
|
||||||
"""Raised when no prompt is found in the LLM configuration."""
|
"""Raised when no prompt is found in the LLM configuration."""
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupportedPromptTypeError(LLMNodeError):
|
||||||
|
"""Raised when the prompt type is not supported."""
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRolePrefixRequiredError(LLMNodeError):
|
||||||
|
"""Raised when memory role prefix is required for completion model."""
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.entities.provider_entities import QuotaUnit
|
from core.entities.provider_entities import QuotaUnit
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
|
from core.file import FileType, file_manager
|
||||||
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
AudioPromptMessageContent,
|
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
VideoPromptMessageContent,
|
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
@ -32,8 +38,9 @@ from core.variables import (
|
|||||||
ObjectSegment,
|
ObjectSegment,
|
||||||
StringSegment,
|
StringSegment,
|
||||||
)
|
)
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
@ -62,14 +69,18 @@ from .exc import (
|
|||||||
InvalidVariableTypeError,
|
InvalidVariableTypeError,
|
||||||
LLMModeRequiredError,
|
LLMModeRequiredError,
|
||||||
LLMNodeError,
|
LLMNodeError,
|
||||||
|
MemoryRolePrefixRequiredError,
|
||||||
ModelNotExistError,
|
ModelNotExistError,
|
||||||
NoPromptFoundError,
|
NoPromptFoundError,
|
||||||
|
NotSupportedPromptTypeError,
|
||||||
VariableNotFoundError,
|
VariableNotFoundError,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMNode(BaseNode[LLMNodeData]):
|
class LLMNode(BaseNode[LLMNodeData]):
|
||||||
_node_data_cls = LLMNodeData
|
_node_data_cls = LLMNodeData
|
||||||
@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
# fetch prompt messages
|
# fetch prompt messages
|
||||||
if self.node_data.memory:
|
if self.node_data.memory:
|
||||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
query = self.node_data.memory.query_prompt_template
|
||||||
if not query:
|
|
||||||
raise VariableNotFoundError("Query not found")
|
|
||||||
query = query.text
|
|
||||||
else:
|
else:
|
||||||
query = None
|
query = None
|
||||||
|
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
system_query=query,
|
user_query=query,
|
||||||
inputs=inputs,
|
user_files=files,
|
||||||
files=files,
|
|
||||||
context=context,
|
context=context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
memory_config=self.node_data.memory,
|
memory_config=self.node_data.memory,
|
||||||
vision_enabled=self.node_data.vision.enabled,
|
vision_enabled=self.node_data.vision.enabled,
|
||||||
vision_detail=self.node_data.vision.configs.detail,
|
vision_detail=self.node_data.vision.configs.detail,
|
||||||
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = {
|
process_data = {
|
||||||
@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Node {self.node_id} failed to run")
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=str(e),
|
||||||
|
inputs=node_inputs,
|
||||||
|
process_data=process_data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||||
|
|
||||||
@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
self,
|
self,
|
||||||
node_data_model: ModelConfig,
|
node_data_model: ModelConfig,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
) -> Generator[NodeEvent, None, None]:
|
) -> Generator[NodeEvent, None, None]:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
def _fetch_prompt_messages(
|
def _fetch_prompt_messages(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
system_query: str | None = None,
|
user_query: str | None = None,
|
||||||
inputs: dict[str, str] | None = None,
|
user_files: Sequence["File"],
|
||||||
files: Sequence["File"],
|
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
memory_config: MemoryConfig | None = None,
|
memory_config: MemoryConfig | None = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
variable_pool: VariablePool,
|
||||||
inputs = inputs or {}
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||||
|
prompt_messages = []
|
||||||
|
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
if isinstance(prompt_template, list):
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
# For chat model
|
||||||
prompt_template=prompt_template,
|
prompt_messages.extend(
|
||||||
inputs=inputs,
|
_handle_list_messages(
|
||||||
query=system_query or "",
|
messages=prompt_template,
|
||||||
files=files,
|
context=context,
|
||||||
context=context,
|
jinja2_variables=jinja2_variables,
|
||||||
memory_config=memory_config,
|
variable_pool=variable_pool,
|
||||||
memory=memory,
|
vision_detail_config=vision_detail,
|
||||||
model_config=model_config,
|
)
|
||||||
)
|
)
|
||||||
stop = model_config.stop
|
|
||||||
|
# Get memory messages for chat mode
|
||||||
|
memory_messages = _handle_memory_chat_mode(
|
||||||
|
memory=memory,
|
||||||
|
memory_config=memory_config,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
# Extend prompt_messages with memory messages
|
||||||
|
prompt_messages.extend(memory_messages)
|
||||||
|
|
||||||
|
# Add current query to the prompt messages
|
||||||
|
if user_query:
|
||||||
|
message = LLMNodeChatModelMessage(
|
||||||
|
text=user_query,
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
)
|
||||||
|
prompt_messages.extend(
|
||||||
|
_handle_list_messages(
|
||||||
|
messages=[message],
|
||||||
|
context="",
|
||||||
|
jinja2_variables=[],
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
vision_detail_config=vision_detail,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||||
|
# For completion model
|
||||||
|
prompt_messages.extend(
|
||||||
|
_handle_completion_template(
|
||||||
|
template=prompt_template,
|
||||||
|
context=context,
|
||||||
|
jinja2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get memory text for completion model
|
||||||
|
memory_text = _handle_memory_completion_mode(
|
||||||
|
memory=memory,
|
||||||
|
memory_config=memory_config,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
# Insert histories into the prompt
|
||||||
|
prompt_content = prompt_messages[0].content
|
||||||
|
if "#histories#" in prompt_content:
|
||||||
|
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||||
|
else:
|
||||||
|
prompt_content = memory_text + "\n" + prompt_content
|
||||||
|
prompt_messages[0].content = prompt_content
|
||||||
|
|
||||||
|
# Add current query to the prompt message
|
||||||
|
if user_query:
|
||||||
|
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
||||||
|
prompt_messages[0].content = prompt_content
|
||||||
|
else:
|
||||||
|
errmsg = f"Prompt type {type(prompt_template)} is not supported"
|
||||||
|
logger.warning(errmsg)
|
||||||
|
raise NotSupportedPromptTypeError(errmsg)
|
||||||
|
|
||||||
|
if vision_enabled and user_files:
|
||||||
|
file_prompts = []
|
||||||
|
for file in user_files:
|
||||||
|
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||||
|
file_prompts.append(file_prompt)
|
||||||
|
if (
|
||||||
|
len(prompt_messages) > 0
|
||||||
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||||
|
and isinstance(prompt_messages[-1].content, list)
|
||||||
|
):
|
||||||
|
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||||
|
|
||||||
|
# Filter prompt messages
|
||||||
filtered_prompt_messages = []
|
filtered_prompt_messages = []
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
if prompt_message.is_empty():
|
if isinstance(prompt_message.content, list):
|
||||||
continue
|
|
||||||
|
|
||||||
if not isinstance(prompt_message.content, str):
|
|
||||||
prompt_message_content = []
|
prompt_message_content = []
|
||||||
for content_item in prompt_message.content or []:
|
for content_item in prompt_message.content:
|
||||||
# Skip image if vision is disabled
|
# Skip content if features are not defined
|
||||||
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
|
if not model_config.model_schema.features:
|
||||||
|
if content_item.type != PromptMessageContentType.TEXT:
|
||||||
|
continue
|
||||||
|
prompt_message_content.append(content_item)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(content_item, ImagePromptMessageContent):
|
# Skip content if corresponding feature is not supported
|
||||||
# Override vision config if LLM node has vision config,
|
if (
|
||||||
# cuz vision detail is related to the configuration from FileUpload feature.
|
(
|
||||||
content_item.detail = vision_detail
|
content_item.type == PromptMessageContentType.IMAGE
|
||||||
prompt_message_content.append(content_item)
|
and ModelFeature.VISION not in model_config.model_schema.features
|
||||||
elif isinstance(
|
)
|
||||||
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.DOCUMENT
|
||||||
|
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.VIDEO
|
||||||
|
and ModelFeature.VIDEO not in model_config.model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.AUDIO
|
||||||
|
and ModelFeature.AUDIO not in model_config.model_schema.features
|
||||||
|
)
|
||||||
):
|
):
|
||||||
prompt_message_content.append(content_item)
|
continue
|
||||||
|
prompt_message_content.append(content_item)
|
||||||
if len(prompt_message_content) > 1:
|
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||||
prompt_message.content = prompt_message_content
|
|
||||||
elif (
|
|
||||||
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
|
|
||||||
):
|
|
||||||
prompt_message.content = prompt_message_content[0].data
|
prompt_message.content = prompt_message_content[0].data
|
||||||
|
else:
|
||||||
|
prompt_message.content = prompt_message_content
|
||||||
|
if prompt_message.is_empty():
|
||||||
|
continue
|
||||||
filtered_prompt_messages.append(prompt_message)
|
filtered_prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
if not filtered_prompt_messages:
|
if len(filtered_prompt_messages) == 0:
|
||||||
raise NoPromptFoundError(
|
raise NoPromptFoundError(
|
||||||
"No prompt found in the LLM configuration. "
|
"No prompt found in the LLM configuration. "
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stop = model_config.stop
|
||||||
return filtered_prompt_messages, stop
|
return filtered_prompt_messages, stop
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
|
||||||
|
match role:
|
||||||
|
case PromptMessageRole.USER:
|
||||||
|
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
case PromptMessageRole.ASSISTANT:
|
||||||
|
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
case PromptMessageRole.SYSTEM:
|
||||||
|
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
raise NotImplementedError(f"Role {role} is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def _render_jinja2_message(
|
||||||
|
*,
|
||||||
|
template: str,
|
||||||
|
jinjia2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
):
|
||||||
|
if not template:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
jinjia2_inputs = {}
|
||||||
|
for jinja2_variable in jinjia2_variables:
|
||||||
|
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||||
|
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||||
|
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
||||||
|
language=CodeLanguage.JINJA2,
|
||||||
|
code=template,
|
||||||
|
inputs=jinjia2_inputs,
|
||||||
|
)
|
||||||
|
result_text = code_execute_resp["result"]
|
||||||
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_list_messages(
|
||||||
|
*,
|
||||||
|
messages: Sequence[LLMNodeChatModelMessage],
|
||||||
|
context: Optional[str],
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
prompt_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message.edition_type == "jinja2":
|
||||||
|
result_text = _render_jinja2_message(
|
||||||
|
template=message.jinja2_text or "",
|
||||||
|
jinjia2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
else:
|
||||||
|
# Get segment group from basic message
|
||||||
|
if context:
|
||||||
|
template = message.text.replace("{#context#}", context)
|
||||||
|
else:
|
||||||
|
template = message.text
|
||||||
|
segment_group = variable_pool.convert_template(template)
|
||||||
|
|
||||||
|
# Process segments for images
|
||||||
|
file_contents = []
|
||||||
|
for segment in segment_group.value:
|
||||||
|
if isinstance(segment, ArrayFileSegment):
|
||||||
|
for file in segment.value:
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||||
|
file_content = file_manager.to_prompt_message_content(
|
||||||
|
file, image_detail_config=vision_detail_config
|
||||||
|
)
|
||||||
|
file_contents.append(file_content)
|
||||||
|
if isinstance(segment, FileSegment):
|
||||||
|
file = segment.value
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||||
|
file_content = file_manager.to_prompt_message_content(
|
||||||
|
file, image_detail_config=vision_detail_config
|
||||||
|
)
|
||||||
|
file_contents.append(file_content)
|
||||||
|
|
||||||
|
# Create message with text from all segments
|
||||||
|
plain_text = segment_group.text
|
||||||
|
if plain_text:
|
||||||
|
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
|
if file_contents:
|
||||||
|
# Create message with image contents
|
||||||
|
prompt_message = UserPromptMessage(content=file_contents)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_rest_token(
|
||||||
|
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||||
|
) -> int:
|
||||||
|
rest_tokens = 2000
|
||||||
|
|
||||||
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
|
if model_context_tokens:
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
|
|
||||||
|
max_tokens = 0
|
||||||
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
|
if parameter_rule.name == "max_tokens" or (
|
||||||
|
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||||
|
):
|
||||||
|
max_tokens = (
|
||||||
|
model_config.parameters.get(parameter_rule.name)
|
||||||
|
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||||
|
rest_tokens = max(rest_tokens, 0)
|
||||||
|
|
||||||
|
return rest_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_memory_chat_mode(
|
||||||
|
*,
|
||||||
|
memory: TokenBufferMemory | None,
|
||||||
|
memory_config: MemoryConfig | None,
|
||||||
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
memory_messages = []
|
||||||
|
# Get messages from memory for chat model
|
||||||
|
if memory and memory_config:
|
||||||
|
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||||
|
memory_messages = memory.get_history_prompt_messages(
|
||||||
|
max_token_limit=rest_tokens,
|
||||||
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
)
|
||||||
|
return memory_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_memory_completion_mode(
|
||||||
|
*,
|
||||||
|
memory: TokenBufferMemory | None,
|
||||||
|
memory_config: MemoryConfig | None,
|
||||||
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
) -> str:
|
||||||
|
memory_text = ""
|
||||||
|
# Get history text from memory for completion model
|
||||||
|
if memory and memory_config:
|
||||||
|
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||||
|
if not memory_config.role_prefix:
|
||||||
|
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||||
|
memory_text = memory.get_history_prompt_text(
|
||||||
|
max_token_limit=rest_tokens,
|
||||||
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
human_prefix=memory_config.role_prefix.user,
|
||||||
|
ai_prefix=memory_config.role_prefix.assistant,
|
||||||
|
)
|
||||||
|
return memory_text
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_completion_template(
|
||||||
|
*,
|
||||||
|
template: LLMNodeCompletionModelPromptTemplate,
|
||||||
|
context: Optional[str],
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
"""Handle completion template processing outside of LLMNode class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: The completion model prompt template
|
||||||
|
context: Optional context string
|
||||||
|
jinja2_variables: Variables for jinja2 template rendering
|
||||||
|
variable_pool: Variable pool for template conversion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sequence of prompt messages
|
||||||
|
"""
|
||||||
|
prompt_messages = []
|
||||||
|
if template.edition_type == "jinja2":
|
||||||
|
result_text = _render_jinja2_message(
|
||||||
|
template=template.jinja2_text or "",
|
||||||
|
jinjia2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if context:
|
||||||
|
template_text = template.text.replace("{#context#}", context)
|
||||||
|
else:
|
||||||
|
template_text = template.text
|
||||||
|
result_text = variable_pool.convert_template(template_text).text
|
||||||
|
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
return prompt_messages
|
||||||
|
@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
)
|
)
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
system_query=query,
|
user_query=query,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
files=files,
|
user_files=files,
|
||||||
vision_enabled=node_data.vision.enabled,
|
vision_enabled=node_data.vision.enabled,
|
||||||
vision_detail=node_data.vision.configs.detail,
|
vision_detail=node_data.vision.configs.detail,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
jinja2_variables=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
|
17
api/poetry.lock
generated
17
api/poetry.lock
generated
@ -2423,6 +2423,21 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["pytest (>=6)"]
|
test = ["pytest (>=6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "faker"
|
||||||
|
version = "32.1.0"
|
||||||
|
description = "Faker is a Python package that generates fake data for you."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"},
|
||||||
|
{file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
python-dateutil = ">=2.4"
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fal-client"
|
name = "fal-client"
|
||||||
version = "0.5.6"
|
version = "0.5.6"
|
||||||
@ -11041,4 +11056,4 @@ cffi = ["cffi (>=1.11)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69"
|
content-hash = "d149b24ce7a203fa93eddbe8430d8ea7e5160a89c8d348b1b747c19899065639"
|
||||||
|
@ -268,6 +268,7 @@ weaviate-client = "~3.21.0"
|
|||||||
optional = true
|
optional = true
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
coverage = "~7.2.4"
|
coverage = "~7.2.4"
|
||||||
|
faker = "~32.1.0"
|
||||||
pytest = "~8.3.2"
|
pytest = "~8.3.2"
|
||||||
pytest-benchmark = "~4.0.0"
|
pytest-benchmark = "~4.0.0"
|
||||||
pytest-env = "~1.1.3"
|
pytest-env = "~1.1.3"
|
||||||
|
@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
||||||
from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||||
|
@ -4,29 +4,21 @@ import pytest
|
|||||||
|
|
||||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel
|
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
|
||||||
|
|
||||||
|
|
||||||
def test_validate_credentials():
|
def test_validate_credentials():
|
||||||
model = AzureAIStudioRerankModel()
|
model = AzureRerankModel()
|
||||||
|
|
||||||
with pytest.raises(CredentialsValidateFailedError):
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
model.validate_credentials(
|
model.validate_credentials(
|
||||||
model="azure-ai-studio-rerank-v1",
|
model="azure-ai-studio-rerank-v1",
|
||||||
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
||||||
query="What is the capital of the United States?",
|
|
||||||
docs=[
|
|
||||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
|
||||||
"Census, Carson City had a population of 55,274.",
|
|
||||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
|
||||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
|
||||||
],
|
|
||||||
score_threshold=0.8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_invoke_model():
|
def test_invoke_model():
|
||||||
model = AzureAIStudioRerankModel()
|
model = AzureRerankModel()
|
||||||
|
|
||||||
result = model.invoke(
|
result = model.invoke(
|
||||||
model="azure-ai-studio-rerank-v1",
|
model="azure-ai-studio-rerank-v1",
|
||||||
|
@ -1,125 +1,484 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from configs import dify_config
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||||
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
|
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||||
|
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||||
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||||
from core.workflow.nodes.end import EndStreamParam
|
from core.workflow.nodes.end import EndStreamParam
|
||||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
|
from core.workflow.nodes.llm.entities import (
|
||||||
|
ContextConfig,
|
||||||
|
LLMNodeChatModelMessage,
|
||||||
|
LLMNodeData,
|
||||||
|
ModelConfig,
|
||||||
|
VisionConfig,
|
||||||
|
VisionConfigOptions,
|
||||||
|
)
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
|
from models.provider import ProviderType
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
|
||||||
|
|
||||||
|
|
||||||
class TestLLMNode:
|
class MockTokenBufferMemory:
|
||||||
@pytest.fixture
|
def __init__(self, history_messages=None):
|
||||||
def llm_node(self):
|
self.history_messages = history_messages or []
|
||||||
data = LLMNodeData(
|
|
||||||
title="Test LLM",
|
|
||||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
|
||||||
prompt_template=[],
|
|
||||||
memory=None,
|
|
||||||
context=ContextConfig(enabled=False),
|
|
||||||
vision=VisionConfig(
|
|
||||||
enabled=True,
|
|
||||||
configs=VisionConfigOptions(
|
|
||||||
variable_selector=["sys", "files"],
|
|
||||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
variable_pool = VariablePool(
|
|
||||||
system_variables={},
|
|
||||||
user_inputs={},
|
|
||||||
)
|
|
||||||
node = LLMNode(
|
|
||||||
id="1",
|
|
||||||
config={
|
|
||||||
"id": "1",
|
|
||||||
"data": data.model_dump(),
|
|
||||||
},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="1",
|
|
||||||
app_id="1",
|
|
||||||
workflow_type=WorkflowType.WORKFLOW,
|
|
||||||
workflow_id="1",
|
|
||||||
graph_config={},
|
|
||||||
user_id="1",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph=Graph(
|
|
||||||
root_node_id="1",
|
|
||||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
|
||||||
answer_dependencies={},
|
|
||||||
answer_generate_route={},
|
|
||||||
),
|
|
||||||
end_stream_param=EndStreamParam(
|
|
||||||
end_dependencies={},
|
|
||||||
end_stream_variable_selector_mapping={},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
graph_runtime_state=GraphRuntimeState(
|
|
||||||
variable_pool=variable_pool,
|
|
||||||
start_at=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def test_fetch_files_with_file_segment(self, llm_node):
|
def get_history_prompt_messages(
|
||||||
file = File(
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
if message_limit is not None:
|
||||||
|
return self.history_messages[-message_limit * 2 :]
|
||||||
|
return self.history_messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_node():
|
||||||
|
data = LLMNodeData(
|
||||||
|
title="Test LLM",
|
||||||
|
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||||
|
prompt_template=[],
|
||||||
|
memory=None,
|
||||||
|
context=ContextConfig(enabled=False),
|
||||||
|
vision=VisionConfig(
|
||||||
|
enabled=True,
|
||||||
|
configs=VisionConfigOptions(
|
||||||
|
variable_selector=["sys", "files"],
|
||||||
|
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={},
|
||||||
|
user_inputs={},
|
||||||
|
)
|
||||||
|
node = LLMNode(
|
||||||
|
id="1",
|
||||||
|
config={
|
||||||
|
"id": "1",
|
||||||
|
"data": data.model_dump(),
|
||||||
|
},
|
||||||
|
graph_init_params=GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config={},
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
),
|
||||||
|
graph=Graph(
|
||||||
|
root_node_id="1",
|
||||||
|
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||||
|
answer_dependencies={},
|
||||||
|
answer_generate_route={},
|
||||||
|
),
|
||||||
|
end_stream_param=EndStreamParam(
|
||||||
|
end_dependencies={},
|
||||||
|
end_stream_variable_selector_mapping={},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
graph_runtime_state=GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_config():
|
||||||
|
# Create actual provider and model type instances
|
||||||
|
model_provider_factory = ModelProviderFactory()
|
||||||
|
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||||
|
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
# Create a ProviderModelBundle
|
||||||
|
provider_model_bundle = ProviderModelBundle(
|
||||||
|
configuration=ProviderConfiguration(
|
||||||
|
tenant_id="1",
|
||||||
|
provider=provider_instance.get_provider_schema(),
|
||||||
|
preferred_provider_type=ProviderType.CUSTOM,
|
||||||
|
using_provider_type=ProviderType.CUSTOM,
|
||||||
|
system_configuration=SystemConfiguration(enabled=False),
|
||||||
|
custom_configuration=CustomConfiguration(provider=None),
|
||||||
|
model_settings=[],
|
||||||
|
),
|
||||||
|
provider_instance=provider_instance,
|
||||||
|
model_type_instance=model_type_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and return a ModelConfigWithCredentialsEntity
|
||||||
|
return ModelConfigWithCredentialsEntity(
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
model_schema=AIModelEntity(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
label=I18nObject(en_US="GPT-3.5 Turbo"),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={},
|
||||||
|
),
|
||||||
|
mode="chat",
|
||||||
|
credentials={},
|
||||||
|
parameters={},
|
||||||
|
provider_model_bundle=provider_model_bundle,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_file_segment(llm_node):
|
||||||
|
file = File(
|
||||||
|
id="1",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test.jpg",
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="1",
|
||||||
|
)
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == [file]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_array_file_segment(llm_node):
|
||||||
|
files = [
|
||||||
|
File(
|
||||||
id="1",
|
id="1",
|
||||||
tenant_id="test",
|
tenant_id="test",
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
filename="test.jpg",
|
filename="test1.jpg",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1",
|
related_id="1",
|
||||||
|
),
|
||||||
|
File(
|
||||||
|
id="2",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test2.jpg",
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == files
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_none_segment(llm_node):
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_array_any_segment(llm_node):
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_non_existent_variable(llm_node):
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||||
|
prompt_template = []
|
||||||
|
llm_node.node_data.prompt_template = prompt_template
|
||||||
|
|
||||||
|
fake_vision_detail = faker.random_element(
|
||||||
|
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||||
|
)
|
||||||
|
fake_remote_url = faker.url()
|
||||||
|
files = [
|
||||||
|
File(
|
||||||
|
id="1",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
)
|
)
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
]
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
fake_query = faker.sentence()
|
||||||
assert result == [file]
|
|
||||||
|
|
||||||
def test_fetch_files_with_array_file_segment(self, llm_node):
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
files = [
|
user_query=fake_query,
|
||||||
File(
|
user_files=files,
|
||||||
id="1",
|
context=None,
|
||||||
tenant_id="test",
|
memory=None,
|
||||||
type=FileType.IMAGE,
|
model_config=model_config,
|
||||||
filename="test1.jpg",
|
prompt_template=prompt_template,
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
memory_config=None,
|
||||||
related_id="1",
|
vision_enabled=False,
|
||||||
),
|
vision_detail=fake_vision_detail,
|
||||||
File(
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
id="2",
|
jinja2_variables=[],
|
||||||
tenant_id="test",
|
)
|
||||||
type=FileType.IMAGE,
|
|
||||||
filename="test2.jpg",
|
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
|
||||||
related_id="2",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
assert prompt_messages == [UserPromptMessage(content=fake_query)]
|
||||||
assert result == files
|
|
||||||
|
|
||||||
def test_fetch_files_with_none_segment(self, llm_node):
|
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||||
assert result == []
|
# Setup dify config
|
||||||
|
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||||
|
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||||
|
|
||||||
def test_fetch_files_with_array_any_segment(self, llm_node):
|
# Generate fake values for prompt template
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
fake_assistant_prompt = faker.sentence()
|
||||||
|
fake_query = faker.sentence()
|
||||||
|
fake_context = faker.sentence()
|
||||||
|
fake_window_size = faker.random_int(min=1, max=3)
|
||||||
|
fake_vision_detail = faker.random_element(
|
||||||
|
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||||
|
)
|
||||||
|
fake_remote_url = faker.url()
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
# Setup mock memory with history messages
|
||||||
assert result == []
|
mock_history = [
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
]
|
||||||
|
|
||||||
def test_fetch_files_with_non_existent_variable(self, llm_node):
|
# Setup memory configuration
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
memory_config = MemoryConfig(
|
||||||
assert result == []
|
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||||
|
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
|
||||||
|
query_prompt_template=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = MockTokenBufferMemory(history_messages=mock_history)
|
||||||
|
|
||||||
|
# Test scenarios covering different file input combinations
|
||||||
|
test_scenarios = [
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="No files",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
features=[],
|
||||||
|
vision_enabled=False,
|
||||||
|
vision_detail=None,
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_context,
|
||||||
|
role=PromptMessageRole.SYSTEM,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_assistant_prompt,
|
||||||
|
role=PromptMessageRole.ASSISTANT,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
SystemPromptMessage(content=fake_context),
|
||||||
|
UserPromptMessage(content=fake_context),
|
||||||
|
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [
|
||||||
|
UserPromptMessage(content=fake_query),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="User files",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[
|
||||||
|
File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_context,
|
||||||
|
role=PromptMessageRole.SYSTEM,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_assistant_prompt,
|
||||||
|
role=PromptMessageRole.ASSISTANT,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
SystemPromptMessage(content=fake_context),
|
||||||
|
UserPromptMessage(content=fake_context),
|
||||||
|
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [
|
||||||
|
UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
TextPromptMessageContent(data=fake_query),
|
||||||
|
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=False,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File without vision feature",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File with video file and vision feature",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.VIDEO,
|
||||||
|
filename="test1.mp4",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
extension="mp4",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for scenario in test_scenarios:
|
||||||
|
model_config.model_schema.features = scenario.features
|
||||||
|
|
||||||
|
for k, v in scenario.file_variables.items():
|
||||||
|
selector = k.split(".")
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(selector, v)
|
||||||
|
|
||||||
|
# Call the method under test
|
||||||
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
|
user_query=scenario.user_query,
|
||||||
|
user_files=scenario.user_files,
|
||||||
|
context=fake_context,
|
||||||
|
memory=memory,
|
||||||
|
model_config=model_config,
|
||||||
|
prompt_template=scenario.prompt_template,
|
||||||
|
memory_config=memory_config,
|
||||||
|
vision_enabled=scenario.vision_enabled,
|
||||||
|
vision_detail=scenario.vision_detail,
|
||||||
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
|
jinja2_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
|
||||||
|
assert (
|
||||||
|
prompt_messages == scenario.expected_messages
|
||||||
|
), f"Message content mismatch in scenario: {scenario.description}"
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.file import File
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
|
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
|
||||||
|
|
||||||
|
|
||||||
|
class LLMNodeTestScenario(BaseModel):
|
||||||
|
"""Test scenario for LLM node testing."""
|
||||||
|
|
||||||
|
description: str = Field(..., description="Description of the test scenario")
|
||||||
|
user_query: str = Field(..., description="User query input")
|
||||||
|
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||||
|
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||||
|
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||||
|
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||||
|
window_size: int = Field(..., description="Window size for memory")
|
||||||
|
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
|
||||||
|
file_variables: Mapping[str, File | Sequence[File]] = Field(
|
||||||
|
default_factory=dict, description="List of file variables"
|
||||||
|
)
|
||||||
|
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")
|
@ -160,6 +160,7 @@ const CodeEditor: FC<Props> = ({
|
|||||||
hideSearch
|
hideSearch
|
||||||
vars={availableVars}
|
vars={availableVars}
|
||||||
onChange={handleSelectVar}
|
onChange={handleSelectVar}
|
||||||
|
isSupportFileVar={false}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
@ -18,6 +18,7 @@ type Props = {
|
|||||||
isSupportConstantValue?: boolean
|
isSupportConstantValue?: boolean
|
||||||
onlyLeafNodeVar?: boolean
|
onlyLeafNodeVar?: boolean
|
||||||
filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean
|
filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean
|
||||||
|
isSupportFileVar?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
const VarList: FC<Props> = ({
|
const VarList: FC<Props> = ({
|
||||||
@ -29,6 +30,7 @@ const VarList: FC<Props> = ({
|
|||||||
isSupportConstantValue,
|
isSupportConstantValue,
|
||||||
onlyLeafNodeVar,
|
onlyLeafNodeVar,
|
||||||
filterVar,
|
filterVar,
|
||||||
|
isSupportFileVar = true,
|
||||||
}) => {
|
}) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
|
|
||||||
@ -94,6 +96,7 @@ const VarList: FC<Props> = ({
|
|||||||
defaultVarKindType={item.variable_type}
|
defaultVarKindType={item.variable_type}
|
||||||
onlyLeafNodeVar={onlyLeafNodeVar}
|
onlyLeafNodeVar={onlyLeafNodeVar}
|
||||||
filterVar={filterVar}
|
filterVar={filterVar}
|
||||||
|
isSupportFileVar={isSupportFileVar}
|
||||||
/>
|
/>
|
||||||
{!readonly && (
|
{!readonly && (
|
||||||
<RemoveButton
|
<RemoveButton
|
||||||
|
@ -59,6 +59,7 @@ type Props = {
|
|||||||
isInTable?: boolean
|
isInTable?: boolean
|
||||||
onRemove?: () => void
|
onRemove?: () => void
|
||||||
typePlaceHolder?: string
|
typePlaceHolder?: string
|
||||||
|
isSupportFileVar?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
const VarReferencePicker: FC<Props> = ({
|
const VarReferencePicker: FC<Props> = ({
|
||||||
@ -81,6 +82,7 @@ const VarReferencePicker: FC<Props> = ({
|
|||||||
isInTable,
|
isInTable,
|
||||||
onRemove,
|
onRemove,
|
||||||
typePlaceHolder,
|
typePlaceHolder,
|
||||||
|
isSupportFileVar = true,
|
||||||
}) => {
|
}) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const store = useStoreApi()
|
const store = useStoreApi()
|
||||||
@ -382,6 +384,7 @@ const VarReferencePicker: FC<Props> = ({
|
|||||||
vars={outputVars}
|
vars={outputVars}
|
||||||
onChange={handleVarReferenceChange}
|
onChange={handleVarReferenceChange}
|
||||||
itemWidth={isAddBtnTrigger ? 260 : triggerWidth}
|
itemWidth={isAddBtnTrigger ? 260 : triggerWidth}
|
||||||
|
isSupportFileVar={isSupportFileVar}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</PortalToFollowElemContent>
|
</PortalToFollowElemContent>
|
||||||
|
@ -8,11 +8,13 @@ type Props = {
|
|||||||
vars: NodeOutPutVar[]
|
vars: NodeOutPutVar[]
|
||||||
onChange: (value: ValueSelector, varDetail: Var) => void
|
onChange: (value: ValueSelector, varDetail: Var) => void
|
||||||
itemWidth?: number
|
itemWidth?: number
|
||||||
|
isSupportFileVar?: boolean
|
||||||
}
|
}
|
||||||
const VarReferencePopup: FC<Props> = ({
|
const VarReferencePopup: FC<Props> = ({
|
||||||
vars,
|
vars,
|
||||||
onChange,
|
onChange,
|
||||||
itemWidth,
|
itemWidth,
|
||||||
|
isSupportFileVar = true,
|
||||||
}) => {
|
}) => {
|
||||||
// max-h-[300px] overflow-y-auto todo: use portal to handle long list
|
// max-h-[300px] overflow-y-auto todo: use portal to handle long list
|
||||||
return (
|
return (
|
||||||
@ -24,7 +26,7 @@ const VarReferencePopup: FC<Props> = ({
|
|||||||
vars={vars}
|
vars={vars}
|
||||||
onChange={onChange}
|
onChange={onChange}
|
||||||
itemWidth={itemWidth}
|
itemWidth={itemWidth}
|
||||||
isSupportFileVar
|
isSupportFileVar={isSupportFileVar}
|
||||||
/>
|
/>
|
||||||
</div >
|
</div >
|
||||||
)
|
)
|
||||||
|
@ -89,6 +89,7 @@ const Panel: FC<NodePanelProps<CodeNodeType>> = ({
|
|||||||
list={inputs.variables}
|
list={inputs.variables}
|
||||||
onChange={handleVarListChange}
|
onChange={handleVarListChange}
|
||||||
filterVar={filterVar}
|
filterVar={filterVar}
|
||||||
|
isSupportFileVar={false}
|
||||||
/>
|
/>
|
||||||
</Field>
|
</Field>
|
||||||
<Split />
|
<Split />
|
||||||
|
@ -144,6 +144,7 @@ const ConfigPromptItem: FC<Props> = ({
|
|||||||
onEditionTypeChange={onEditionTypeChange}
|
onEditionTypeChange={onEditionTypeChange}
|
||||||
varList={varList}
|
varList={varList}
|
||||||
handleAddVariable={handleAddVariable}
|
handleAddVariable={handleAddVariable}
|
||||||
|
isSupportFileVar
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -67,6 +67,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||||||
handleStop,
|
handleStop,
|
||||||
varInputs,
|
varInputs,
|
||||||
runResult,
|
runResult,
|
||||||
|
filterJinjia2InputVar,
|
||||||
} = useConfig(id, data)
|
} = useConfig(id, data)
|
||||||
|
|
||||||
const model = inputs.model
|
const model = inputs.model
|
||||||
@ -194,7 +195,8 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||||||
list={inputs.prompt_config?.jinja2_variables || []}
|
list={inputs.prompt_config?.jinja2_variables || []}
|
||||||
onChange={handleVarListChange}
|
onChange={handleVarListChange}
|
||||||
onVarNameChange={handleVarNameChange}
|
onVarNameChange={handleVarNameChange}
|
||||||
filterVar={filterVar}
|
filterVar={filterJinjia2InputVar}
|
||||||
|
isSupportFileVar={false}
|
||||||
/>
|
/>
|
||||||
</Field>
|
</Field>
|
||||||
)}
|
)}
|
||||||
@ -233,6 +235,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||||||
hasSetBlockStatus={hasSetBlockStatus}
|
hasSetBlockStatus={hasSetBlockStatus}
|
||||||
nodesOutputVars={availableVars}
|
nodesOutputVars={availableVars}
|
||||||
availableNodes={availableNodesWithParent}
|
availableNodes={availableNodesWithParent}
|
||||||
|
isSupportFileVar
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
|
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
|
||||||
|
@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
|||||||
}, [inputs, setInputs])
|
}, [inputs, setInputs])
|
||||||
|
|
||||||
const filterInputVar = useCallback((varPayload: Var) => {
|
const filterInputVar = useCallback((varPayload: Var) => {
|
||||||
|
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const filterJinjia2InputVar = useCallback((varPayload: Var) => {
|
||||||
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const filterMemoryPromptVar = useCallback((varPayload: Var) => {
|
const filterMemoryPromptVar = useCallback((varPayload: Var) => {
|
||||||
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
|||||||
handleRun,
|
handleRun,
|
||||||
handleStop,
|
handleStop,
|
||||||
runResult,
|
runResult,
|
||||||
|
filterJinjia2InputVar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,6 +64,7 @@ const Panel: FC<NodePanelProps<TemplateTransformNodeType>> = ({
|
|||||||
onChange={handleVarListChange}
|
onChange={handleVarListChange}
|
||||||
onVarNameChange={handleVarNameChange}
|
onVarNameChange={handleVarNameChange}
|
||||||
filterVar={filterVar}
|
filterVar={filterVar}
|
||||||
|
isSupportFileVar={false}
|
||||||
/>
|
/>
|
||||||
</Field>
|
</Field>
|
||||||
<Split />
|
<Split />
|
||||||
|
Loading…
x
Reference in New Issue
Block a user