diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 8173bee58e..7c632f8a34 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -76,7 +76,7 @@ jobs: - name: Run Workflow run: poetry run -C api bash dev/pytest/pytest_workflow.sh - - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale) + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch) uses: hoverkraft-tech/compose-action@v2.0.0 with: compose-file: | @@ -90,5 +90,6 @@ jobs: pgvecto-rs pgvector chroma + elasticsearch - name: Test Vector Stores run: poetry run -C api bash dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index 3418bf0c6f..ae3e0ee69d 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -6,5 +6,6 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml +yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs." \ No newline at end of file +echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch" \ No newline at end of file diff --git a/api/.env.example b/api/.env.example index cf3a0f302d..775149f8fd 100644 --- a/api/.env.example +++ b/api/.env.example @@ -130,6 +130,12 @@ TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_REPLICAS=2 +# ElasticSearch configuration +ELASTICSEARCH_HOST=127.0.0.1 +ELASTICSEARCH_PORT=9200 +ELASTICSEARCH_USERNAME=elastic +ELASTICSEARCH_PASSWORD=elastic + # PGVECTO_RS configuration PGVECTO_RS_HOST=localhost PGVECTO_RS_PORT=5431 diff --git a/api/commands.py b/api/commands.py index c7ffb47b51..82a32f0f5b 100644 --- a/api/commands.py +++ b/api/commands.py @@ -344,6 +344,14 @@ def migrate_knowledge_vector_database(): "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == VectorType.ELASTICSEARCH: + dataset_id = dataset.id + index_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": 'elasticsearch', + "vector_store": {"class_prefix": index_name} + } + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 1104e298b1..247fcde655 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description='Dify version', - default='0.6.16', + default='0.7.0', ) COMMIT_SHA: str = Field( diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 306fac3a93..b6b18f5c5b 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -1,3 +1,7 @@ from contextvars import ContextVar -tenant_id: ContextVar[str] = ContextVar('tenant_id') \ No newline at end of file +from core.workflow.entities.variable_pool import VariablePool + +tenant_id: ContextVar[str] = ContextVar('tenant_id') + +workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool') diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 3e98843280..a5bc2dd86a 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -555,7 +555,7 @@ class DatasetRetrievalSettingApi(Resource): RetrievalMethod.SEMANTIC_SEARCH.value ] } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE: + case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: return { 'retrieval_method': [ RetrievalMethod.SEMANTIC_SEARCH.value, @@ -579,7 +579,7 @@ class DatasetRetrievalSettingMockApi(Resource): RetrievalMethod.SEMANTIC_SEARCH.value ] } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE: + case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: return { 'retrieval_method': [ RetrievalMethod.SEMANTIC_SEARCH.value, diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index afe0ca7c69..976b97660a 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -178,11 +178,20 @@ class DatasetDocumentListApi(Resource): .subquery() query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ - .order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0))) + .order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) elif sort == 'created_at': - query = query.order_by(sort_logic(Document.created_at)) + query = query.order_by( + sort_logic(Document.created_at), + sort_logic(Document.position), + ) else: - query = query.order_by(desc(Document.created_at)) + query = query.order_by( + desc(Document.created_at), + desc(Document.position), + ) paginated_documents = query.paginate( page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index ec17db5f06..f4e6675bd4 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -93,6 +93,7 @@ class DatasetConfigManager: reranking_model=dataset_configs.get('reranking_model'), weights=dataset_configs.get('weights'), reranking_enabled=dataset_configs.get('reranking_enabled', True), + rerank_mode=dataset_configs["reranking_mode"], ) ) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 0cde659992..351eb05d8a 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -8,6 +8,8 @@ from typing import Union from flask import Flask, current_app from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models.workflow import ConversationVariable, Workflow logger = logging.getLogger(__name__) @@ -120,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, stream=stream ) - + def single_iteration_generate(self, app_model: App, workflow: Workflow, node_id: str, @@ -140,10 +147,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): """ if not node_id: raise ValueError('node_id is required') - + if args.get('inputs') is None: raise ValueError('inputs is required') - + extras = { "auto_generate_conversation_name": False } @@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # update conversation features conversation.override_model_configs = workflow.features db.session.commit() - db.session.refresh(conversation) + # db.session.refresh(conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id ) + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + # Create conversation variables if they don't exist. + conversation_variables = [ + ConversationVariable.from_variable( + app_id=conversation.app_id, conversation_id=conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in conversation_variables] + + session.commit() + + # Increment dialogue count. + conversation.dialogue_count += 1 + + conversation_id = conversation.id + conversation_dialogue_count = conversation.dialogue_count + db.session.commit() + db.session.refresh(conversation) + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + user_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = application_generate_entity.user_id + + # Create a variable pool. + system_inputs = { + SystemVariable.QUERY: query, + SystemVariable.FILES: files, + SystemVariable.CONVERSATION_ID: conversation_id, + SystemVariable.USER_ID: user_id, + SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, + } + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + contexts.workflow_variable_pool.set(variable_pool) + # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ 'flask_app': current_app._get_current_object(), 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, - 'conversation_id': conversation.id, 'message_id': message.id, - 'user': user, - 'context': contextvars.copy_context() + 'context': contextvars.copy_context(), }) worker_thread.start() @@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) return AdvancedChatAppGenerateResponseConverter.convert( @@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): def _generate_worker(self, flask_app: Flask, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, - conversation_id: str, message_id: str, - user: Account, context: contextvars.Context) -> None: """ Generate worker in a new thread. @@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user_id=application_generate_entity.user_id ) else: - # get conversation and message - conversation = self._get_conversation(conversation_id) + # get message message = self._get_message(message_id) # chatbot app @@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - conversation=conversation, message=message ) except GenerateTaskStoppedException: @@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): finally: db.session.close() - def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False) \ - -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + def _handle_advanced_chat_response( + self, + *, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 47c53531f6..5dc03979cf 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -4,9 +4,6 @@ import time from collections.abc import Mapping from typing import Any, Optional, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import ( from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.model import App, Conversation, EndUser, Message -from models.workflow import ConversationVariable, Workflow +from models import App, Message, Workflow logger = logging.getLogger(__name__) @@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner): self, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, - conversation: Conversation, message: Message, ) -> None: """ @@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner): inputs = application_generate_entity.inputs query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id # moderation if self.handle_input_moderation( @@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner): if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id - ) - with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: - conversation_variables = [ - ConversationVariable.from_variable( - app_id=conversation.app_id, conversation_id=conversation.id, variable=variable - ) - for variable in workflow.conversation_variables - ] - session.add_all(conversation_variables) - session.commit() - # Convert database entities to variables - conversation_variables = [item.to_variable() for item in conversation_variables] - - # Create a variable pool. - system_inputs = { - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, - ) - # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( @@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner): invoke_from=application_generate_entity.invoke_from, callbacks=workflow_callbacks, call_depth=application_generate_entity.call_depth, - variable_pool=variable_pool, ) def single_iteration_run( @@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner): """ Single iteration run """ - app_record: App = db.session.query(App).filter(App.id == app_id).first() + app_record = db.session.query(App).filter(App.id == app_id).first() if not app_record: raise ValueError('App not found') diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 91a43ed449..f8efcb5960 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,7 @@ import time from collections.abc import Generator from typing import Any, Optional, Union, cast +import contexts from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -47,7 +48,8 @@ from core.file.file_obj import FileVar from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created @@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] + # Deprecated _workflow_system_variables: dict[SystemVariable, Any] _iteration_nested_relations: dict[str, list[str]] @@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc conversation: Conversation, message: Message, user: Union[Account, EndUser], - stream: bool + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._workflow = workflow self._conversation = conversation self._message = message + # Deprecated self._workflow_system_variables = { SystemVariable.QUERY: message.query, SystemVariable.FILES: application_generate_entity.files, SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id + SystemVariable.USER_ID: user_id, } self._task_state = AdvancedChatTaskState( @@ -613,7 +617,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if route_chunk_node_id == 'sys': # system variable - value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1])) + value = contexts.workflow_variable_pool.get().get(value_selector) + if value: + value = value.text elif route_chunk_node_id in self._iteration_nested_relations: # it's a iteration variable if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c5cd686402..12f69f1528 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): return introduction - def _get_conversation(self, conversation_id: str) -> Conversation: + def _get_conversation(self, conversation_id: str): """ Get conversation by conversation id :param conversation_id: conversation id @@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): .first() ) + if not conversation: + raise ConversationNotExistsError() + return conversation def _get_message(self, message_id: str) -> Message: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 17a99cf1c5..994919391e 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import ( WorkflowAppGenerateEntity, ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2b4362150f..5022eb0438 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -42,7 +42,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariable from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account @@ -519,7 +520,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ nodes = graph.get('nodes') - iteration_ids = [node.get('id') for node in nodes + iteration_ids = [node.get('id') for node in nodes if node.get('data', {}).get('type') in [ NodeType.ITERATION.value, NodeType.LOOP.value, @@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id ] for iteration_id in iteration_ids } - \ No newline at end of file diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index 174e241261..7de06dfb96 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -2,7 +2,6 @@ from .segment_group import SegmentGroup from .segments import ( ArrayAnySegment, ArraySegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -13,11 +12,9 @@ from .segments import ( from .types import SegmentType from .variables import ( ArrayAnyVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, NoneVariable, @@ -32,7 +29,6 @@ __all__ = [ 'FloatVariable', 'ObjectVariable', 'SecretVariable', - 'FileVariable', 'StringVariable', 'ArrayAnyVariable', 'Variable', @@ -45,11 +41,9 @@ __all__ = [ 'FloatSegment', 'ObjectSegment', 'ArrayAnySegment', - 'FileSegment', 'StringSegment', 'ArrayStringVariable', 'ArrayNumberVariable', 'ArrayObjectVariable', - 'ArrayFileVariable', 'ArraySegment', ] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index 91ff1fdb3d..e6e9ce9774 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -2,12 +2,10 @@ from collections.abc import Mapping from typing import Any from configs import dify_config -from core.file.file_obj import FileVar from .exc import VariableError from .segments import ( ArrayAnySegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -17,11 +15,9 @@ from .segments import ( ) from .types import SegmentType from .variables import ( - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, ObjectVariable, @@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f'invalid number value {value}') - case SegmentType.FILE: - result = FileVariable.model_validate(mapping) case SegmentType.OBJECT if isinstance(value, dict): result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): @@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) - case SegmentType.ARRAY_FILE if isinstance(value, list): - mapping = dict(mapping) - mapping['value'] = [{'value': v} for v in value] - result = ArrayFileVariable.model_validate(mapping) case _: raise VariableError(f'not supported value type {value_type}') if result.size > dify_config.MAX_VARIABLE_SIZE: @@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment: return ObjectSegment(value=value) if isinstance(value, list): return ArrayAnySegment(value=value) - if isinstance(value, FileVar): - return FileSegment(value=value) raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 7653e1085f..321bc0ad02 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -5,8 +5,6 @@ from typing import Any from pydantic import BaseModel, ConfigDict, field_validator -from core.file.file_obj import FileVar - from .types import SegmentType @@ -78,14 +76,7 @@ class IntegerSegment(Segment): value: int -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - # TODO: embed FileVar in this model. - value: FileVar - @property - def markdown(self) -> str: - return self.value.to_markdown() class ObjectSegment(Segment): @@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] - -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[FileSegment] diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index a371058ef5..cdd2b0b4b0 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -10,8 +10,6 @@ class SegmentType(str, Enum): ARRAY_STRING = 'array[string]' ARRAY_NUMBER = 'array[number]' ARRAY_OBJECT = 'array[object]' - ARRAY_FILE = 'array[file]' OBJECT = 'object' - FILE = 'file' GROUP = 'group' diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index ac26e16542..8fef707fcf 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -4,11 +4,9 @@ from core.helper import encrypter from .segments import ( ArrayAnySegment, - ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable): pass -class FileVariable(FileSegment, Variable): - pass - - class ObjectVariable(ObjectSegment, Variable): pass @@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass -class ArrayFileVariable(ArrayFileSegment, Variable): - pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index 545f31fddf..8baa8ba09e 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -2,7 +2,7 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.enums import SystemVariable from models.account import Account from models.model import EndUser from models.workflow import Workflow @@ -13,4 +13,4 @@ class WorkflowCycleStateManager: _workflow: Workflow _user: Union[Account, EndUser] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] \ No newline at end of file + _workflow_system_variables: dict[SystemVariable, Any] diff --git a/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml new file mode 100644 index 0000000000..aad07f5673 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml @@ -0,0 +1,81 @@ +model: farui-plus +label: + en_US: farui-plus +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 12288 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: enable_search + type: boolean + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.02' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml index eed09f95de..f4303c53d3 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml @@ -2,3 +2,8 @@ model: text-embedding-v1 model_type: text-embedding model_properties: context_size: 2048 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml index db2fa861e6..f6be3544ed 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml @@ -2,3 +2,8 @@ model: text-embedding-v2 model_type: text-embedding model_properties: context_size: 2048 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index c207ffc1e3..e7e1b5c764 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -2,6 +2,7 @@ import time from typing import Optional import dashscope +import numpy as np from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( @@ -21,11 +22,11 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): """ def _invoke( - self, - model: str, - credentials: dict, - texts: list[str], - user: Optional[str] = None, + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,16 +38,44 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - embeddings, embedding_used_tokens = self.embed_documents( - credentials_kwargs=credentials_kwargs, - model=model, - texts=texts - ) + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + inputs = [] + indices = [] + used_tokens = 0 + + for i, text in enumerate(texts): + + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + for i in _iter: + embeddings_batch, embedding_used_tokens = self.embed_documents( + credentials_kwargs=credentials_kwargs, + model=model, + texts=inputs[i : i + max_chunks], + ) + used_tokens += embedding_used_tokens + batched_embeddings += embeddings_batch + + # calc usage + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=used_tokens + ) return TextEmbeddingResult( - embeddings=embeddings, - usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + embeddings=batched_embeddings, usage=usage, model=model ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -79,12 +108,16 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): credentials_kwargs = self._to_credential_kwargs(credentials) # call embedding model - self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) + self.embed_documents( + credentials_kwargs=credentials_kwargs, model=model, texts=["ping"] + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: + def embed_documents( + credentials_kwargs: dict, model: str, texts: list[str] + ) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: @@ -102,7 +135,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): api_key=credentials_kwargs["dashscope_api_key"], model=model, input=text, - text_type="document" + text_type="document", ) data = response.output["embeddings"][0] embeddings.append(data["embedding"]) @@ -111,7 +144,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens def _calc_response_usage( - self, model: str, credentials: dict, tokens: int + self, model: str, credentials: dict, tokens: int ) -> EmbeddingUsage: """ Calculate response usage @@ -125,7 +158,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -136,7 +169,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/upstage/llm/_position.yaml b/api/core/model_runtime/model_providers/upstage/llm/_position.yaml index d4f03e1988..7992843dcb 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/upstage/llm/_position.yaml @@ -1 +1 @@ -- soloar-1-mini-chat +- solar-1-mini-chat diff --git a/api/core/rag/datasource/vdb/elasticsearch/__init__.py b/api/core/rag/datasource/vdb/elasticsearch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py new file mode 100644 index 0000000000..01ba6fb324 --- /dev/null +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -0,0 +1,191 @@ +import json +from typing import Any + +import requests +from elasticsearch import Elasticsearch +from flask import current_app +from pydantic import BaseModel, model_validator + +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.models.document import Document +from models.dataset import Dataset + + +class ElasticSearchConfig(BaseModel): + host: str + port: str + username: str + password: str + + @model_validator(mode='before') + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config HOST is required") + if not values['port']: + raise ValueError("config PORT is required") + if not values['username']: + raise ValueError("config USERNAME is required") + if not values['password']: + raise ValueError("config PASSWORD is required") + return values + + +class ElasticSearchVector(BaseVector): + def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): + super().__init__(index_name.lower()) + self._client = self._init_client(config) + self._attributes = attributes + + def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: + try: + client = Elasticsearch( + hosts=f'{config.host}:{config.port}', + basic_auth=(config.username, config.password), + request_timeout=100000, + retry_on_timeout=True, + max_retries=10000, + ) + except requests.exceptions.ConnectionError: + raise ConnectionError("Vector database connection error") + + return client + + def get_type(self) -> str: + return 'elasticsearch' + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + if not self._client.indices.exists(index=self._collection_name): + dim = len(embeddings[0]) + mapping = { + "properties": { + "text": { + "type": "text" + }, + "vector": { + "type": "dense_vector", + "index": True, + "dims": dim, + "similarity": "l2_norm" + }, + } + } + self._client.indices.create(index=self._collection_name, mappings=mapping) + + added_ids = [] + for i, text in enumerate(texts): + self._client.index(index=self._collection_name, + id=uuids[i], + document={ + "text": text, + "vector": embeddings[i] if embeddings[i] else None, + "metadata": metadatas[i] if metadatas[i] else {}, + }) + added_ids.append(uuids[i]) + + self._client.indices.refresh(index=self._collection_name) + return uuids + + def text_exists(self, id: str) -> bool: + return self._client.exists(index=self._collection_name, id=id).__bool__() + + def delete_by_ids(self, ids: list[str]) -> None: + for id in ids: + self._client.delete(index=self._collection_name, id=id) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + query_str = { + 'query': { + 'match': { + f'metadata.{key}': f'{value}' + } + } + } + results = self._client.search(index=self._collection_name, body=query_str) + ids = [hit['_id'] for hit in results['hits']['hits']] + if ids: + self.delete_by_ids(ids) + + def delete(self) -> None: + self._client.indices.delete(index=self._collection_name) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + query_str = { + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", + "params": { + "query_vector": query_vector + } + } + } + } + } + + results = self._client.search(index=self._collection_name, body=query_str) + + docs_and_scores = [] + for hit in results['hits']['hits']: + docs_and_scores.append( + (Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score'])) + + docs = [] + for doc, score in docs_and_scores: + score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + if score > score_threshold: + doc.metadata['score'] = score + docs.append(doc) + + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + + return docs + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + query_str = { + "match": { + "text": query + } + } + results = self._client.search(index=self._collection_name, query=query_str) + docs = [] + for hit in results['hits']['hits']: + docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata'])) + + return docs + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + return self.add_texts(texts, embeddings, **kwargs) + + +class ElasticSearchVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + + config = current_app.config + return ElasticSearchVector( + index_name=collection_name, + config=ElasticSearchConfig( + host=config.get('ELASTICSEARCH_HOST'), + port=config.get('ELASTICSEARCH_PORT'), + username=config.get('ELASTICSEARCH_USERNAME'), + password=config.get('ELASTICSEARCH_PASSWORD'), + ), + attributes=[] + ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index cff9293baa..4ae1a3395b 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -93,7 +93,7 @@ class MyScaleVector(BaseVector): @staticmethod def escape_str(value: Any) -> str: - return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value)) + return "".join(" " if c in ("\\", "'") else c for c in str(value)) def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") @@ -118,7 +118,7 @@ class MyScaleVector(BaseVector): return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs) + return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs) def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index fad60ecf45..3e9ca8e1fe 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -71,6 +71,9 @@ class Vector: case VectorType.RELYT: from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory return RelytVectorFactory + case VectorType.ELASTICSEARCH: + from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory + return ElasticSearchVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory return TiDBVectorFactory diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 77495044df..317ca6abc8 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -15,3 +15,4 @@ class VectorType(str, Enum): OPENSEARCH = 'opensearch' TENCENT = 'tencent' ORACLE = 'oracle' + ELASTICSEARCH = 'elasticsearch' diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 569a1d3238..2e4433d9f6 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -46,7 +46,7 @@ class ToolProviderType(Enum): if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. @@ -68,7 +68,7 @@ class ApiProviderSchemaType(Enum): if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. @@ -103,8 +103,8 @@ class ToolInvokeMessage(BaseModel): """ plain text, image url or link url """ - message: Union[str, bytes, dict] = None - meta: dict[str, Any] = None + message: str | bytes | dict | None = None + meta: dict[str, Any] | None = None save_as: str = '' class ToolInvokeMessageBinary(BaseModel): @@ -154,8 +154,8 @@ class ToolParameter(BaseModel): options: Optional[list[ToolParameterOption]] = None @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, + def get_simple_instance(cls, + name: str, llm_description: str, type: ToolParameterType, required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': """ get a simple tool parameter @@ -222,7 +222,7 @@ class ToolProviderCredentials(BaseModel): if mode.value == value: return mode raise ValueError(f'invalid mode value {value}') - + @staticmethod def default(value: str) -> str: return "" @@ -290,7 +290,7 @@ class ToolRuntimeVariablePool(BaseModel): 'tenant_id': self.tenant_id, 'pool': [variable.model_dump() for variable in self.pool], } - + def set_text(self, tool_name: str, name: str, value: str) -> None: """ set a text variable @@ -301,7 +301,7 @@ class ToolRuntimeVariablePool(BaseModel): variable = cast(ToolRuntimeTextVariable, variable) variable.value = value return - + variable = ToolRuntimeTextVariable( type=ToolRuntimeVariableType.TEXT, name=name, @@ -334,7 +334,7 @@ class ToolRuntimeVariablePool(BaseModel): variable = cast(ToolRuntimeImageVariable, variable) variable.value = value return - + variable = ToolRuntimeImageVariable( type=ToolRuntimeVariableType.IMAGE, name=name, @@ -388,21 +388,21 @@ class ToolInvokeMeta(BaseModel): Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) - + @classmethod def error_instance(cls, error: str) -> 'ToolInvokeMeta': """ Get an instance of ToolInvokeMeta with error """ return cls(time_cost=0.0, error=error, tool_config={}) - + def to_dict(self) -> dict: return { 'time_cost': self.time_cost, 'error': self.error, 'tool_config': self.tool_config, } - + class ToolLabel(BaseModel): """ Tool label @@ -416,4 +416,4 @@ class ToolInvokeFrom(Enum): Enum class for tool invoke """ WORKFLOW = "workflow" - AGENT = "agent" \ No newline at end of file + AGENT = "agent" diff --git a/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg b/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg new file mode 100644 index 0000000000..07734077d5 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/_assets/gitlab.svg @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py new file mode 100644 index 0000000000..fca34ae15f --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -0,0 +1,34 @@ +from typing import Any + +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class GitlabProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.") + + if 'site_url' not in credentials or not credentials.get('site_url'): + site_url = 'https://gitlab.com' + else: + site_url = credentials.get('site_url') + + try: + headers = { + "Content-Type": "application/vnd.text+json", + "Authorization": f"Bearer {credentials.get('access_tokens')}", + } + + response = requests.get( + url= f"{site_url}/api/v4/user", + headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError((response.json()).get('message')) + except Exception as e: + raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e)) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.yaml b/api/core/tools/provider/builtin/gitlab/gitlab.yaml new file mode 100644 index 0000000000..b5feea2382 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/gitlab.yaml @@ -0,0 +1,38 @@ +identity: + author: Leo.Wang + name: gitlab + label: + en_US: Gitlab + zh_Hans: Gitlab + description: + en_US: Gitlab plugin for commit + zh_Hans: 用于获取Gitlab commit的插件 + icon: gitlab.svg +credentials_for_provider: + access_tokens: + type: secret-input + required: true + label: + en_US: Gitlab access token + zh_Hans: Gitlab access token + placeholder: + en_US: Please input your Gitlab access token + zh_Hans: 请输入你的 Gitlab access token + help: + en_US: Get your Gitlab access token from Gitlab + zh_Hans: 从 Gitlab 获取您的 access token + url: https://docs.gitlab.com/16.9/ee/api/oauth2.html + site_url: + type: text-input + required: false + default: 'https://gitlab.com' + label: + en_US: Gitlab site url + zh_Hans: Gitlab site url + placeholder: + en_US: Please input your Gitlab site url + zh_Hans: 请输入你的 Gitlab site url + help: + en_US: Find your Gitlab url + zh_Hans: 找到你的Gitlab url + url: https://gitlab.com/help diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py new file mode 100644 index 0000000000..212bdb03ab --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -0,0 +1,101 @@ +import json +from datetime import datetime, timedelta +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GitlabCommitsTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + project = tool_parameters.get('project', '') + employee = tool_parameters.get('employee', '') + start_time = tool_parameters.get('start_time', '') + end_time = tool_parameters.get('end_time', '') + + if not project: + return self.create_text_message('Project is required') + + if not start_time: + start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() + if not end_time: + end_time = datetime.utcnow().isoformat() + + access_token = self.runtime.credentials.get('access_tokens') + site_url = self.runtime.credentials.get('site_url') + + if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + return self.create_text_message("Gitlab API Access Tokens is required.") + if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): + site_url = 'https://gitlab.com' + + # Get commit content + result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time) + + return self.create_text_message(json.dumps(result, ensure_ascii=False)) + + def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + # Get all of projects + url = f"{domain}/api/v4/projects" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() + + filtered_projects = [p for p in projects if project == "*" or p['name'] == project] + + for project in filtered_projects: + project_id = project['id'] + project_name = project['name'] + print(f"Project: {project_name}") + + # Get all of proejct commits + commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" + params = { + 'since': start_time, + 'until': end_time + } + if employee: + params['author'] = employee + + commits_response = requests.get(commits_url, headers=headers, params=params) + commits_response.raise_for_status() + commits = commits_response.json() + + for commit in commits: + commit_sha = commit['id'] + print(f"\tCommit SHA: {commit_sha}") + + diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" + diff_response = requests.get(diff_url, headers=headers) + diff_response.raise_for_status() + diffs = diff_response.json() + + for diff in diffs: + # Caculate code lines of changed + added_lines = diff['diff'].count('\n+') + removed_lines = diff['diff'].count('\n-') + total_changes = added_lines + removed_lines + + if total_changes > 1: + final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) + results.append({ + "project": project_name, + "commit_sha": commit_sha, + "diff": final_code + }) + print(f"Commit code:{final_code}") + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml new file mode 100644 index 0000000000..fc4e7eb7bb --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -0,0 +1,56 @@ +identity: + name: gitlab_commits + author: Leo.Wang + label: + en_US: Gitlab Commits + zh_Hans: Gitlab代码提交内容 +description: + human: + en_US: A tool for query gitlab commits. Input should be a exists username. + zh_Hans: 一个用于查询gitlab代码提交记录的的工具,输入的内容应该是一个已存在的用户名或者项目名。 + llm: A tool for query gitlab commits. Input should be a exists username or project. +parameters: + - name: employee + type: string + required: false + label: + en_US: employee + zh_Hans: 员工用户名 + human_description: + en_US: employee + zh_Hans: 员工用户名 + llm_description: employee for gitlab + form: llm + - name: project + type: string + required: true + label: + en_US: project + zh_Hans: 项目名 + human_description: + en_US: project + zh_Hans: 项目名 + llm_description: project for gitlab + form: llm + - name: start_time + type: string + required: false + label: + en_US: start_time + zh_Hans: 开始时间 + human_description: + en_US: start_time + zh_Hans: 开始时间 + llm_description: start_time for gitlab + form: llm + - name: end_time + type: string + required: false + label: + en_US: end_time + zh_Hans: 结束时间 + human_description: + en_US: end_time + zh_Hans: 结束时间 + llm_description: end_time for gitlab + form: llm diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 0978b09b94..025453567b 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -4,13 +4,14 @@ from typing import Any, Optional from pydantic import BaseModel -from models.workflow import WorkflowNodeExecutionStatus +from models import WorkflowNodeExecutionStatus class NodeType(Enum): """ Node Types. """ + START = 'start' END = 'end' ANSWER = 'answer' @@ -44,33 +45,11 @@ class NodeType(Enum): raise ValueError(f'invalid node type value {value}') -class SystemVariable(Enum): - """ - System Variables. - """ - QUERY = 'query' - FILES = 'files' - CONVERSATION_ID = 'conversation_id' - USER_ID = 'user_id' - - @classmethod - def value_of(cls, value: str) -> 'SystemVariable': - """ - Get value of given system variable. - - :param value: system variable value - :return: system variable - """ - for system_variable in cls: - if system_variable.value == value: - return system_variable - raise ValueError(f'invalid system variable value {value}') - - class NodeRunMetadataKey(Enum): """ Node Run Metadata Key. """ + TOTAL_TOKENS = 'total_tokens' TOTAL_PRICE = 'total_price' CURRENCY = 'currency' @@ -83,6 +62,7 @@ class NodeRunResult(BaseModel): """ Node Run Result. """ + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index a96a26f794..9fe3356faa 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -6,7 +6,7 @@ from typing_extensions import deprecated from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.enums import SystemVariable VariableValue = Union[str, int, float, dict, list, FileVar] diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py new file mode 100644 index 0000000000..4757cf32f8 --- /dev/null +++ b/api/core/workflow/enums.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class SystemVariable(str, Enum): + """ + System Variables. + """ + QUERY = 'query' + FILES = 'files' + CONVERSATION_ID = 'conversation_id' + USER_ID = 'user_id' + DIALOGUE_COUNT = 'dialogue_count' + + @classmethod + def value_of(cls, value: str): + """ + Get value of given system variable. + + :param value: system variable value + :return: system variable + """ + for system_variable in cls: + if system_variable.value == value: + return system_variable + raise ValueError(f'invalid system variable value {value}') diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index bbe5f9ad43..1facf8a4f4 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -133,9 +133,6 @@ class HttpRequestNode(BaseNode): """ files = [] mimetype, file_binary = response.extract_file() - # if not image, return directly - if 'image' not in mimetype: - return files if mimetype: # extract filename from url diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 4431259a57..97b64d4b05 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, @@ -201,8 +202,8 @@ class LLMNode(BaseNode): usage = LLMUsage.empty_usage() return full_text, usage - - def _transform_chat_messages(self, + + def _transform_chat_messages(self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ @@ -249,13 +250,13 @@ class LLMNode(BaseNode): # check if it's a context structure if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: return d['content'] - + # else, parse the dict try: return json.dumps(d, ensure_ascii=False) except Exception: return str(d) - + if isinstance(value, str): value = value elif isinstance(value, list): diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 87bfa5beae..554e3b6074 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence from os import path from typing import Any, cast -from core.app.segments import parser +from core.app.segments import ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser -from models.workflow import WorkflowNodeExecutionStatus +from models import WorkflowNodeExecutionStatus class ToolNode(BaseNode): @@ -140,9 +141,9 @@ class ToolNode(BaseNode): return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - # FIXME: ensure this is a ArrayVariable contains FileVariable. variable = variable_pool.get(['sys', SystemVariable.FILES.value]) - return [file_var.value for file_var in variable.value] if variable else [] + assert isinstance(variable, ArrayAnyVariable) + return list(variable.value) if variable else [] def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index f299f84efb..3157eedfee 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -3,6 +3,7 @@ import time from collections.abc import Mapping, Sequence from typing import Any, Optional, cast +import contexts from configs import dify_config from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom @@ -97,16 +98,16 @@ class WorkflowEngineManager: invoke_from: InvokeFrom, callbacks: Sequence[WorkflowCallback], call_depth: int = 0, - variable_pool: VariablePool, + variable_pool: VariablePool | None = None, ) -> None: """ :param workflow: Workflow instance :param user_id: user id :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files + :param invoke_from: invoke from :param callbacks: workflow callbacks :param call_depth: call depth + :param variable_pool: variable pool """ # fetch workflow graph graph = workflow.graph_dict @@ -128,6 +129,8 @@ class WorkflowEngineManager: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) # init workflow run state + if not variable_pool: + variable_pool = contexts.workflow_variable_pool.get() workflow_run_state = WorkflowRunState( workflow=workflow, start_at=time.perf_counter(), diff --git a/api/libs/bearer_data_source.py b/api/libs/bearer_data_source.py index 04de1fb6da..c1aee7b819 100644 --- a/api/libs/bearer_data_source.py +++ b/api/libs/bearer_data_source.py @@ -2,10 +2,10 @@ from abc import abstractmethod import requests -from api.models.source import DataSourceBearerBinding from flask_login import current_user from extensions.ext_database import db +from models.source import DataSourceBearerBinding class BearerDataSource: diff --git a/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py new file mode 100644 index 0000000000..eba78e2e77 --- /dev/null +++ b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py @@ -0,0 +1,33 @@ +"""add conversations.dialogue_count + +Revision ID: 8782057ff0dc +Revises: 63a83fcf12ba +Create Date: 2024-08-14 13:54:25.161324 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '8782057ff0dc' +down_revision = '63a83fcf12ba' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_column('dialogue_count') + + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index f831356841..4012611471 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,10 +1,10 @@ from enum import Enum -from .model import AppMode +from .model import App, AppMode, Message from .types import StringUUID -from .workflow import ConversationVariable, WorkflowNodeExecutionStatus +from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus -__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus'] +__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] class CreatedByRole(Enum): diff --git a/api/models/model.py b/api/models/model.py index 9909b10dc0..5426d3bc83 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,6 +7,7 @@ from typing import Optional from flask import request from flask_login import UserMixin from sqlalchemy import Float, func, text +from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config from core.file.tool_file_parser import ToolFileParser @@ -512,12 +513,12 @@ class Conversation(db.Model): from_account_id = db.Column(StringUUID) read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) + dialogue_count: Mapped[int] = mapped_column(default=0) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', - passive_deletes="all") + message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) diff --git a/api/poetry.lock b/api/poetry.lock index 89d017f656..358f9f8510 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2100,6 +2100,44 @@ primp = ">=0.5.5" dev = ["mypy (>=1.11.0)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.5.5)"] lxml = ["lxml (>=5.2.2)"] +[[package]] +name = "elastic-transport" +version = "8.15.0" +description = "Transport classes and utilities shared among Python Elastic client libraries" +optional = false +python-versions = ">=3.8" +files = [ + {file = "elastic_transport-8.15.0-py3-none-any.whl", hash = "sha256:d7080d1dada2b4eee69e7574f9c17a76b42f2895eff428e562f94b0360e158c0"}, + {file = "elastic_transport-8.15.0.tar.gz", hash = "sha256:85d62558f9baafb0868c801233a59b235e61d7b4804c28c2fadaa866b6766233"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = ">=1.26.2,<3" + +[package.extras] +develop = ["aiohttp", "furo", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"] + +[[package]] +name = "elasticsearch" +version = "8.14.0" +description = "Python client for Elasticsearch" +optional = false +python-versions = ">=3.7" +files = [ + {file = "elasticsearch-8.14.0-py3-none-any.whl", hash = "sha256:cef8ef70a81af027f3da74a4f7d9296b390c636903088439087b8262a468c130"}, + {file = "elasticsearch-8.14.0.tar.gz", hash = "sha256:aa2490029dd96f4015b333c1827aa21fd6c0a4d223b00dfb0fe933b8d09a511b"}, +] + +[package.dependencies] +elastic-transport = ">=8.13,<9" + +[package.extras] +async = ["aiohttp (>=3,<4)"] +orjson = ["orjson (>=3)"] +requests = ["requests (>=2.4.0,!=2.32.2,<3.0.0)"] +vectorstore-mmr = ["numpy (>=1)", "simsimd (>=3)"] + [[package]] name = "emoji" version = "2.12.1" @@ -9546,4 +9584,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2b822039247a445f72e04e967aef84f841781e2789b70071acad022f36ba26a5" +content-hash = "05dfa6b9bce9ed8ac21caf58eff1596f146080ab2ab6987924b189be673c22cf" diff --git a/api/pyproject.toml b/api/pyproject.toml index 058d67c42f..82e2aaeb2b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -181,6 +181,7 @@ zhipuai = "1.0.7" rank-bm25 = "~0.2.2" openpyxl = "^3.1.5" kaleido = "0.2.1" +elasticsearch = "8.14.0" ############################################################ # Tool dependencies required by tool implementations diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index e16e5c715c..838585c575 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -13,9 +13,9 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) -current_dsl_version = "0.1.0" +current_dsl_version = "0.1.1" dsl_to_dify_version_mapping: dict[str, str] = { - "0.1.0": "0.6.0", # dsl version -> from dify version + "0.1.1": "0.6.0", # dsl version -> from dify version } diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index 2f66d707ca..c2fe95974b 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -1,5 +1,4 @@ - -from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter class MockTEIClass: @@ -12,7 +11,7 @@ class MockTEIClass: model_type = 'embedding' return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) - + @staticmethod def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: # Use space as token separator, and split the text into tokens diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py index da65c7dfc7..ed371fbc07 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -1,12 +1,12 @@ import os import pytest -from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import ( HuggingfaceTeiTextEmbeddingModel, + TeiHelper, ) from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass diff --git a/api/tests/integration_tests/vdb/elasticsearch/__init__.py b/api/tests/integration_tests/vdb/elasticsearch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py new file mode 100644 index 0000000000..b1c1cc10d9 --- /dev/null +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -0,0 +1,25 @@ +from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class ElasticSearchVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.vector = ElasticSearchVector( + index_name=self.collection_name.lower(), + config=ElasticSearchConfig( + host='http://localhost', + port='9200', + username='elastic', + password='elastic' + ), + attributes=self.attributes + ) + + +def test_elasticsearch_vector(setup_mock_redis): + ElasticSearchVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index ac704e4eaf..4686ce0675 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -10,8 +10,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db @@ -236,4 +236,4 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert 'sunny' in json.dumps(result.process_data) - assert 'what\'s the weather today?' in json.dumps(result.process_data) \ No newline at end of file + assert 'what\'s the weather today?' in json.dumps(result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 312ad47026..adf5ffe3ca 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -12,8 +12,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db @@ -363,7 +363,7 @@ def test_extract_json_response(): { "location": "kawaii" } - hello world. + hello world. """) assert result['location'] == 'kawaii' @@ -445,4 +445,4 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): assert latest_role != prompt.get('role') if prompt.get('role') in ['user', 'assistant']: - latest_role = prompt.get('role') \ No newline at end of file + latest_role = prompt.get('role') diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index a8429b9c1b..afd0fa50b5 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -3,12 +3,9 @@ from uuid import uuid4 import pytest from core.app.segments import ( - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileSegment, - FileVariable, FloatVariable, IntegerVariable, ObjectSegment, @@ -149,83 +146,6 @@ def test_array_object_variable(): assert isinstance(variable.value[1]['key2'], int) -def test_file_variable(): - mapping = { - 'id': str(uuid4()), - 'value_type': 'file', - 'name': 'test_file', - 'description': 'Description of the variable.', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - } - variable = factory.build_variable_from_mapping(mapping) - assert isinstance(variable, FileVariable) - - -def test_array_file_variable(): - mapping = { - 'id': str(uuid4()), - 'value_type': 'array[file]', - 'name': 'test_array_file', - 'description': 'Description of the variable.', - 'value': [ - { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - ], - } - variable = factory.build_variable_from_mapping(mapping) - assert isinstance(variable, ArrayFileVariable) - assert isinstance(variable.value[0], FileSegment) - assert isinstance(variable.value[1], FileSegment) - - def test_variable_cannot_large_than_5_kb(): with pytest.raises(VariableError): factory.build_variable_from_mapping( diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 414404b7d0..7e3e69ffbf 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,7 +1,7 @@ from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable def test_segment_group_to_text(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 3a32829e37..4617b6a42f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 4662c5ff2b..d21b7785c4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 8706ba05ce..0b37d06fc0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -3,8 +3,8 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.segments import ArrayStringVariable, StringVariable -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index c954c528fb..0b23200dc3 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/pgvector \ api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/weaviate \ + api/tests/integration_tests/vdb/elasticsearch \ api/tests/integration_tests/vdb/test_vector_store.py \ No newline at end of file diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 807946f3fe..aed2586053 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.6.16 + image: langgenius/dify-api:0.7.0 restart: always environment: # Startup mode, 'api' starts the API server. @@ -169,6 +169,11 @@ services: CHROMA_DATABASE: default_database CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider CHROMA_AUTH_CREDENTIALS: xxxxxx + # ElasticSearch Config + ELASTICSEARCH_HOST: 127.0.0.1 + ELASTICSEARCH_PORT: 9200 + ELASTICSEARCH_USERNAME: elastic + ELASTICSEARCH_PASSWORD: elastic # Mail configuration, support: resend, smtp MAIL_TYPE: '' # default send from email address, if not specified @@ -224,7 +229,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.16 + image: langgenius/dify-api:0.7.0 restart: always environment: CONSOLE_WEB_URL: '' @@ -371,6 +376,11 @@ services: CHROMA_DATABASE: default_database CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider CHROMA_AUTH_CREDENTIALS: xxxxxx + # ElasticSearch Config + ELASTICSEARCH_HOST: 127.0.0.1 + ELASTICSEARCH_PORT: 9200 + ELASTICSEARCH_USERNAME: elastic + ELASTICSEARCH_PASSWORD: elastic # Notion import configuration, support public and internal NOTION_INTEGRATION_TYPE: public NOTION_CLIENT_SECRET: you-client-secret @@ -390,7 +400,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.6.16 + image: langgenius/dify-web:0.7.0 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 2b10fbc2cc..f3151bbc2a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -125,6 +125,10 @@ x-shared-env: &shared-api-worker-env CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database} CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider} CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-} + ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-127.0.0.1} + ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} + ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} + ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-} @@ -187,7 +191,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.6.16 + image: langgenius/dify-api:0.7.0 restart: always environment: # Use the shared environment variables. @@ -207,7 +211,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.16 + image: langgenius/dify-api:0.7.0 restart: always environment: # Use the shared environment variables. @@ -226,7 +230,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.6.16 + image: langgenius/dify-web:0.7.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -583,7 +587,7 @@ services: # MyScale vector database myscale: container_name: myscale - image: myscale/myscaledb:1.6 + image: myscale/myscaledb:1.6.4 profiles: - myscale restart: always @@ -595,6 +599,27 @@ services: ports: - "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}" + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3 + container_name: elasticsearch + profiles: + - elasticsearch + restart: always + environment: + - "ELASTIC_PASSWORD=${ELASTICSEARCH_USERNAME:-elastic}" + - "cluster.name=dify-es-cluster" + - "node.name=dify-es0" + - "discovery.type=single-node" + - "xpack.security.http.ssl.enabled=false" + - "xpack.license.self_generated.type=trial" + ports: + - "${ELASTICSEARCH_PORT:-9200}:${ELASTICSEARCH_PORT:-9200}" + healthcheck: + test: ["CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty"] + interval: 30s + timeout: 10s + retries: 50 + # unstructured . # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.) unstructured: diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 36395d391d..44c5964d77 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -922,6 +922,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge ID + + Document ID + Document Segment ID @@ -965,6 +968,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge ID + + Document ID + Document Segment ID diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index a624c0594f..9f79b0f900 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -922,6 +922,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 知识库 ID + + 文档 ID + 文档分段ID @@ -965,6 +968,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 知识库 ID + + 文档 ID + 文档分段ID diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index 3adb4d75e1..02d00334b6 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -1,4 +1,5 @@ import ReactMarkdown from 'react-markdown' +import ReactEcharts from 'echarts-for-react' import 'katex/dist/katex.min.css' import RemarkMath from 'remark-math' import RemarkBreaks from 'remark-breaks' @@ -30,6 +31,7 @@ const capitalizationLanguageNameMap: Record = { mermaid: 'Mermaid', markdown: 'MarkDown', makefile: 'MakeFile', + echarts: 'ECharts', } const getCorrectCapitalizationLanguageName = (language: string) => { if (!language) @@ -107,6 +109,14 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } const match = /language-(\w+)/.exec(className || '') const language = match?.[1] const languageShowName = getCorrectCapitalizationLanguageName(language || '') + let chartData = JSON.parse(String('{"title":{"text":"Something went wrong."}}').replace(/\n$/, '')) + if (language === 'echarts') { + try { + chartData = JSON.parse(String(children).replace(/\n$/, '')) + } + catch (error) { + } + } // Use `useMemo` to ensure that `SyntaxHighlighter` only re-renders when necessary return useMemo(() => { @@ -136,19 +146,25 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } {(language === 'mermaid' && isSVG) ? () - : ( - {String(children).replace(/\n$/, '')} - )} + : ( + (language === 'echarts') + ? (
+
) + : ( + {String(children).replace(/\n$/, '')} + ))} ) : ( diff --git a/web/app/components/tools/edit-custom-collection-modal/index.tsx b/web/app/components/tools/edit-custom-collection-modal/index.tsx index b55c224164..03523aa0cb 100644 --- a/web/app/components/tools/edit-custom-collection-modal/index.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/index.tsx @@ -329,36 +329,36 @@ const EditCustomCollectionModal: FC = ({ + {showEmojiPicker && { + setEmoji({ content: icon, background: icon_background }) + setShowEmojiPicker(false) + }} + onClose={() => { + setShowEmojiPicker(false) + }} + />} + {credentialsModalShow && ( + setCredentialsModalShow(false)} + />) + } + {isShowTestApi && ( + setIsShowTestApi(false)} + /> + )} } isShowMask={true} clickOutsideNotOpen={true} /> - {showEmojiPicker && { - setEmoji({ content: icon, background: icon_background }) - setShowEmojiPicker(false) - }} - onClose={() => { - setShowEmojiPicker(false) - }} - />} - {credentialsModalShow && ( - setCredentialsModalShow(false)} - />) - } - {isShowTestApi && ( - setIsShowTestApi(false)} - /> - )} ) diff --git a/web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts b/web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts index ef3d665910..b81feab805 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts @@ -7,12 +7,16 @@ import { import type { ValueSelector, Var } from '@/app/components/workflow/types' type Params = { onlyLeafNodeVar?: boolean + hideEnv?: boolean + hideChatVar?: boolean filterVar: (payload: Var, selector: ValueSelector) => boolean } const useAvailableVarList = (nodeId: string, { onlyLeafNodeVar, filterVar, + hideEnv, + hideChatVar, }: Params = { onlyLeafNodeVar: false, filterVar: () => true, @@ -32,6 +36,8 @@ const useAvailableVarList = (nodeId: string, { beforeNodes: availableNodes, isChatMode, filterVar, + hideEnv, + hideChatVar, }) return { diff --git a/web/app/components/workflow/nodes/answer/panel.tsx b/web/app/components/workflow/nodes/answer/panel.tsx index feb07c36c9..daa5be4e66 100644 --- a/web/app/components/workflow/nodes/answer/panel.tsx +++ b/web/app/components/workflow/nodes/answer/panel.tsx @@ -23,6 +23,8 @@ const Panel: FC> = ({ const { availableVars, availableNodesWithParent } = useAvailableVarList(id, { onlyLeafNodeVar: false, + hideChatVar: true, + hideEnv: true, filterVar, }) diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 8a671adc93..61b289de53 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -62,30 +62,25 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const formatNodeList = useCallback((list: NodeTracing[]) => { const allItems = list.reverse() const result: NodeTracing[] = [] - let iterationIndexInfos: { - start: number - end: number - }[] = [] + let iterationIndex = 0 allItems.forEach((item) => { - const { node_type, index, execution_metadata } = item + const { node_type, execution_metadata } = item if (node_type !== BlockEnum.Iteration) { - let isInIteration = false - let isIterationFirstNode = false - iterationIndexInfos.forEach(({ start, end }) => { - if (index >= start && index < end) { - if (index === start) - isIterationFirstNode = true + const isInIteration = !!execution_metadata?.iteration_id - isInIteration = true - } - }) if (isInIteration) { const iterationDetails = result[result.length - 1].details! - if (isIterationFirstNode) - iterationDetails!.push([item]) + const currentIterationIndex = execution_metadata?.iteration_index + const isIterationFirstNode = iterationIndex !== currentIterationIndex || iterationDetails.length === 0 - else + if (isIterationFirstNode) { + iterationDetails!.push([item]) + iterationIndex = currentIterationIndex! + } + + else { iterationDetails[iterationDetails.length - 1].push(item) + } return } @@ -95,26 +90,6 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe return } - const { steps_boundary } = execution_metadata - iterationIndexInfos = [] - steps_boundary.forEach((boundary, index) => { - if (index === 0) { - iterationIndexInfos.push({ - start: boundary, - end: 0, - }) - } - else if (index === steps_boundary.length - 1) { - iterationIndexInfos[iterationIndexInfos.length - 1].end = boundary - } - else { - iterationIndexInfos[iterationIndexInfos.length - 1].end = boundary - iterationIndexInfos.push({ - start: boundary, - end: 0, - }) - } - }) result.push({ ...item, details: [], diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index f5df961d21..f0f7ec5173 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -123,7 +123,7 @@ const NodePanel: FC = ({
-
{t('workflow.nodes.iteration.iteration', { count: nodeInfo.metadata?.iterator_length || (nodeInfo.execution_metadata?.steps_boundary?.length - 1) })}
+
{t('workflow.nodes.iteration.iteration', { count: nodeInfo.metadata?.iterator_length })}
{justShowIterationNavArrow ? ( diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index c00578ca6a..8f506bcb46 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -94,11 +94,38 @@ const translation = { }, export: { title: 'シークレット環境変数をエクスポートしますか?', - checkbox: 'シクレート値をエクスポート', + checkbox: 'シークレット値をエクスポート', ignore: 'DSLをエクスポート', - export: 'シクレート値を含むDSLをエクスポート', + export: 'シークレット値を含むDSLをエクスポート', }, }, + chatVariable: { + panelTitle: '会話変数', + panelDescription: '会話変数は、LLMが記憶すべき対話情報を保存するために使用されます。この情報には、対話の履歴、アップロードされたファイル、ユーザーの好みなどが含まれます。読み書きが可能です。', + docLink: '詳しくはドキュメントをご覧ください。', + button: '変数を追加', + modal: { + title: '会話変数を追加', + editTitle: '会話変数を編集', + name: '名前', + namePlaceholder: '変数名前', + type: 'タイプ', + value: 'デフォルト値', + valuePlaceholder: 'デフォルト値、設定しない場合は空白にしでください', + description: '説明', + descriptionPlaceholder: '変数の説明', + editInJSON: 'JSONで編集する', + oneByOne: '次々に追加する', + editInForm: 'フォームで編集', + arrayValue: '値', + addArrayValue: '値を追加', + objectKey: 'キー', + objectType: 'タイプ', + objectValue: 'デフォルト値', + }, + storedContent: '保存されたコンテンツ', + updatedAt: '更新日は', + }, changeHistory: { title: '変更履歴', placeholder: 'まだ何も変更していません', @@ -149,6 +176,7 @@ const translation = { tabs: { 'searchBlock': 'ブロックを検索', 'blocks': 'ブロック', + 'searchTool': '検索ツール', 'tools': 'ツール', 'allTool': 'すべて', 'workflowTool': 'ワークフロー', @@ -171,8 +199,9 @@ const translation = { 'code': 'コード', 'template-transform': 'テンプレート', 'http-request': 'HTTPリクエスト', - 'variable-assigner': '変数代入', + 'variable-assigner': '変数代入器', 'variable-aggregator': '変数集約器', + 'assigner': '変数代入', 'iteration-start': 'イテレーション開始', 'iteration': 'イテレーション', 'parameter-extractor': 'パラメーター抽出', @@ -189,6 +218,7 @@ const translation = { 'template-transform': 'Jinjaテンプレート構文を使用してデータを文字列に変換します', 'http-request': 'HTTPプロトコル経由でサーバーリクエストを送信できます', 'variable-assigner': '複数のブランチの変数を1つの変数に集約し、下流のノードに対して統一された設定を行います。', + 'assigner': '変数代入ノードは、書き込み可能な変数(例えば、会話変数)に値を割り当てるために使用されます。', 'variable-aggregator': '複数のブランチの変数を1つの変数に集約し、下流のノードに対して統一された設定を行います。', 'iteration': 'リストオブジェクトに対して複数のステップを実行し、すべての結果が出力されるまで繰り返します。', 'parameter-extractor': '自然言語からツールの呼び出しやHTTPリクエストのための構造化されたパラメーターを抽出するためにLLMを使用します。', @@ -215,6 +245,7 @@ const translation = { checklistResolved: 'すべての問題が解決されました', organizeBlocks: 'ブロックを整理', change: '変更', + optional: '(オプション)', }, nodes: { common: { @@ -406,6 +437,17 @@ const translation = { }, setAssignVariable: '代入された変数を設定', }, + assigner: { + 'assignedVariable': '代入された変数', + 'writeMode': '書き込みモード', + 'writeModeTip': '代入された変数が配列の場合, 末尾に追記モードを追加する。', + 'over-write': '上書き', + 'append': '追記', + 'plus': 'プラス', + 'clear': 'クリア', + 'setVariable': '変数を設定する', + 'variable': '変数', + }, tool: { toAuthorize: '承認するには', inputVars: '入力変数', diff --git a/web/package.json b/web/package.json index 2d6cd0a511..9b8e50885c 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.6.16", + "version": "0.7.0", "private": true, "engines": { "node": ">=18.17.0" diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 72475ce55c..f7991bc4e0 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -24,7 +24,8 @@ export type NodeTracing = { total_tokens: number total_price: number currency: string - steps_boundary: number[] + iteration_id?: string + iteration_index?: number } metadata: { iterator_length: number