Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly 2024-11-11 14:00:53 +08:00
commit 5cdbfe2f41
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
77 changed files with 851 additions and 581 deletions

36
.github/actions/setup-poetry/action.yml vendored Normal file
View File

@ -0,0 +1,36 @@
name: Setup Poetry and Python
inputs:
python-version:
description: Python version to use and the Poetry installed with
required: true
default: '3.10'
poetry-version:
description: Poetry version to set up
required: true
default: '1.8.4'
poetry-lockfile:
description: Path to the Poetry lockfile to restore cache from
required: true
default: ''
runs:
using: composite
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
cache: pip
- name: Install Poetry
shell: bash
run: pip install poetry==${{ inputs.poetry-version }}
- name: Restore Poetry cache
if: ${{ inputs.poetry-lockfile != '' }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
cache: poetry
cache-dependency-path: ${{ inputs.poetry-lockfile }}

View File

@ -28,15 +28,11 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
- name: Setup Poetry and Python ${{ matrix.python-version }}
uses: ./.github/actions/setup-poetry
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: api/poetry.lock
poetry-lockfile: api/poetry.lock
- name: Check Poetry lockfile
run: |

View File

@ -15,25 +15,15 @@ concurrency:
jobs:
db-migration-test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.10"
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
- name: Setup Poetry and Python
uses: ./.github/actions/setup-poetry
with:
python-version: ${{ matrix.python-version }}
cache-dependency-path: |
api/pyproject.toml
api/poetry.lock
- name: Install Poetry
uses: abatilo/actions-poetry@v3
poetry-lockfile: api/poetry.lock
- name: Install dependencies
run: poetry install -C api

View File

@ -22,34 +22,28 @@ jobs:
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: api/**
files: |
api/**
.github/workflows/style.yml
- name: Install Poetry
- name: Setup Poetry and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: abatilo/actions-poetry@v3
uses: ./.github/actions/setup-poetry
- name: Set up Python
uses: actions/setup-python@v5
if: steps.changed-files.outputs.any_changed == 'true'
with:
python-version: '3.10'
- name: Python dependencies
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry install -C api --only lint
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff check ./api
run: |
poetry run -C api ruff check ./api
poetry run -C api ruff format --check ./api
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
- name: Ruff formatter check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff format --check ./api
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."

View File

@ -28,15 +28,11 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
- name: Setup Poetry and Python ${{ matrix.python-version }}
uses: ./.github/actions/setup-poetry
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: api/poetry.lock
poetry-lockfile: api/poetry.lock
- name: Check Poetry lockfile
run: |

View File

@ -177,3 +177,4 @@ To protect your privacy, please avoid posting security issues on GitHub. Instead
## License
This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions.

View File

@ -367,6 +367,10 @@ LOG_FILE=
LOG_FILE_MAX_SIZE=20
# Log file max backup count
LOG_FILE_BACKUP_COUNT=5
# Log dateformat
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
@ -407,3 +411,5 @@ MARKETPLACE_API_URL=https://marketplace.dify.ai
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false

View File

@ -4,7 +4,7 @@ FROM python:3.10-slim-bookworm AS base
WORKDIR /app/api
# Install Poetry
ENV POETRY_VERSION=1.8.3
ENV POETRY_VERSION=1.8.4
# if you located in China, you can use aliyun mirror to speed up
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/

View File

@ -429,7 +429,7 @@ class LoggingConfig(BaseSettings):
LOG_TZ: Optional[str] = Field(
description="Timezone for log timestamps (e.g., 'America/New_York')",
default=None,
default="UTC",
)
@ -664,6 +664,11 @@ class DataSetConfig(BaseSettings):
default=500,
)
CREATE_TIDB_SERVICE_JOB_ENABLED: bool = Field(
description="Enable or disable create tidb service job",
default=False,
)
class WorkspaceConfig(BaseSettings):
"""

View File

@ -328,8 +328,11 @@ class DatasetInitApi(Resource):
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"],
)
except InvokeAuthorizationError:
raise ProviderNotInitializeError(

View File

@ -62,9 +62,10 @@ class ConversationDetailApi(Resource):
conversation_id = str(c_id)
try:
return ConversationService.delete(app_model, conversation_id, end_user)
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 200
class ConversationRenameApi(Resource):

View File

@ -10,6 +10,7 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import message_file_fields
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser
from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError
@ -55,7 +56,7 @@ class MessageListApi(Resource):
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields)),

View File

@ -29,6 +29,7 @@ from core.model_runtime.entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.utils.extract_thread_messages import extract_thread_messages
@ -488,24 +489,27 @@ class BaseAgentRunner(AppRunner):
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
assert message.app_model_config
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
else:
file_objs = []
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
return UserPromptMessage(content=prompt_message_contents)
else:
if not files:
return UserPromptMessage(content=message.query)
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if not file_extra_config:
return UserPromptMessage(content=message.query)
image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
if not file_objs:
return UserPromptMessage(content=message.query)
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
return UserPromptMessage(content=prompt_message_contents)

View File

@ -10,6 +10,7 @@ from core.model_runtime.entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.utils.encoders import jsonable_encoder
@ -37,9 +38,26 @@ class CotChatAgentRunner(CotAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View File

@ -22,6 +22,7 @@ from core.model_runtime.entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
@ -392,9 +393,26 @@ class FunctionCallAgentRunner(BaseAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View File

@ -4,7 +4,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
from core.file import FileExtraConfig, FileTransferMethod, FileType
from core.file import FileTransferMethod, FileType, FileUploadConfig
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode
@ -211,7 +211,7 @@ class TracingConfigEntity(BaseModel):
class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None
file_upload: Optional[FileUploadConfig] = None
opening_statement: Optional[str] = None
suggested_questions: list[str] = []
suggested_questions_after_answer: bool = False

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any
from core.file import FileExtraConfig
from core.file import FileUploadConfig
class FileUploadConfigManager:
@ -29,19 +29,18 @@ class FileUploadConfigManager:
if is_vision:
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
return FileExtraConfig.model_validate(data)
return FileUploadConfig.model_validate(data)
@classmethod
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for file upload feature
:param config: app model config args
:param is_vision: if True, the feature is vision feature
"""
if not config.get("file_upload"):
config["file_upload"] = {}
else:
FileExtraConfig.model_validate(config["file_upload"])
FileUploadConfig.model_validate(config["file_upload"])
return config, ["file_upload"]

View File

@ -52,9 +52,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, is_vision=False
)
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# opening_statement

View File

@ -26,7 +26,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
@ -109,13 +108,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# parse files
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
@ -138,10 +134,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser
from models.enums import CreatedByRole
logger = logging.getLogger(__name__)
@ -108,8 +107,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
@ -117,8 +114,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
@ -140,10 +135,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -3,12 +3,11 @@ from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import VariableEntityType
from core.file import File, FileExtraConfig
from core.file import File, FileUploadConfig
from factories import file_factory
if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity
from models.enums import CreatedByRole
class BaseAppGenerator:
@ -17,8 +16,6 @@ class BaseAppGenerator:
*,
user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig",
user_id: str,
role: "CreatedByRole",
) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
@ -35,9 +32,7 @@ class BaseAppGenerator:
k: file_factory.build_from_mapping(
mapping=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
@ -51,9 +46,7 @@ class BaseAppGenerator:
k: file_factory.build_from_mappings(
mappings=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,

View File

@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, EndUser
logger = logging.getLogger(__name__)
@ -111,8 +110,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
@ -120,8 +117,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
@ -143,10 +138,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -22,7 +22,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Message
from models.enums import CreatedByRole
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
@ -98,8 +97,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
@ -107,8 +104,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
@ -120,7 +115,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id)
# init application generate entity
@ -128,7 +122,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query,
files=file_objs,
user_id=user.id,
@ -269,14 +264,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["model"] = model_dict
# parse files
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:

View File

@ -46,9 +46,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, is_vision=False
)
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech

View File

@ -25,7 +25,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow
from models.enums import CreatedByRole
logger = logging.getLogger(__name__)
@ -82,15 +81,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
@ -112,7 +107,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
files=system_files,
user_id=user.id,
stream=stream,

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file.models import File
from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
from core.ops.ops_trace_manager import TraceQueueManager
@ -80,6 +80,7 @@ class AppGenerateEntity(BaseModel):
# app config
app_config: AppConfig
file_upload_config: Optional[FileUploadConfig] = None
inputs: Mapping[str, Any]
files: Sequence[File]

View File

@ -2,13 +2,13 @@ from .constants import FILE_MODEL_IDENTITY
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
from .models import (
File,
FileExtraConfig,
FileUploadConfig,
ImageConfig,
)
__all__ = [
"FileType",
"FileExtraConfig",
"FileUploadConfig",
"FileTransferMethod",
"FileBelongsTo",
"File",

View File

@ -33,25 +33,28 @@ def get_attr(*, file: File, attr: FileAttribute):
raise ValueError(f"Invalid file attribute: {attr}")
def to_prompt_message_content(f: File, /):
def to_prompt_message_content(
f: File,
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
):
"""
Convert a File object to an ImagePromptMessageContent object.
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
This function takes a File object and converts it to an ImagePromptMessageContent
object, which can be used as a prompt for image-based AI models.
This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image or audio-based AI models.
Args:
file (File): The File object to convert. Must be of type FileType.IMAGE.
f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
Returns:
ImagePromptMessageContent: An object containing the image data and detail level.
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
Raises:
ValueError: If the file is not an image or if the file data is missing.
Note:
The detail level of the image prompt is determined by the file's extra_config.
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
ValueError: If the file type is not supported or if required data is missing.
"""
match f.type:
case FileType.IMAGE:
@ -60,12 +63,7 @@ def to_prompt_message_content(f: File, /):
else:
data = _to_base64_data_string(f)
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
detail = f._extra_config.image_config.detail
else:
detail = ImagePromptMessageContent.DETAIL.LOW
return ImagePromptMessageContent(data=data, detail=detail)
return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f)
if f.extension is None:
@ -78,7 +76,7 @@ def to_prompt_message_content(f: File, /):
data = _to_base64_data_string(f)
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")
raise ValueError("file type f.type is not supported")
def download(f: File, /):

View File

@ -21,7 +21,7 @@ class ImageConfig(BaseModel):
detail: ImagePromptMessageContent.DETAIL | None = None
class FileExtraConfig(BaseModel):
class FileUploadConfig(BaseModel):
"""
File Upload Entity.
"""
@ -46,7 +46,6 @@ class File(BaseModel):
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
mime_type: Optional[str] = None
size: int = -1
_extra_config: FileExtraConfig | None = None
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")
@ -107,34 +106,4 @@ class File(BaseModel):
case FileTransferMethod.TOOL_FILE:
if not self.related_id:
raise ValueError("Missing file related_id")
# Validate the extra config.
if not self._extra_config:
return self
if self._extra_config.allowed_file_types:
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
raise ValueError(f"Invalid file type: {self.type}")
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
raise ValueError(f"Invalid file extension: {self.extension}")
if (
self._extra_config.allowed_upload_methods
and self.transfer_method not in self._extra_config.allowed_upload_methods
):
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
match self.type:
case FileType.IMAGE:
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
if not self._extra_config.image_config:
return self
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
if (
self._extra_config.image_config.transfer_methods
and self.transfer_method not in self._extra_config.image_config.transfer_methods
):
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
return self

View File

@ -0,0 +1,3 @@
from .code_executor import CodeExecutor, CodeLanguage
__all__ = ["CodeExecutor", "CodeLanguage"]

View File

@ -1,7 +1,8 @@
import logging
from collections.abc import Mapping
from enum import Enum
from threading import Lock
from typing import Optional
from typing import Any, Optional
from httpx import Timeout, post
from pydantic import BaseModel
@ -117,7 +118,7 @@ class CodeExecutor:
return response.data.stdout or ""
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
"""
Execute code
:param language: code language

View File

@ -2,6 +2,8 @@ import json
import re
from abc import ABC, abstractmethod
from base64 import b64encode
from collections.abc import Mapping
from typing import Any
class TemplateTransformer(ABC):
@ -10,7 +12,7 @@ class TemplateTransformer(ABC):
_result_tag: str = "<<RESULT>>"
@classmethod
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
"""
Transform code to python runner
:param code: code
@ -48,13 +50,13 @@ class TemplateTransformer(ABC):
pass
@classmethod
def serialize_inputs(cls, inputs: dict) -> str:
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
@classmethod
def assemble_runner_script(cls, code: str, inputs: dict) -> str:
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
# assemble runner script
script = cls.get_runner_script()
script = script.replace(cls._code_placeholder, code)

View File

@ -81,15 +81,18 @@ class TokenBufferMemory:
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
)
if workflow_run:
if workflow_run and workflow_run.workflow:
file_extra_config = FileUploadConfigManager.convert(
workflow_run.workflow.features_dict, is_vision=False
)
detail = ImagePromptMessageContent.DETAIL.LOW
if file_extra_config and app_record:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
)
if file_extra_config.image_config and file_extra_config.image_config.detail:
detail = file_extra_config.image_config.detail
else:
file_objs = []
@ -98,12 +101,16 @@ class TokenBufferMemory:
else:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs:
if file_obj.type in {FileType.IMAGE, FileType.AUDIO}:
prompt_message = file_manager.to_prompt_message_content(file_obj)
for file in file_objs:
if file.type in {FileType.IMAGE, FileType.AUDIO}:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=message.query))

View File

@ -13,9 +13,9 @@ parameter_rules:
use_template: max_tokens
required: true
type: int
default: 4096
default: 8192
min: 1
max: 4096
max: 8192
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.

View File

@ -1,6 +1,6 @@
provider: vessl_ai
label:
en_US: vessl_ai
en_US: VESSL AI
icon_small:
en_US: icon_s_en.svg
icon_large:
@ -20,28 +20,28 @@ model_credential_schema:
label:
en_US: Model Name
placeholder:
en_US: Enter your model name
en_US: Enter model name
credential_form_schemas:
- variable: endpoint_url
label:
en_US: endpoint url
en_US: Endpoint Url
type: text-input
required: true
placeholder:
en_US: Enter the url of your endpoint url
en_US: Enter VESSL AI service endpoint url
- variable: api_key
required: true
label:
en_US: API Key
type: secret-input
placeholder:
en_US: Enter your VESSL AI secret key
en_US: Enter VESSL AI secret key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
en_US: Completion Mode
type: select
required: false
default: chat

View File

@ -54,3 +54,7 @@ class LangSmithConfig(BaseTracingConfig):
raise ValueError("endpoint must start with https://")
return v
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -23,6 +23,11 @@ class BaseTraceInfo(BaseModel):
return v
return ""
class Config:
json_encoders = {
datetime: lambda v: v.isoformat(),
}
class WorkflowTraceInfo(BaseTraceInfo):
workflow_data: Any
@ -100,6 +105,12 @@ class GenerateNameTraceInfo(BaseTraceInfo):
tenant_id: str
class TaskData(BaseModel):
app_id: str
trace_info_type: str
trace_info: Any
trace_info_info_map = {
"WorkflowTraceInfo": WorkflowTraceInfo,
"MessageTraceInfo": MessageTraceInfo,

View File

@ -6,12 +6,13 @@ import threading
import time
from datetime import timedelta
from typing import Any, Optional, Union
from uuid import UUID
from uuid import UUID, uuid4
from flask import current_app
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
LangfuseConfig,
LangSmithConfig,
TracingProviderEnum,
@ -22,6 +23,7 @@ from core.ops.entities.trace_entity import (
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
@ -30,6 +32,7 @@ from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
@ -740,10 +743,17 @@ class TraceQueueManager:
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
file_id = uuid4().hex
trace_info = task.execute()
task_data = {
task_data = TaskData(
app_id=task.app_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
"trace_info_type": type(trace_info).__name__,
"trace_info": trace_info.model_dump() if trace_info else {},
}
process_trace_tasks.delay(task_data)
process_trace_tasks.delay(file_info)

View File

@ -15,6 +15,7 @@ from core.model_runtime.entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
@ -26,8 +27,13 @@ class AdvancedPromptTransform(PromptTransform):
Advanced Prompt Transform for Workflow LLM Node.
"""
def __init__(self, with_variable_tmpl: bool = False) -> None:
def __init__(
self,
with_variable_tmpl: bool = False,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
) -> None:
self.with_variable_tmpl = with_variable_tmpl
self.image_detail_config = image_detail_config
def get_prompt(
self,

View File

@ -49,13 +49,7 @@ class CodeNode(BaseNode[CodeNodeData]):
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=f"Variable `{variable_selector.value_selector}` not found",
)
variables[variable_name] = variable.to_object()
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(

View File

@ -13,6 +13,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from factories import file_factory
from models.workflow import WorkflowNodeExecutionStatus
from .entities import (
@ -161,16 +162,15 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mimetype=content_type,
)
files.append(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file.id,
filename=filename,
extension=extension,
mime_type=content_type,
)
mapping = {
"tool_file_id": tool_file.id,
"type": FileType.IMAGE.value,
"transfer_method": FileTransferMethod.TOOL_FILE.value,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
return files

View File

@ -156,7 +156,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=0,
pre_iteration_output=None,
)
outputs: list[Any] = []
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
futures: list[Future] = []
@ -214,6 +214,8 @@ class IterationNode(BaseNode[IterationNodeData]):
graph_engine,
iteration_graph,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -425,7 +427,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs.insert(current_index, None)
outputs[current_index] = None
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
@ -473,7 +475,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield metadata_event
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
outputs.insert(current_index, current_iteration_output)
outputs[current_index] = current_iteration_output
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])

View File

@ -49,8 +49,14 @@ class Limit(BaseModel):
size: int = -1
class ExtractConfig(BaseModel):
enabled: bool = False
serial: str = "1"
class ListOperatorNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderBy
limit: Limit
extract_by: ExtractConfig

View File

@ -58,6 +58,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
if self.node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
if self.node_data.order_by.enabled:
variable = self._apply_order(variable)
@ -140,6 +144,16 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
result = variable.value[: self.node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - 1
if len(variable.value) > int(value):
result = variable.value[value]
else:
result = ""
return variable.model_copy(update={"value": [result]})
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
match key:

View File

@ -34,12 +34,7 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if value is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Variable {variable_name} not found in variable pool",
)
variables[variable_name] = value.to_object()
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(

View File

@ -21,7 +21,8 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.tools import ToolFile
from factories import file_factory
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ToolNodeData
@ -192,19 +193,17 @@ class ToolNode(BaseNode[ToolNodeData]):
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
files.append(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
remote_url=url,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
size=tool_file.size,
)
mapping = {
"tool_file_id": tool_file_id,
"type": FileType.IMAGE,
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)

View File

@ -5,10 +5,10 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
from core.app.app_config.entities import FileExtraConfig
from core.app.app_config.entities import FileUploadConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File, FileTransferMethod, FileType, ImageConfig
from core.file.models import File, FileTransferMethod, ImageConfig
from core.workflow.callbacks import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
@ -22,6 +22,7 @@ from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.llm import LLMNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from factories import file_factory
from models.enums import UserFrom
from models.workflow import (
Workflow,
@ -372,19 +373,17 @@ class WorkflowEntry:
for item in input_value:
if isinstance(item, dict) and "type" in item and item["type"] == "image":
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
file = File(
mapping = {
"id": item.get("id"),
"transfer_method": transfer_method,
"upload_file_id": item.get("upload_file_id"),
"url": item.get("url"),
}
config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
remote_url=item.get("url")
if transfer_method == FileTransferMethod.REMOTE_URL
else None,
related_id=item.get("upload_file_id")
if transfer_method == FileTransferMethod.LOCAL_FILE
else None,
_extra_config=FileExtraConfig(
image_config=ImageConfig(detail=detail) if detail else None
),
config=config,
)
new_value.append(file)

View File

@ -1,5 +1,6 @@
from datetime import timedelta
import pytz
from celery import Celery, Task
from celery.schedules import crontab
from flask import Flask
@ -43,6 +44,11 @@ def init_app(app: Flask) -> Celery:
result_backend=dify_config.CELERY_RESULT_BACKEND,
broker_transport_options=broker_transport_options,
broker_connection_retry_on_startup=True,
worker_log_format=dify_config.LOG_FORMAT,
worker_task_log_format=dify_config.LOG_FORMAT,
worker_logfile=dify_config.LOG_FILE,
worker_hijack_root_logger=False,
timezone=pytz.timezone(dify_config.LOG_TZ),
)
if dify_config.BROKER_USE_SSL:

View File

@ -1,23 +1,21 @@
import mimetypes
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from typing import Any
import httpx
from sqlalchemy import select
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
from core.helper import ssrf_proxy
from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile
from models.enums import CreatedByRole
def build_from_message_files(
*,
message_files: Sequence["MessageFile"],
tenant_id: str,
config: FileExtraConfig,
config: FileUploadConfig,
) -> Sequence[File]:
results = [
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
@ -31,7 +29,7 @@ def build_from_message_file(
*,
message_file: "MessageFile",
tenant_id: str,
config: FileExtraConfig,
config: FileUploadConfig,
):
mapping = {
"transfer_method": message_file.transfer_method,
@ -43,8 +41,6 @@ def build_from_message_file(
return build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
user_id=message_file.created_by,
role=CreatedByRole(message_file.created_by_role),
config=config,
)
@ -53,38 +49,30 @@ def build_from_mapping(
*,
mapping: Mapping[str, Any],
tenant_id: str,
user_id: str,
role: "CreatedByRole",
config: FileExtraConfig,
):
config: FileUploadConfig | None = None,
) -> File:
config = config or FileUploadConfig()
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
match transfer_method:
case FileTransferMethod.REMOTE_URL:
file = _build_from_remote_url(
mapping=mapping,
tenant_id=tenant_id,
config=config,
transfer_method=transfer_method,
)
case FileTransferMethod.LOCAL_FILE:
file = _build_from_local_file(
mapping=mapping,
tenant_id=tenant_id,
user_id=user_id,
role=role,
config=config,
transfer_method=transfer_method,
)
case FileTransferMethod.TOOL_FILE:
file = _build_from_tool_file(
mapping=mapping,
tenant_id=tenant_id,
user_id=user_id,
config=config,
transfer_method=transfer_method,
)
case _:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
build_functions: dict[FileTransferMethod, Callable] = {
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
}
build_func = build_functions.get(transfer_method)
if not build_func:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
file = build_func(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
)
if not _is_file_valid_with_config(file=file, config=config):
raise ValueError(f"File validation failed for file: {file.filename}")
return file
@ -92,10 +80,8 @@ def build_from_mapping(
def build_from_mappings(
*,
mappings: Sequence[Mapping[str, Any]],
config: FileExtraConfig | None,
config: FileUploadConfig | None,
tenant_id: str,
user_id: str,
role: "CreatedByRole",
) -> Sequence[File]:
if not config:
return []
@ -104,8 +90,6 @@ def build_from_mappings(
build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
user_id=user_id,
role=role,
config=config,
)
for mapping in mappings
@ -128,31 +112,20 @@ def _build_from_local_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
user_id: str,
role: "CreatedByRole",
config: FileExtraConfig,
transfer_method: FileTransferMethod,
):
# check if the upload file exists.
) -> File:
file_type = FileType.value_of(mapping.get("type"))
stmt = select(UploadFile).where(
UploadFile.id == mapping.get("upload_file_id"),
UploadFile.tenant_id == tenant_id,
UploadFile.created_by == user_id,
UploadFile.created_by_role == role,
)
if file_type == FileType.IMAGE:
stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS))
elif file_type == FileType.VIDEO:
stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS))
elif file_type == FileType.AUDIO:
stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS))
elif file_type == FileType.DOCUMENT:
stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS))
row = db.session.scalar(stmt)
if row is None:
raise ValueError("Invalid upload file")
file = File(
return File(
id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
@ -162,23 +135,37 @@ def _build_from_local_file(
transfer_method=transfer_method,
remote_url=row.source_url,
related_id=mapping.get("upload_file_id"),
_extra_config=config,
size=row.size,
)
return file
def _build_from_remote_url(
*,
mapping: Mapping[str, Any],
tenant_id: str,
config: FileExtraConfig,
transfer_method: FileTransferMethod,
):
) -> File:
url = mapping.get("url")
if not url:
raise ValueError("Invalid file url")
mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
return File(
id=mapping.get("id"),
filename=filename,
tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")),
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
extension=extension,
size=file_size,
)
def _get_remote_file_info(url: str):
mime_type = mimetypes.guess_type(url)[0] or ""
file_size = -1
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
@ -186,56 +173,34 @@ def _build_from_remote_url(
resp = ssrf_proxy.head(url, follow_redirects=True)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = content_disposition.split("filename=")[-1].strip('"')
filename = str(content_disposition.split("filename=")[-1].strip('"'))
file_size = int(resp.headers.get("Content-Length", file_size))
mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
# Determine file extension
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
if not mime_type:
mime_type, _ = mimetypes.guess_type(url)
file = File(
id=mapping.get("id"),
filename=filename,
tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")),
transfer_method=transfer_method,
remote_url=url,
_extra_config=config,
mime_type=mime_type,
extension=extension,
size=file_size,
)
return file
return mime_type, filename, file_size
def _build_from_tool_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
user_id: str,
config: FileExtraConfig,
transfer_method: FileTransferMethod,
):
) -> File:
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id,
ToolFile.user_id == user_id,
)
.first()
)
if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
path = tool_file.file_key
if "." in path:
extension = "." + path.split("/")[-1].split(".")[-1]
else:
extension = ".bin"
file = File(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
return File(
id=mapping.get("id"),
tenant_id=tenant_id,
filename=tool_file.name,
@ -246,6 +211,21 @@ def _build_from_tool_file(
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
_extra_config=config,
)
return file
def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
return False
if config.allowed_extensions and file.extension not in config.allowed_extensions:
return False
if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods:
return False
if file.type == FileType.IMAGE and config.image_config:
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods:
return False
return True

View File

@ -23,7 +23,7 @@ v0_9_0_release_date= '2024-09-29 12:00:00'
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
sql = f"""UPDATE
public.messages
messages
SET
parent_message_id = '{UUID_NIL}'
WHERE
@ -37,7 +37,7 @@ WHERE
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
sql = f"""UPDATE
public.messages
messages
SET
parent_message_id = NULL
WHERE

View File

@ -18,7 +18,7 @@ from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, mapped_column
from configs import dify_config
from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from core.file.tool_file_parser import ToolFileParser
from extensions.ext_database import db
@ -962,9 +962,6 @@ class Message(Base):
"type": message_file.type,
},
tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
)
elif message_file.transfer_method == "remote_url":
if message_file.url is None:
@ -977,9 +974,6 @@ class Message(Base):
"url": message_file.url,
},
tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
)
elif message_file.transfer_method == "tool_file":
if message_file.upload_file_id is None:
@ -994,9 +988,6 @@ class Message(Base):
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
)
else:
raise ValueError(

174
api/poetry.lock generated
View File

@ -835,13 +835,13 @@ files = [
[[package]]
name = "blinker"
version = "1.8.2"
version = "1.9.0"
description = "Fast, simple object-to-object and broadcast signaling"
optional = false
python-versions = ">=3.8"
python-versions = ">=3.9"
files = [
{file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"},
{file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"},
{file = "blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc"},
{file = "blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf"},
]
[[package]]
@ -865,13 +865,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.35.55"
version = "1.35.57"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
files = [
{file = "botocore-1.35.55-py3-none-any.whl", hash = "sha256:3d54739e498534c9d7a6e9732ae2d17ed29c7d5e29fe36c956d8488b859538b0"},
{file = "botocore-1.35.55.tar.gz", hash = "sha256:61ae18f688250372d7b6046e35c86f8fd09a7c0f0064b52688f3490b4d6c9d6b"},
{file = "botocore-1.35.57-py3-none-any.whl", hash = "sha256:92ddd02469213766872cb2399269dd20948f90348b42bf08379881d5e946cc34"},
{file = "botocore-1.35.57.tar.gz", hash = "sha256:d96306558085baf0bcb3b022d7a8c39c93494f031edb376694d2b2dcd0e81327"},
]
[package.dependencies]
@ -1116,13 +1116,13 @@ files = [
[[package]]
name = "celery"
version = "5.3.6"
version = "5.4.0"
description = "Distributed Task Queue."
optional = false
python-versions = ">=3.8"
files = [
{file = "celery-5.3.6-py3-none-any.whl", hash = "sha256:9da4ea0118d232ce97dff5ed4974587fb1c0ff5c10042eb15278487cdd27d1af"},
{file = "celery-5.3.6.tar.gz", hash = "sha256:870cc71d737c0200c397290d730344cc991d13a057534353d124c9380267aab9"},
{file = "celery-5.4.0-py3-none-any.whl", hash = "sha256:369631eb580cf8c51a82721ec538684994f8277637edde2dfc0dacd73ed97f64"},
{file = "celery-5.4.0.tar.gz", hash = "sha256:504a19140e8d3029d5acad88330c541d4c3f64c789d85f94756762d8bca7e706"},
]
[package.dependencies]
@ -1138,7 +1138,7 @@ vine = ">=5.1.0,<6.0"
[package.extras]
arangodb = ["pyArango (>=2.0.2)"]
auth = ["cryptography (==41.0.5)"]
auth = ["cryptography (==42.0.5)"]
azureblockblob = ["azure-storage-blob (>=12.15.0)"]
brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"]
cassandra = ["cassandra-driver (>=3.25.0,<4)"]
@ -1148,22 +1148,23 @@ couchbase = ["couchbase (>=3.0.0)"]
couchdb = ["pycouchdb (==1.14.2)"]
django = ["Django (>=2.2.28)"]
dynamodb = ["boto3 (>=1.26.143)"]
elasticsearch = ["elastic-transport (<=8.10.0)", "elasticsearch (<=8.11.0)"]
elasticsearch = ["elastic-transport (<=8.13.0)", "elasticsearch (<=8.13.0)"]
eventlet = ["eventlet (>=0.32.0)"]
gcs = ["google-cloud-storage (>=2.10.0)"]
gevent = ["gevent (>=1.5.0)"]
librabbitmq = ["librabbitmq (>=2.0.0)"]
memcache = ["pylibmc (==1.6.3)"]
mongodb = ["pymongo[srv] (>=4.0.2)"]
msgpack = ["msgpack (==1.0.7)"]
pymemcache = ["python-memcached (==1.59)"]
msgpack = ["msgpack (==1.0.8)"]
pymemcache = ["python-memcached (>=1.61)"]
pyro = ["pyro4 (==4.82)"]
pytest = ["pytest-celery (==0.0.0)"]
pytest = ["pytest-celery[all] (>=1.0.0)"]
redis = ["redis (>=4.5.2,!=4.5.5,<6.0.0)"]
s3 = ["boto3 (>=1.26.143)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
solar = ["ephem (==4.1.5)"]
sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.4)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=1.3.1)"]
@ -2036,17 +2037,20 @@ tokenizer = ["tiktoken"]
[[package]]
name = "dataclass-wizard"
version = "0.26.0"
version = "0.27.0"
description = "Marshal dataclasses to/from JSON. Use field properties with initial values. Construct a dataclass schema with JSON input."
optional = false
python-versions = "*"
files = [
{file = "dataclass-wizard-0.26.0.tar.gz", hash = "sha256:227fa229332a2fcbfc1dca4dc5e090b01f313939f78b078a6f1fd3b5687a98a7"},
{file = "dataclass_wizard-0.26.0-py2.py3-none-any.whl", hash = "sha256:5e5821b6010f3c19309c889f6b46e6f50b9c4514a46a5bac42f90c8bbf09345f"},
{file = "dataclass-wizard-0.27.0.tar.gz", hash = "sha256:6bb5d7101949e8e6c0a3a305ceb9e68b24e231858aad8ed4a83c12414ded1d0d"},
{file = "dataclass_wizard-0.27.0-py2.py3-none-any.whl", hash = "sha256:a9ef05297c54823f6d82382123fd675347f6a1d02ee5a1c988a63855208aa6fb"},
]
[package.dependencies]
typing-extensions = {version = ">=4", markers = "python_version == \"3.9\" or python_version == \"3.10\""}
[package.extras]
dev = ["Sphinx (==5.3.0)", "bump2version (==1.0.1)", "coverage (>=6.2)", "dataclass-factory (==2.12)", "dataclasses-json (==0.5.6)", "flake8 (>=3)", "jsons (==1.6.1)", "pip (>=21.3.1)", "pytest (==7.0.1)", "pytest-cov (==3.0.0)", "pytest-mock (>=3.6.1)", "pytimeparse (==1.1.8)", "sphinx-issues (==3.0.1)", "sphinx-issues (==4.0.0)", "tox (==3.24.5)", "twine (==3.8.0)", "watchdog[watchmedo] (==2.1.6)", "wheel (==0.37.1)", "wheel (==0.42.0)"]
dev = ["Sphinx (==7.4.7)", "Sphinx (==8.1.3)", "bump2version (==1.0.1)", "coverage (>=6.2)", "dataclass-factory (==2.16)", "dataclasses-json (==0.6.7)", "flake8 (>=3)", "jsons (==1.6.3)", "pip (>=21.3.1)", "pytest (==8.3.3)", "pytest-cov (==6.0.0)", "pytest-mock (>=3.6.1)", "pytimeparse (==1.1.8)", "sphinx-issues (==5.0.0)", "tox (==4.23.2)", "twine (==5.1.1)", "watchdog[watchmedo] (==6.0.0)", "wheel (==0.45.0)"]
timedelta = ["pytimeparse (>=1.1.7)"]
yaml = ["PyYAML (>=5.3)"]
@ -2275,13 +2279,13 @@ files = [
[[package]]
name = "duckduckgo-search"
version = "6.3.3"
version = "6.3.4"
description = "Search for words, documents, images, news, maps and text translation using the DuckDuckGo.com search engine."
optional = false
python-versions = ">=3.8"
files = [
{file = "duckduckgo_search-6.3.3-py3-none-any.whl", hash = "sha256:63e5d6b958bd532016bc8a53e8b18717751bf7ef51b1c83e59b9f5780c79e64c"},
{file = "duckduckgo_search-6.3.3.tar.gz", hash = "sha256:4d49508f01f85c8675765fdd4cc25eedbb3450e129b35209897fded874f6568f"},
{file = "duckduckgo_search-6.3.4-py3-none-any.whl", hash = "sha256:0c18279fb43cbb43e51a251a2133cd0be09604f5a0395fe05409e213bed0cf00"},
{file = "duckduckgo_search-6.3.4.tar.gz", hash = "sha256:71317d0dee393cb2c0fb8d2eedc76bba0d8c93c752fe97be0030c39b89fd05f9"},
]
[package.dependencies]
@ -4662,13 +4666,13 @@ openai = ["openai (>=0.27.8)"]
[[package]]
name = "langsmith"
version = "0.1.140"
version = "0.1.142"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.140-py3-none-any.whl", hash = "sha256:3de70183ae19a4ada4d77a8a9f336ff95ca0ead98215771033ee889a2889fe19"},
{file = "langsmith-0.1.140.tar.gz", hash = "sha256:cb0a717d7b9e6d3145285d7ca0ab216e064cbe7a1ca4139fc04af57fb2315e70"},
{file = "langsmith-0.1.142-py3-none-any.whl", hash = "sha256:f639ca23c9a0bb77af5fb881679b2f66ff1f21f19d0bebf4e51375e7585a8b38"},
{file = "langsmith-0.1.142.tar.gz", hash = "sha256:f8a84d100f3052233ff0a1d66ae14c5dfc20b7e41a1601de011384f16ee6cb82"},
]
[package.dependencies]
@ -6016,13 +6020,13 @@ kerberos = ["requests-kerberos"]
[[package]]
name = "opentelemetry-api"
version = "1.28.0"
version = "1.28.1"
description = "OpenTelemetry Python API"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_api-1.28.0-py3-none-any.whl", hash = "sha256:8457cd2c59ea1bd0988560f021656cecd254ad7ef6be4ba09dbefeca2409ce52"},
{file = "opentelemetry_api-1.28.0.tar.gz", hash = "sha256:578610bcb8aa5cdcb11169d136cc752958548fb6ccffb0969c1036b0ee9e5353"},
{file = "opentelemetry_api-1.28.1-py3-none-any.whl", hash = "sha256:bfe86c95576cf19a914497f439fd79c9553a38de0adbdc26f7cfc46b0c00b16c"},
{file = "opentelemetry_api-1.28.1.tar.gz", hash = "sha256:6fa7295a12c707f5aebef82da3d9ec5afe6992f3e42bfe7bec0339a44b3518e7"},
]
[package.dependencies]
@ -6053,59 +6057,59 @@ test = ["pytest-grpc"]
[[package]]
name = "opentelemetry-instrumentation"
version = "0.49b0"
version = "0.49b1"
description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation-0.49b0-py3-none-any.whl", hash = "sha256:68364d73a1ff40894574cbc6138c5f98674790cae1f3b0865e21cf702f24dcb3"},
{file = "opentelemetry_instrumentation-0.49b0.tar.gz", hash = "sha256:398a93e0b9dc2d11cc8627e1761665c506fe08c6b2df252a2ab3ade53d751c46"},
{file = "opentelemetry_instrumentation-0.49b1-py3-none-any.whl", hash = "sha256:0a9d3821736104013693ef3b8a9d29b41f2f3a81ee2d8c9288b52d62bae5747c"},
{file = "opentelemetry_instrumentation-0.49b1.tar.gz", hash = "sha256:2d0e41181b7957ba061bb436b969ad90545ac3eba65f290830009b4264d2824e"},
]
[package.dependencies]
opentelemetry-api = ">=1.4,<2.0"
opentelemetry-semantic-conventions = "0.49b0"
opentelemetry-semantic-conventions = "0.49b1"
packaging = ">=18.0"
wrapt = ">=1.0.0,<2.0.0"
[[package]]
name = "opentelemetry-instrumentation-asgi"
version = "0.49b0"
version = "0.49b1"
description = "ASGI instrumentation for OpenTelemetry"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_asgi-0.49b0-py3-none-any.whl", hash = "sha256:722a90856457c81956c88f35a6db606cc7db3231046b708aae2ddde065723dbe"},
{file = "opentelemetry_instrumentation_asgi-0.49b0.tar.gz", hash = "sha256:959fd9b1345c92f20c6ef1d42f92ef6a76b3c3083fbc4104d59da6859b15b083"},
{file = "opentelemetry_instrumentation_asgi-0.49b1-py3-none-any.whl", hash = "sha256:8dcbc438cb138789fcb20ae38b6e7f23088e066d77b54bae205c5744856603c6"},
{file = "opentelemetry_instrumentation_asgi-0.49b1.tar.gz", hash = "sha256:d1a2b4cb76490be28bcad3c0f562c4b3c84157148c922ca298bb04ed9e36c005"},
]
[package.dependencies]
asgiref = ">=3.0,<4.0"
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.49b0"
opentelemetry-semantic-conventions = "0.49b0"
opentelemetry-util-http = "0.49b0"
opentelemetry-instrumentation = "0.49b1"
opentelemetry-semantic-conventions = "0.49b1"
opentelemetry-util-http = "0.49b1"
[package.extras]
instruments = ["asgiref (>=3.0,<4.0)"]
[[package]]
name = "opentelemetry-instrumentation-fastapi"
version = "0.49b0"
version = "0.49b1"
description = "OpenTelemetry FastAPI Instrumentation"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_fastapi-0.49b0-py3-none-any.whl", hash = "sha256:646e1b18523cbe6860ae9711eb2c7b9c85466c3c7697cd6b8fb5180d85d3fe6e"},
{file = "opentelemetry_instrumentation_fastapi-0.49b0.tar.gz", hash = "sha256:6d14935c41fd3e49328188b6a59dd4c37bd17a66b01c15b0c64afa9714a1f905"},
{file = "opentelemetry_instrumentation_fastapi-0.49b1-py3-none-any.whl", hash = "sha256:3398940102c8ef613b9c55fc4f179cc92413de456f6bec6eeb1995270de2b087"},
{file = "opentelemetry_instrumentation_fastapi-0.49b1.tar.gz", hash = "sha256:13d9d4d70b4bb831468b8e40807353731cad7fbfaeedde0070d93bcb2c417b07"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.49b0"
opentelemetry-instrumentation-asgi = "0.49b0"
opentelemetry-semantic-conventions = "0.49b0"
opentelemetry-util-http = "0.49b0"
opentelemetry-instrumentation = "0.49b1"
opentelemetry-instrumentation-asgi = "0.49b1"
opentelemetry-semantic-conventions = "0.49b1"
opentelemetry-util-http = "0.49b1"
[package.extras]
instruments = ["fastapi (>=0.58,<1.0)"]
@ -6126,44 +6130,44 @@ protobuf = ">=3.19,<5.0"
[[package]]
name = "opentelemetry-sdk"
version = "1.28.0"
version = "1.28.1"
description = "OpenTelemetry Python SDK"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_sdk-1.28.0-py3-none-any.whl", hash = "sha256:4b37da81d7fad67f6683c4420288c97f4ed0d988845d5886435f428ec4b8429a"},
{file = "opentelemetry_sdk-1.28.0.tar.gz", hash = "sha256:41d5420b2e3fb7716ff4981b510d551eff1fc60eb5a95cf7335b31166812a893"},
{file = "opentelemetry_sdk-1.28.1-py3-none-any.whl", hash = "sha256:72aad7f5fcbe37113c4ab4899f6cdeb6ac77ed3e62f25a85e3627b12583dad0f"},
{file = "opentelemetry_sdk-1.28.1.tar.gz", hash = "sha256:100fa371b2046ffba6a340c18f0b2a0463acad7461e5177e126693b613a6ca57"},
]
[package.dependencies]
opentelemetry-api = "1.28.0"
opentelemetry-semantic-conventions = "0.49b0"
opentelemetry-api = "1.28.1"
opentelemetry-semantic-conventions = "0.49b1"
typing-extensions = ">=3.7.4"
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.49b0"
version = "0.49b1"
description = "OpenTelemetry Semantic Conventions"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_semantic_conventions-0.49b0-py3-none-any.whl", hash = "sha256:0458117f6ead0b12e3221813e3e511d85698c31901cac84682052adb9c17c7cd"},
{file = "opentelemetry_semantic_conventions-0.49b0.tar.gz", hash = "sha256:dbc7b28339e5390b6b28e022835f9bac4e134a80ebf640848306d3c5192557e8"},
{file = "opentelemetry_semantic_conventions-0.49b1-py3-none-any.whl", hash = "sha256:dd6f3ac8169d2198c752e1a63f827e5f5e110ae9b0ce33f2aad9a3baf0739743"},
{file = "opentelemetry_semantic_conventions-0.49b1.tar.gz", hash = "sha256:91817883b159ffb94c2ca9548509c4fe0aafce7c24f437aa6ac3fc613aa9a758"},
]
[package.dependencies]
deprecated = ">=1.2.6"
opentelemetry-api = "1.28.0"
opentelemetry-api = "1.28.1"
[[package]]
name = "opentelemetry-util-http"
version = "0.49b0"
version = "0.49b1"
description = "Web util for OpenTelemetry"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_util_http-0.49b0-py3-none-any.whl", hash = "sha256:8661bbd6aea1839badc44de067ec9c15c05eab05f729f496c856c50a1203caf1"},
{file = "opentelemetry_util_http-0.49b0.tar.gz", hash = "sha256:02928496afcffd58a7c15baf99d2cedae9b8325a8ac52b0d0877b2e8f936dd1b"},
{file = "opentelemetry_util_http-0.49b1-py3-none-any.whl", hash = "sha256:0290b942f7888b6310df6803e52e12f4043b8f224db0659f62dc7b70059eb94f"},
{file = "opentelemetry_util_http-0.49b1.tar.gz", hash = "sha256:6c2bc6f7e20e286dbdfcccb9d895fa290ec9d7c596cdf2e06bf1d8e434b2edd0"},
]
[[package]]
@ -6307,13 +6311,13 @@ files = [
[[package]]
name = "packaging"
version = "24.1"
version = "24.2"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"},
{file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"},
{file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
{file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
]
[[package]]
@ -7319,13 +7323,13 @@ rsa = ["cryptography"]
[[package]]
name = "pyobvector"
version = "0.1.8"
version = "0.1.10"
description = "A python SDK for OceanBase Vector Store, based on SQLAlchemy, compatible with Milvus API."
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "pyobvector-0.1.8-py3-none-any.whl", hash = "sha256:d44c88df3930ea0f888dc1580d575e2434b17c9c5183f57d1da6fdfa43c6c252"},
{file = "pyobvector-0.1.8.tar.gz", hash = "sha256:78ab9b4d2e5d9903bb5f2dc7d356f5439eb1a8620d25335fad9f897921b75e2e"},
{file = "pyobvector-0.1.10-py3-none-any.whl", hash = "sha256:7ef0d20c640a948c7fe64f2f3bd4defda395e65c617152643340ed440056238c"},
{file = "pyobvector-0.1.10.tar.gz", hash = "sha256:30a7ad42ff8be0bf0c37a33d1acfb8b948e7f9b6ac3d482b85f9761c41af9bfb"},
]
[package.dependencies]
@ -9309,13 +9313,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"]
[[package]]
name = "tencentcloud-sdk-python-common"
version = "3.0.1261"
version = "3.0.1263"
description = "Tencent Cloud Common SDK for Python"
optional = false
python-versions = "*"
files = [
{file = "tencentcloud-sdk-python-common-3.0.1261.tar.gz", hash = "sha256:e34bd23f9b05f3da9b88263cad6f7d95a30e9a9d827924a848d10841171e6cbc"},
{file = "tencentcloud_sdk_python_common-3.0.1261-py2.py3-none-any.whl", hash = "sha256:c13cf1524abc550a2f2fb996bb7b426d428f2e1a0fa5d81ab321eae07ca5c680"},
{file = "tencentcloud-sdk-python-common-3.0.1263.tar.gz", hash = "sha256:3091024ece07982ec4829c661bc90474d2b9c5543965717f7136b9f66b201c34"},
{file = "tencentcloud_sdk_python_common-3.0.1263-py2.py3-none-any.whl", hash = "sha256:812cdc2d183d455472f8fee88d699acb869a8d8497cd09cd6d83596a98a8e6d7"},
]
[package.dependencies]
@ -9323,17 +9327,17 @@ requests = ">=2.16.0"
[[package]]
name = "tencentcloud-sdk-python-hunyuan"
version = "3.0.1261"
version = "3.0.1263"
description = "Tencent Cloud Hunyuan SDK for Python"
optional = false
python-versions = "*"
files = [
{file = "tencentcloud-sdk-python-hunyuan-3.0.1261.tar.gz", hash = "sha256:da9bef80d5491dab6b7ba3d7615d6597e546361cf8650a1a12db22472379ba19"},
{file = "tencentcloud_sdk_python_hunyuan-3.0.1261-py2.py3-none-any.whl", hash = "sha256:b925dac1d7cd98f475462f3e07131a7d5091cb8a5cdfff90e2b0720c5e02c189"},
{file = "tencentcloud-sdk-python-hunyuan-3.0.1263.tar.gz", hash = "sha256:4e9c0120ca7eca48983afec7ff6a04a4bd75c347070f942a7edd378c5f9b2767"},
{file = "tencentcloud_sdk_python_hunyuan-3.0.1263-py2.py3-none-any.whl", hash = "sha256:37446ef71d50a91dfe06d7c1704b1841aab079da29dc91099d2b793779e18dc2"},
]
[package.dependencies]
tencentcloud-sdk-python-common = "3.0.1261"
tencentcloud-sdk-python-common = "3.0.1263"
[[package]]
name = "termcolor"
@ -9729,13 +9733,13 @@ requests = ">=2.0.0"
[[package]]
name = "typer"
version = "0.12.5"
version = "0.13.0"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
optional = false
python-versions = ">=3.7"
files = [
{file = "typer-0.12.5-py3-none-any.whl", hash = "sha256:62fe4e471711b147e3365034133904df3e235698399bc4de2b36c8579298d52b"},
{file = "typer-0.12.5.tar.gz", hash = "sha256:f592f089bedcc8ec1b974125d64851029c3b1af145f04aca64d69410f0c9b722"},
{file = "typer-0.13.0-py3-none-any.whl", hash = "sha256:d85fe0b777b2517cc99c8055ed735452f2659cd45e451507c76f48ce5c1d00e2"},
{file = "typer-0.13.0.tar.gz", hash = "sha256:f1c7198347939361eec90139ffa0fd8b3df3a2259d5852a0f7400e476d95985c"},
]
[package.dependencies]
@ -9884,13 +9888,13 @@ files = [
[[package]]
name = "unstructured"
version = "0.16.4"
version = "0.16.5"
description = "A library that prepares raw documents for downstream ML tasks."
optional = false
python-versions = "<3.13,>=3.9.0"
files = [
{file = "unstructured-0.16.4-py3-none-any.whl", hash = "sha256:300e4a9e630c6d55484a62e90df23075e5abd04f17bf15043898bc0eff6c4070"},
{file = "unstructured-0.16.4.tar.gz", hash = "sha256:da05433db186f8251fc0b1b1b273f584ca2a71363d541ec1ab82ef55dc49055d"},
{file = "unstructured-0.16.5-py3-none-any.whl", hash = "sha256:d867e6d5c002c159997bb44df82c43531570c32fa87a010a0aae8a7a0e22ec49"},
{file = "unstructured-0.16.5.tar.gz", hash = "sha256:2c36de777f88529e0f7c306eb8116b755963928d50d331bbfee56e2f61fe023f"},
]
[package.dependencies]
@ -10744,13 +10748,13 @@ multidict = ">=4.0"
[[package]]
name = "yfinance"
version = "0.2.48"
version = "0.2.49"
description = "Download market data from Yahoo! Finance API"
optional = false
python-versions = "*"
files = [
{file = "yfinance-0.2.48-py2.py3-none-any.whl", hash = "sha256:eda797145faa4536595eb629f869d3616e58ed7e71de36856b19f1abaef71a5b"},
{file = "yfinance-0.2.48.tar.gz", hash = "sha256:1434cd8bf22f345fa27ef1ed82bfdd291c1bb5b6fe3067118a94e256aa90c4eb"},
{file = "yfinance-0.2.49-py2.py3-none-any.whl", hash = "sha256:cc9c7d09826e7eaee96d179395e814b911e083fbfb325c2fe693cae019b47f38"},
{file = "yfinance-0.2.49.tar.gz", hash = "sha256:e6b45f8392feb11360450630f86f96a46dfa708d77c334d5376564a9eead952b"},
]
[package.dependencies]
@ -10790,13 +10794,13 @@ pyjwt = ">=2.8.0,<2.9.0"
[[package]]
name = "zipp"
version = "3.20.2"
version = "3.21.0"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
python-versions = ">=3.9"
files = [
{file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"},
{file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"},
{file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"},
{file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"},
]
[package.extras]
@ -10994,4 +10998,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "bb8385625eb61de086b7a7156745066b4fb171d9ca67afd1d092fa7e872f3abd"
content-hash = "f20bd678044926913dbbc24bd0cf22503a75817aa55f59457ff7822032139b77"

View File

@ -118,7 +118,7 @@ beautifulsoup4 = "4.12.2"
boto3 = "1.35.17"
bs4 = "~0.0.1"
cachetools = "~5.3.0"
celery = "~5.3.6"
celery = "~5.4.0"
chardet = "~5.1.0"
cohere = "~5.2.4"
dashscope = { version = "~1.17.0", extras = ["tokenizer"] }

View File

@ -12,6 +12,8 @@ from models.dataset import TidbAuthBinding
@app.celery.task(queue="dataset")
def create_tidb_serverless_task():
click.echo(click.style("Start create tidb serverless task.", fg="green"))
if not dify_config.CREATE_TIDB_SERVICE_JOB_ENABLED:
return
tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER
start_at = time.perf_counter()
while True:

View File

@ -14,8 +14,6 @@ from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.models.document import Document as RAGDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
@ -37,6 +35,7 @@ from models.dataset import (
)
from models.model import UploadFile
from models.source import DataSourceOauthBinding
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity
from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
@ -1415,9 +1414,13 @@ class SegmentService:
created_by=current_user.id,
)
if document.doc_form == "qa_model":
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
db.session.add(segment_document)
# update document word count
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
# save vector index
@ -1436,6 +1439,7 @@ class SegmentService:
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
lock_name = "multi_add_segment_lock_document_id_{}".format(document.id)
increment_word_count = 0
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
@ -1461,7 +1465,10 @@ class SegmentService:
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]])
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
@ -1479,6 +1486,8 @@ class SegmentService:
)
if document.doc_form == "qa_model":
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
db.session.add(segment_document)
segment_data_list.append(segment_document)
@ -1487,7 +1496,9 @@ class SegmentService:
keywords_list.append(segment_item["keywords"])
else:
keywords_list.append(None)
# update document word count
document.word_count += increment_word_count
db.session.add(document)
try:
# save vector index
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
@ -1503,12 +1514,13 @@ class SegmentService:
@classmethod
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
segment_update_entity = SegmentUpdateEntity(**args)
indexing_cache_key = "segment_{}_indexing".format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is indexing, please try again later")
if "enabled" in args and args["enabled"] is not None:
action = args["enabled"]
if segment_update_entity.enabled is not None:
action = segment_update_entity.enabled
if segment.enabled != action:
if not action:
segment.enabled = action
@ -1521,37 +1533,34 @@ class SegmentService:
disable_segment_from_index_task.delay(segment.id)
return segment
if not segment.enabled:
if "enabled" in args and args["enabled"] is not None:
if not args["enabled"]:
if segment_update_entity.enabled is not None:
if not segment_update_entity.enabled:
raise ValueError("Can't update disabled segment")
else:
raise ValueError("Can't update disabled segment")
try:
content = args["content"]
word_count_change = segment.word_count
content = segment_update_entity.content
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
segment.answer = args["answer"]
if args.get("keywords"):
segment.keywords = args["keywords"]
segment.answer = segment_update_entity.answer
segment.word_count += len(segment_update_entity.answer)
word_count_change = segment.word_count - word_count_change
if segment_update_entity.keywords:
segment.keywords = segment_update_entity.keywords
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
db.session.commit()
# update document word count
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if "keywords" in args:
keyword = Keyword(dataset)
keyword.delete_by_ids([segment.index_node_id])
document = RAGDocument(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
keyword.add_texts([document], keywords_list=[args["keywords"]])
if segment_update_entity.enabled:
VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
@ -1565,7 +1574,10 @@ class SegmentService:
)
# calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
@ -1579,11 +1591,17 @@ class SegmentService:
segment.disabled_at = None
segment.disabled_by = None
if document.doc_form == "qa_model":
segment.answer = args["answer"]
segment.answer = segment_update_entity.answer
segment.word_count += len(segment_update_entity.answer)
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
db.session.add(segment)
db.session.commit()
# update segment vector index
VectorService.update_segment_vector(args["keywords"], segment, dataset)
VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset)
except Exception as e:
logging.exception("update segment index failed")
@ -1608,6 +1626,9 @@ class SegmentService:
redis_client.setex(indexing_cache_key, 600, 1)
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
# update document word count
document.word_count -= segment.word_count
db.session.add(document)
db.session.commit()

View File

@ -0,0 +1,10 @@
from typing import Optional
from pydantic import BaseModel
class SegmentUpdateEntity(BaseModel):
content: str
answer: Optional[str] = None
keywords: Optional[list[str]] = None
enabled: Optional[bool] = None

View File

@ -13,7 +13,7 @@ from core.app.app_config.entities import (
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.file.models import FileExtraConfig
from core.file.models import FileUploadConfig
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder
@ -381,7 +381,7 @@ class WorkflowConverter:
graph: dict,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileExtraConfig] = None,
file_upload: Optional[FileUploadConfig] = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
) -> dict:
"""

View File

@ -57,7 +57,7 @@ def batch_create_segment_to_index_task(
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0
for segment in content:
content = segment["content"]
doc_id = str(uuid.uuid4())
@ -86,8 +86,13 @@ def batch_create_segment_to_index_task(
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
db.session.add(segment_document)
document_segments.append(segment_document)
# update document word count
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
# add index to db
indexing_runner = IndexingRunner()
indexing_runner.batch_add_segments(document_segments, dataset)

View File

@ -1,17 +1,20 @@
import json
import logging
import time
from celery import shared_task
from flask import current_app
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
from core.ops.entities.trace_entity import trace_info_info_map
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.model import Message
from models.workflow import WorkflowRun
@shared_task(queue="ops_trace")
def process_trace_tasks(tasks_data):
def process_trace_tasks(file_info):
"""
Async process trace tasks
:param tasks_data: List of dictionaries containing task data
@ -20,9 +23,12 @@ def process_trace_tasks(tasks_data):
"""
from core.ops.ops_trace_manager import OpsTraceManager
trace_info = tasks_data.get("trace_info")
app_id = tasks_data.get("app_id")
trace_info_type = tasks_data.get("trace_info_type")
app_id = file_info.get("app_id")
file_id = file_info.get("file_id")
file_path = f"{OPS_FILE_PATH}{app_id}/{file_id}.json"
file_data = json.loads(storage.load(file_path))
trace_info = file_data.get("trace_info")
trace_info_type = file_data.get("trace_info_type")
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
if trace_info.get("message_data"):
@ -39,6 +45,10 @@ def process_trace_tasks(tasks_data):
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
end_at = time.perf_counter()
logging.info(f"Processing trace tasks success, app_id: {app_id}")
except Exception:
logging.exception("Processing trace tasks failed")
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logging.info(f"Processing trace tasks failed, app_id: {app_id}")
finally:
storage.delete(file_path)

View File

@ -430,37 +430,3 @@ def test_multi_colons_parse(setup_http_mock):
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
def test_image_file(monkeypatch):
from types import SimpleNamespace
monkeypatch.setattr(
"core.tools.tool_file_manager.ToolFileManager.create_file_by_raw",
lambda *args, **kwargs: SimpleNamespace(id="1"),
)
node = init_http_node(
config={
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "https://cloud.dify.ai/logo/logo-site.png",
"authorization": {
"type": "no-auth",
"config": None,
},
"params": "",
"headers": "",
"body": None,
},
}
)
result = node._run()
assert result.process_data is not None
assert result.outputs is not None
resp = result.outputs
assert len(resp.get("files", [])) == 1

View File

@ -0,0 +1,61 @@
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file.models import FileTransferMethod, FileUploadConfig, ImageConfig
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
def test_convert_with_vision():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
"image": {"detail": "high"},
}
}
result = FileUploadConfigManager.convert(config, is_vision=True)
expected = FileUploadConfig(
image_config=ImageConfig(
number_limits=5,
transfer_methods=[FileTransferMethod.REMOTE_URL],
detail=ImagePromptMessageContent.DETAIL.HIGH,
)
)
assert result == expected
def test_convert_without_vision():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
}
}
result = FileUploadConfigManager.convert(config, is_vision=False)
expected = FileUploadConfig(
image_config=ImageConfig(number_limits=5, transfer_methods=[FileTransferMethod.REMOTE_URL])
)
assert result == expected
def test_validate_and_set_defaults():
config = {}
result, keys = FileUploadConfigManager.validate_and_set_defaults(config)
assert "file_upload" in result
assert keys == ["file_upload"]
def test_validate_and_set_defaults_with_existing_config():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
}
}
result, keys = FileUploadConfigManager.validate_and_set_defaults(config)
assert "file_upload" in result
assert keys == ["file_upload"]
assert result["file_upload"]["enabled"] is True
assert result["file_upload"]["number_limits"] == 5
assert result["file_upload"]["allowed_file_upload_methods"] == [FileTransferMethod.REMOTE_URL]

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -134,7 +134,6 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
_extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)),
)
]

View File

@ -4,7 +4,14 @@ import pytest
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy
from core.workflow.nodes.list_operator.entities import (
ExtractConfig,
FilterBy,
FilterCondition,
Limit,
ListOperatorNodeData,
OrderBy,
)
from core.workflow.nodes.list_operator.exc import InvalidKeyError
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
from models.workflow import WorkflowNodeExecutionStatus
@ -22,6 +29,7 @@ def list_operator_node():
),
"order_by": OrderBy(enabled=False, value="asc"),
"limit": Limit(enabled=False, size=0),
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)

View File

@ -384,6 +384,7 @@ services:
NOTION_INTERNAL_SECRET: you-internal-secret
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: 1000
CREATE_TIDB_SERVICE_JOB_ENABLED: false
depends_on:
- db
- redis

View File

@ -54,6 +54,10 @@ LOG_FILE=
LOG_FILE_MAX_SIZE=20
# Log file max backup count
LOG_FILE_BACKUP_COUNT=5
# Log dateformat
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Debug mode, default is false.
# It is recommended to turn on this configuration for local development
@ -583,12 +587,13 @@ CODE_GENERATION_MAX_TOKENS=1024
# Multi-modal Configuration
# ------------------------------
# The format of the image sent when the multi-modal model is input,
# The format of the image/video sent when the multi-modal model is input,
# the default is base64, optional url.
# The delay of the call in url mode will be lower than that in base64 mode.
# It is generally recommended to use the more compatible base64 mode.
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image.
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video.
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
# Upload image file size limit, default 10M.
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
@ -906,3 +911,6 @@ POSITION_PROVIDER_EXCLUDES=
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
CSP_WHITELIST=
# Enable or disable create tidb service job
CREATE_TIDB_SERVICE_JOB_ENABLED=false

View File

@ -4,6 +4,10 @@ x-shared-env: &shared-api-worker-env
LOG_FILE: ${LOG_FILE:-}
LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20}
LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5}
# Log dateformat
LOG_DATEFORMAT: ${LOG_DATEFORMAT:-%Y-%m-%d %H:%M:%S}
# Log Timezone
LOG_TZ: ${LOG_TZ:-UTC}
DEBUG: ${DEBUG:-false}
FLASK_DEBUG: ${FLASK_DEBUG:-false}
SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U}
@ -214,6 +218,7 @@ x-shared-env: &shared-api-worker-env
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64}
MULTIMODAL_SEND_VIDEO_FORMAT: ${MULTIMODAL_SEND_VIDEO_FORMAT:-base64}
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}
UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50}
@ -270,6 +275,7 @@ x-shared-env: &shared-api-worker-env
OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test}
OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai}
OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G}
CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false}
services:
# API service

View File

@ -1,3 +1,4 @@
import React, { useEffect, useState } from 'react'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
@ -32,20 +33,31 @@ const MarkdownForm = ({ node }: any) => {
// </form>
const { onSend } = useChatContext()
const getFormValues = (children: any) => {
const formValues: { [key: string]: any } = {}
children.forEach((child: any) => {
if (child.tagName === SUPPORTED_TAGS.INPUT)
formValues[child.properties.name] = child.properties.value
if (child.tagName === SUPPORTED_TAGS.TEXTAREA)
formValues[child.properties.name] = child.properties.value
const [formValues, setFormValues] = useState<{ [key: string]: any }>({})
useEffect(() => {
const initialValues: { [key: string]: any } = {}
node.children.forEach((child: any) => {
if ([SUPPORTED_TAGS.INPUT, SUPPORTED_TAGS.TEXTAREA].includes(child.tagName))
initialValues[child.properties.name] = child.properties.value
})
return formValues
setFormValues(initialValues)
}, [node.children])
const getFormValues = (children: any) => {
const values: { [key: string]: any } = {}
children.forEach((child: any) => {
if ([SUPPORTED_TAGS.INPUT, SUPPORTED_TAGS.TEXTAREA].includes(child.tagName))
values[child.properties.name] = formValues[child.properties.name]
})
return values
}
const onSubmit = (e: any) => {
e.preventDefault()
const format = node.properties.dataFormat || DATA_FORMAT.TEXT
const result = getFormValues(node.children)
if (format === DATA_FORMAT.JSON) {
onSend?.(JSON.stringify(result))
}
@ -77,25 +89,22 @@ const MarkdownForm = ({ node }: any) => {
</label>
)
}
if (child.tagName === SUPPORTED_TAGS.INPUT) {
if (Object.values(SUPPORTED_TYPES).includes(child.properties.type)) {
return (
<Input
key={index}
type={child.properties.type}
name={child.properties.name}
placeholder={child.properties.placeholder}
value={child.properties.value}
onChange={(e) => {
e.preventDefault()
child.properties.value = e.target.value
}}
/>
)
}
else {
return <p key={index}>Unsupported input type: {child.properties.type}</p>
}
if (child.tagName === SUPPORTED_TAGS.INPUT && Object.values(SUPPORTED_TYPES).includes(child.properties.type)) {
return (
<Input
key={index}
type={child.properties.type}
name={child.properties.name}
placeholder={child.properties.placeholder}
value={formValues[child.properties.name]}
onChange={(e) => {
setFormValues(prevValues => ({
...prevValues,
[child.properties.name]: e.target.value,
}))
}}
/>
)
}
if (child.tagName === SUPPORTED_TAGS.TEXTAREA) {
return (
@ -103,10 +112,12 @@ const MarkdownForm = ({ node }: any) => {
key={index}
name={child.properties.name}
placeholder={child.properties.placeholder}
value={child.properties.value}
value={formValues[child.properties.name]}
onChange={(e) => {
e.preventDefault()
child.properties.value = e.target.value
setFormValues(prevValues => ({
...prevValues,
[child.properties.name]: e.target.value,
}))
}}
/>
)

View File

@ -15,7 +15,6 @@ import Category from './category'
import Tools from './tools'
import cn from '@/utils/classnames'
import I18n from '@/context/i18n'
import { getLanguage } from '@/i18n/language'
import Drawer from '@/app/components/base/drawer'
import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading'
@ -44,13 +43,15 @@ const AddToolModal: FC<Props> = ({
}) => {
const { t } = useTranslation()
const { locale } = useContext(I18n)
const language = getLanguage(locale)
const [currentType, setCurrentType] = useState('builtin')
const [currentCategory, setCurrentCategory] = useState('')
const [keywords, setKeywords] = useState<string>('')
const handleKeywordsChange = (value: string) => {
setKeywords(value)
}
const isMatchingKeywords = (text: string, keywords: string) => {
return text.toLowerCase().includes(keywords.toLowerCase())
}
const [toolList, setToolList] = useState<ToolWithProvider[]>([])
const [listLoading, setListLoading] = useState(true)
const getAllTools = async () => {
@ -82,13 +83,16 @@ const AddToolModal: FC<Props> = ({
else
return toolWithProvider.labels.includes(currentCategory)
}).filter((toolWithProvider) => {
return toolWithProvider.tools.some((tool) => {
return Object.values(tool.label).some((label) => {
return label.toLowerCase().includes(keywords.toLowerCase())
return (
isMatchingKeywords(toolWithProvider.name, keywords)
|| toolWithProvider.tools.some((tool) => {
return Object.values(tool.label).some((label) => {
return isMatchingKeywords(label, keywords)
})
})
})
)
})
}, [currentType, currentCategory, toolList, keywords, language])
}, [currentType, currentCategory, toolList, keywords])
const {
modelConfig,

View File

@ -11,7 +11,6 @@ import { ToolTypeEnum } from './types'
import Tools from './tools'
import { useToolTabs } from './hooks'
import cn from '@/utils/classnames'
import { useGetLanguage } from '@/context/i18n'
type AllToolsProps = {
searchText: string
@ -21,13 +20,16 @@ const AllTools = ({
searchText,
onSelect,
}: AllToolsProps) => {
const language = useGetLanguage()
const tabs = useToolTabs()
const [activeTab, setActiveTab] = useState(ToolTypeEnum.All)
const buildInTools = useStore(s => s.buildInTools)
const customTools = useStore(s => s.customTools)
const workflowTools = useStore(s => s.workflowTools)
const isMatchingKeywords = (text: string, keywords: string) => {
return text.toLowerCase().includes(keywords.toLowerCase())
}
const tools = useMemo(() => {
let mergedTools: ToolWithProvider[] = []
if (activeTab === ToolTypeEnum.All)
@ -40,11 +42,14 @@ const AllTools = ({
mergedTools = workflowTools
return mergedTools.filter((toolWithProvider) => {
return toolWithProvider.tools.some((tool) => {
return tool.label[language].toLowerCase().includes(searchText.toLowerCase())
return isMatchingKeywords(toolWithProvider.name, searchText)
|| toolWithProvider.tools.some((tool) => {
return Object.values(tool.label).some((label) => {
return isMatchingKeywords(label, searchText)
})
})
})
}, [activeTab, buildInTools, customTools, workflowTools, searchText, language])
}, [activeTab, buildInTools, customTools, workflowTools, searchText])
return (
<div>
<div className='flex items-center px-3 h-8 space-x-1 bg-gray-25 border-b-[0.5px] border-black/[0.08] shadow-xs'>

View File

@ -0,0 +1,51 @@
'use client'
import type { FC } from 'react'
import React, { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { VarType } from '../../../types'
import type { Var } from '../../../types'
import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list'
import cn from '@/utils/classnames'
import Input from '@/app/components/workflow/nodes/_base/components/input-support-select-var'
type Props = {
nodeId: string
readOnly: boolean
value: string
onChange: (value: string) => void
}
const ExtractInput: FC<Props> = ({
nodeId,
readOnly,
value,
onChange,
}) => {
const { t } = useTranslation()
const [isFocus, setIsFocus] = useState(false)
const { availableVars, availableNodesWithParent } = useAvailableVarList(nodeId, {
onlyLeafNodeVar: false,
filterVar: (varPayload: Var) => {
return [VarType.number].includes(varPayload.type)
},
})
return (
<div className='flex items-start space-x-1'>
<Input
instanceId='http-extract-number'
className={cn(isFocus ? 'shadow-xs bg-gray-50 border-gray-300' : 'bg-gray-100 border-gray-100', 'w-0 grow rounded-lg px-3 py-[6px] border')}
value={value}
onChange={onChange}
readOnly={readOnly}
nodesOutputVars={availableVars}
availableNodes={availableNodesWithParent}
onFocusChange={setIsFocus}
placeholder={!readOnly ? t('workflow.nodes.http.extractListPlaceholder')! : ''}
placeholderClassName='!leading-[21px]'
/>
</div >
)
}
export default React.memo(ExtractInput)

View File

@ -12,6 +12,10 @@ const nodeDefault: NodeDefault<ListFilterNodeType> = {
enabled: false,
conditions: [],
},
extract_by: {
enabled: false,
serial: '1',
},
order_by: {
enabled: false,
key: '',

View File

@ -13,6 +13,7 @@ import FilterCondition from './components/filter-condition'
import Field from '@/app/components/workflow/nodes/_base/components/field'
import { type NodePanelProps } from '@/app/components/workflow/types'
import Switch from '@/app/components/base/switch'
import ExtractInput from '@/app/components/workflow/nodes/list-operator/components/extract-input'
const i18nPrefix = 'workflow.nodes.listFilter'
@ -32,6 +33,8 @@ const Panel: FC<NodePanelProps<ListFilterNodeType>> = ({
filterVar,
handleFilterEnabledChange,
handleFilterChange,
handleExtractsEnabledChange,
handleExtractsChange,
handleLimitChange,
handleOrderByEnabledChange,
handleOrderByKeyChange,
@ -79,6 +82,41 @@ const Panel: FC<NodePanelProps<ListFilterNodeType>> = ({
: null}
</Field>
<Split />
<Field
title={t(`${i18nPrefix}.extractsCondition`)}
operations={
<Switch
defaultValue={inputs.extract_by?.enabled}
onChange={handleExtractsEnabledChange}
size='md'
disabled={readOnly}
/>
}
>
{inputs.extract_by?.enabled
? (
<div className='flex items-center justify-between'>
{hasSubVariable && (
<div className='grow mr-2'>
<ExtractInput
value={inputs.extract_by.serial as string}
onChange={handleExtractsChange}
readOnly={readOnly}
nodeId={id}
/>
</div>
)}
</div>
)
: null}
</Field>
<Split />
<LimitConfig
config={inputs.limit}
onChange={handleLimitChange}
readonly={readOnly}
/>
<Split />
<Field
title={t(`${i18nPrefix}.orderBy`)}
operations={
@ -118,13 +156,7 @@ const Panel: FC<NodePanelProps<ListFilterNodeType>> = ({
: null}
</Field>
<Split />
<LimitConfig
config={inputs.limit}
onChange={handleLimitChange}
readonly={readOnly}
/>
</div>
<Split />
<div className='px-4 pt-4 pb-2'>
<OutputVars>
<>

View File

@ -25,6 +25,10 @@ export type ListFilterNodeType = CommonNodeType & {
enabled: boolean
conditions: Condition[]
}
extract_by: {
enabled: boolean
serial?: string
}
order_by: {
enabled: boolean
key: ValueSelector | string

View File

@ -119,6 +119,22 @@ const useConfig = (id: string, payload: ListFilterNodeType) => {
setInputs(newInputs)
}, [inputs, setInputs])
const handleExtractsEnabledChange = useCallback((enabled: boolean) => {
const newInputs = produce(inputs, (draft) => {
draft.extract_by.enabled = enabled
if (enabled)
draft.extract_by.serial = '1'
})
setInputs(newInputs)
}, [inputs, setInputs])
const handleExtractsChange = useCallback((value: string) => {
const newInputs = produce(inputs, (draft) => {
draft.extract_by.serial = value
})
setInputs(newInputs)
}, [inputs, setInputs])
const handleOrderByEnabledChange = useCallback((enabled: boolean) => {
const newInputs = produce(inputs, (draft) => {
draft.order_by.enabled = enabled
@ -162,6 +178,8 @@ const useConfig = (id: string, payload: ListFilterNodeType) => {
handleOrderByEnabledChange,
handleOrderByKeyChange,
handleOrderByTypeChange,
handleExtractsEnabledChange,
handleExtractsChange,
}
}

View File

@ -369,6 +369,7 @@ const translation = {
inputVars: 'Input Variables',
api: 'API',
apiPlaceholder: 'Enter URL, type / insert variable',
extractListPlaceholder: 'Enter list item index, type / insert variable',
notStartWithHttp: 'API should start with http:// or https://',
key: 'Key',
type: 'Type',
@ -569,8 +570,8 @@ const translation = {
errorResponseMethod: 'Error response method',
ErrorMethod: {
operationTerminated: 'terminated',
continueOnError: 'continue-on-error',
removeAbnormalOutput: 'remove-abnormal-output',
continueOnError: 'continue on error',
removeAbnormalOutput: 'remove abnormal output',
},
answerNodeWarningDesc: 'Parallel mode warning: Answer nodes, conversation variable assignments, and persistent read/write operations within iterations may cause exceptions.',
},
@ -605,6 +606,7 @@ const translation = {
inputVar: 'Input Variable',
filterCondition: 'Filter Condition',
filterConditionKey: 'Filter Condition Key',
extractsCondition: 'Extract the N item',
filterConditionComparisonOperator: 'Filter Condition Comparison Operator',
filterConditionComparisonValue: 'Filter Condition value',
selectVariableKeyPlaceholder: 'Select sub variable key',

View File

@ -369,6 +369,7 @@ const translation = {
inputVars: '输入变量',
api: 'API',
apiPlaceholder: '输入 URL输入变量时请键入/',
extractListPlaceholder: '输入提取列表编号,输入变量时请键入‘/',
notStartWithHttp: 'API 应该以 http:// 或 https:// 开头',
key: '键',
type: '类型',
@ -608,6 +609,7 @@ const translation = {
filterConditionComparisonOperator: '过滤条件比较操作符',
filterConditionComparisonValue: '过滤条件比较值',
selectVariableKeyPlaceholder: '选择子变量的 Key',
extractsCondition: '取第 N 项',
limit: '取前 N 项',
orderBy: '排序',
asc: '升序',

View File

@ -321,7 +321,9 @@ const baseFetch = <T>(
}
const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX
let urlWithPrefix = `${urlPrefix}${url.startsWith('/') ? url : `/${url}`}`
let urlWithPrefix = (url.startsWith('http://') || url.startsWith('https://'))
? url
: `${urlPrefix}${url.startsWith('/') ? url : `/${url}`}`
const { method, params, body } = options
// handle query
@ -494,7 +496,9 @@ export const ssePost = (
getAbortController?.(abortController)
const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX
const urlWithPrefix = `${urlPrefix}${url.startsWith('/') ? url : `/${url}`}`
const urlWithPrefix = (url.startsWith('http://') || url.startsWith('https://'))
? url
: `${urlPrefix}${url.startsWith('/') ? url : `/${url}`}`
const { body } = options
if (body)