Merge branch 'main' into feat/attachments

This commit is contained in:
StyleZhang 2024-08-15 10:57:59 +08:00
commit 33bfa4758e
76 changed files with 1076 additions and 404 deletions

View File

@ -76,7 +76,7 @@ jobs:
- name: Run Workflow - name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh run: poetry run -C api bash dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale) - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch)
uses: hoverkraft-tech/compose-action@v2.0.0 uses: hoverkraft-tech/compose-action@v2.0.0
with: with:
compose-file: | compose-file: |
@ -90,5 +90,6 @@ jobs:
pgvecto-rs pgvecto-rs
pgvector pgvector
chroma chroma
elasticsearch
- name: Test Vector Stores - name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@ -6,5 +6,6 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs." echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch"

View File

@ -130,6 +130,12 @@ TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2 TENCENT_VECTOR_DB_REPLICAS=2
# ElasticSearch configuration
ELASTICSEARCH_HOST=127.0.0.1
ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic
# PGVECTO_RS configuration # PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431 PGVECTO_RS_PORT=5431

View File

@ -344,6 +344,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name}
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ELASTICSEARCH:
dataset_id = dataset.id
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'elasticsearch',
"vector_store": {"class_prefix": index_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else: else:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description='Dify version', description='Dify version',
default='0.6.16', default='0.7.0',
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

View File

@ -1,3 +1,7 @@
from contextvars import ContextVar from contextvars import ContextVar
tenant_id: ContextVar[str] = ContextVar('tenant_id') from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar('tenant_id')
workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')

View File

@ -555,7 +555,7 @@ class DatasetRetrievalSettingApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH.value RetrievalMethod.SEMANTIC_SEARCH.value
] ]
} }
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE: case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.SEMANTIC_SEARCH.value,
@ -579,7 +579,7 @@ class DatasetRetrievalSettingMockApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH.value RetrievalMethod.SEMANTIC_SEARCH.value
] ]
} }
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE: case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -178,11 +178,20 @@ class DatasetDocumentListApi(Resource):
.subquery() .subquery()
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
.order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0))) .order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == 'created_at': elif sort == 'created_at':
query = query.order_by(sort_logic(Document.created_at)) query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
)
else: else:
query = query.order_by(desc(Document.created_at)) query = query.order_by(
desc(Document.created_at),
desc(Document.position),
)
paginated_documents = query.paginate( paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False) page=page, per_page=limit, max_per_page=100, error_out=False)

View File

@ -93,6 +93,7 @@ class DatasetConfigManager:
reranking_model=dataset_configs.get('reranking_model'), reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'), weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True), reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs["reranking_mode"],
) )
) )

View File

@ -8,6 +8,8 @@ from typing import Union
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow from models.workflow import ConversationVariable, Workflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -120,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
stream=stream stream=stream
) )
def single_iteration_generate(self, app_model: App, def single_iteration_generate(self, app_model: App,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
@ -140,10 +147,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
""" """
if not node_id: if not node_id:
raise ValueError('node_id is required') raise ValueError('node_id is required')
if args.get('inputs') is None: if args.get('inputs') is None:
raise ValueError('inputs is required') raise ValueError('inputs is required')
extras = { extras = {
"auto_generate_conversation_name": False "auto_generate_conversation_name": False
} }
@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# update conversation features # update conversation features
conversation.override_model_configs = workflow.features conversation.override_model_configs = workflow.features
db.session.commit() db.session.commit()
db.session.refresh(conversation) # db.session.refresh(conversation)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id message_id=message.id
) )
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity, 'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager, 'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id, 'message_id': message.id,
'user': user, 'context': contextvars.copy_context(),
'context': contextvars.copy_context()
}) })
worker_thread.start() worker_thread.start()
@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
user=user, user=user,
stream=stream stream=stream,
) )
return AdvancedChatAppGenerateResponseConverter.convert( return AdvancedChatAppGenerateResponseConverter.convert(
@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(self, flask_app: Flask, def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str,
message_id: str, message_id: str,
user: Account,
context: contextvars.Context) -> None: context: contextvars.Context) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_id=application_generate_entity.user_id user_id=application_generate_entity.user_id
) )
else: else:
# get conversation and message # get message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)
# chatbot app # chatbot app
@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
runner.run( runner.run(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
conversation=conversation,
message=message message=message
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedException:
@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
finally: finally:
db.session.close() db.session.close()
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, def _handle_advanced_chat_response(
workflow: Workflow, self,
queue_manager: AppQueueManager, *,
conversation: Conversation, application_generate_entity: AdvancedChatAppGenerateEntity,
message: Message, workflow: Workflow,
user: Union[Account, EndUser], queue_manager: AppQueueManager,
stream: bool = False) \ conversation: Conversation,
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
user=user, user=user,
stream=stream stream=stream,
) )
try: try:

View File

@ -4,9 +4,6 @@ import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import (
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message from models import App, Message, Workflow
from models.workflow import ConversationVariable, Workflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner):
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation,
message: Message, message: Message,
) -> None: ) -> None:
""" """
@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner):
inputs = application_generate_entity.inputs inputs = application_generate_entity.inputs
query = application_generate_entity.query query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# moderation # moderation
if self.handle_input_moderation( if self.handle_input_moderation(
@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.commit()
# Convert database entities to variables
conversation_variables = [item.to_variable() for item in conversation_variables]
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# RUN WORKFLOW # RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow( workflow_engine_manager.run_workflow(
@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks, callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth, call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
) )
def single_iteration_run( def single_iteration_run(
@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner):
""" """
Single iteration run Single iteration run
""" """
app_record: App = db.session.query(App).filter(App.id == app_id).first() app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record: if not app_record:
raise ValueError('App not found') raise ValueError('App not found')

View File

@ -4,6 +4,7 @@ import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -47,7 +48,8 @@ from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created from events.message_event import message_was_created
@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_application_generate_entity: AdvancedChatAppGenerateEntity _application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow _workflow: Workflow
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any] _workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]] _iteration_nested_relations: dict[str, list[str]]
@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool stream: bool,
) -> None: ) -> None:
""" """
Initialize AdvancedChatAppGenerateTaskPipeline. Initialize AdvancedChatAppGenerateTaskPipeline.
@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow self._workflow = workflow
self._conversation = conversation self._conversation = conversation
self._message = message self._message = message
# Deprecated
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariable.QUERY: message.query, SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files, SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id, SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id SystemVariable.USER_ID: user_id,
} }
self._task_state = AdvancedChatTaskState( self._task_state = AdvancedChatTaskState(
@ -613,7 +617,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if route_chunk_node_id == 'sys': if route_chunk_node_id == 'sys':
# system variable # system variable
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1])) value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations: elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable # it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:

View File

@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction return introduction
def _get_conversation(self, conversation_id: str) -> Conversation: def _get_conversation(self, conversation_id: str):
""" """
Get conversation by conversation id Get conversation by conversation id
:param conversation_id: conversation id :param conversation_id: conversation id
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
.first() .first()
) )
if not conversation:
raise ConversationNotExistsError()
return conversation return conversation
def _get_message(self, message_id: str) -> Message: def _get_message(self, message_id: str) -> Message:

View File

@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -42,7 +42,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
@ -519,7 +520,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
""" """
nodes = graph.get('nodes') nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [ if node.get('data', {}).get('type') in [
NodeType.ITERATION.value, NodeType.ITERATION.value,
NodeType.LOOP.value, NodeType.LOOP.value,
@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids ] for iteration_id in iteration_ids
} }

View File

@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
from .segments import ( from .segments import (
ArrayAnySegment, ArrayAnySegment,
ArraySegment, ArraySegment,
FileSegment,
FloatSegment, FloatSegment,
IntegerSegment, IntegerSegment,
NoneSegment, NoneSegment,
@ -13,11 +12,9 @@ from .segments import (
from .types import SegmentType from .types import SegmentType
from .variables import ( from .variables import (
ArrayAnyVariable, ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable, ArrayNumberVariable,
ArrayObjectVariable, ArrayObjectVariable,
ArrayStringVariable, ArrayStringVariable,
FileVariable,
FloatVariable, FloatVariable,
IntegerVariable, IntegerVariable,
NoneVariable, NoneVariable,
@ -32,7 +29,6 @@ __all__ = [
'FloatVariable', 'FloatVariable',
'ObjectVariable', 'ObjectVariable',
'SecretVariable', 'SecretVariable',
'FileVariable',
'StringVariable', 'StringVariable',
'ArrayAnyVariable', 'ArrayAnyVariable',
'Variable', 'Variable',
@ -45,11 +41,9 @@ __all__ = [
'FloatSegment', 'FloatSegment',
'ObjectSegment', 'ObjectSegment',
'ArrayAnySegment', 'ArrayAnySegment',
'FileSegment',
'StringSegment', 'StringSegment',
'ArrayStringVariable', 'ArrayStringVariable',
'ArrayNumberVariable', 'ArrayNumberVariable',
'ArrayObjectVariable', 'ArrayObjectVariable',
'ArrayFileVariable',
'ArraySegment', 'ArraySegment',
] ]

View File

@ -2,12 +2,10 @@ from collections.abc import Mapping
from typing import Any from typing import Any
from configs import dify_config from configs import dify_config
from core.file.file_obj import FileVar
from .exc import VariableError from .exc import VariableError
from .segments import ( from .segments import (
ArrayAnySegment, ArrayAnySegment,
FileSegment,
FloatSegment, FloatSegment,
IntegerSegment, IntegerSegment,
NoneSegment, NoneSegment,
@ -17,11 +15,9 @@ from .segments import (
) )
from .types import SegmentType from .types import SegmentType
from .variables import ( from .variables import (
ArrayFileVariable,
ArrayNumberVariable, ArrayNumberVariable,
ArrayObjectVariable, ArrayObjectVariable,
ArrayStringVariable, ArrayStringVariable,
FileVariable,
FloatVariable, FloatVariable,
IntegerVariable, IntegerVariable,
ObjectVariable, ObjectVariable,
@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = FloatVariable.model_validate(mapping) result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int): case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}') raise VariableError(f'invalid number value {value}')
case SegmentType.FILE:
result = FileVariable.model_validate(mapping)
case SegmentType.OBJECT if isinstance(value, dict): case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping) result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list): case SegmentType.ARRAY_STRING if isinstance(value, list):
@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = ArrayNumberVariable.model_validate(mapping) result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list): case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping) result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
mapping = dict(mapping)
mapping['value'] = [{'value': v} for v in value]
result = ArrayFileVariable.model_validate(mapping)
case _: case _:
raise VariableError(f'not supported value type {value_type}') raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE: if result.size > dify_config.MAX_VARIABLE_SIZE:
@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value) return ObjectSegment(value=value)
if isinstance(value, list): if isinstance(value, list):
return ArrayAnySegment(value=value) return ArrayAnySegment(value=value)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}') raise ValueError(f'not supported value {value}')

View File

@ -5,8 +5,6 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from core.file.file_obj import FileVar
from .types import SegmentType from .types import SegmentType
@ -78,14 +76,7 @@ class IntegerSegment(Segment):
value: int value: int
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()
class ObjectSegment(Segment): class ObjectSegment(Segment):
@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[FileSegment]

View File

@ -10,8 +10,6 @@ class SegmentType(str, Enum):
ARRAY_STRING = 'array[string]' ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]' ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]' ARRAY_OBJECT = 'array[object]'
ARRAY_FILE = 'array[file]'
OBJECT = 'object' OBJECT = 'object'
FILE = 'file'
GROUP = 'group' GROUP = 'group'

View File

@ -4,11 +4,9 @@ from core.helper import encrypter
from .segments import ( from .segments import (
ArrayAnySegment, ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment, ArrayNumberSegment,
ArrayObjectSegment, ArrayObjectSegment,
ArrayStringSegment, ArrayStringSegment,
FileSegment,
FloatSegment, FloatSegment,
IntegerSegment, IntegerSegment,
NoneSegment, NoneSegment,
@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
pass pass
class FileVariable(FileSegment, Variable):
pass
class ObjectVariable(ObjectSegment, Variable): class ObjectVariable(ObjectSegment, Variable):
pass pass
@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass pass
class ArrayFileVariable(ArrayFileSegment, Variable):
pass
class SecretVariable(StringVariable): class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET value_type: SegmentType = SegmentType.SECRET

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.entities.node_entities import SystemVariable from core.workflow.enums import SystemVariable
from models.account import Account from models.account import Account
from models.model import EndUser from models.model import EndUser
from models.workflow import Workflow from models.workflow import Workflow
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow _workflow: Workflow
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any] _workflow_system_variables: dict[SystemVariable, Any]

View File

@ -0,0 +1,81 @@
model: farui-plus
label:
en_US: farui-plus
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 12288
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 2000
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: seed
required: false
type: int
default: 1234
label:
zh_Hans: 随机种子
en_US: Random seed
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
- name: enable_search
type: boolean
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
use_template: response_format
pricing:
input: '0.02'
output: '0.02'
unit: '0.001'
currency: RMB

View File

@ -2,3 +2,8 @@ model: text-embedding-v1
model_type: text-embedding model_type: text-embedding
model_properties: model_properties:
context_size: 2048 context_size: 2048
max_chunks: 25
pricing:
input: "0.0007"
unit: "0.001"
currency: RMB

View File

@ -2,3 +2,8 @@ model: text-embedding-v2
model_type: text-embedding model_type: text-embedding
model_properties: model_properties:
context_size: 2048 context_size: 2048
max_chunks: 25
pricing:
input: "0.0007"
unit: "0.001"
currency: RMB

View File

@ -2,6 +2,7 @@ import time
from typing import Optional from typing import Optional
import dashscope import dashscope
import numpy as np
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import ( from core.model_runtime.entities.text_embedding_entities import (
@ -21,11 +22,11 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
""" """
def _invoke( def _invoke(
self, self,
model: str, model: str,
credentials: dict, credentials: dict,
texts: list[str], texts: list[str],
user: Optional[str] = None, user: Optional[str] = None,
) -> TextEmbeddingResult: ) -> TextEmbeddingResult:
""" """
Invoke text embedding model Invoke text embedding model
@ -37,16 +38,44 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
:return: embeddings result :return: embeddings result
""" """
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
embeddings, embedding_used_tokens = self.embed_documents(
credentials_kwargs=credentials_kwargs,
model=model,
texts=texts
)
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
embeddings_batch, embedding_used_tokens = self.embed_documents(
credentials_kwargs=credentials_kwargs,
model=model,
texts=inputs[i : i + max_chunks],
)
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
# calc usage
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=used_tokens
)
return TextEmbeddingResult( return TextEmbeddingResult(
embeddings=embeddings, embeddings=batched_embeddings, usage=usage, model=model
usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
model=model
) )
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
@ -79,12 +108,16 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
# call embedding model # call embedding model
self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) self.embed_documents(
credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]
)
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
@staticmethod @staticmethod
def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: def embed_documents(
credentials_kwargs: dict, model: str, texts: list[str]
) -> tuple[list[list[float]], int]:
"""Call out to Tongyi's embedding endpoint. """Call out to Tongyi's embedding endpoint.
Args: Args:
@ -102,7 +135,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
api_key=credentials_kwargs["dashscope_api_key"], api_key=credentials_kwargs["dashscope_api_key"],
model=model, model=model,
input=text, input=text,
text_type="document" text_type="document",
) )
data = response.output["embeddings"][0] data = response.output["embeddings"][0]
embeddings.append(data["embedding"]) embeddings.append(data["embedding"])
@ -111,7 +144,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
return [list(map(float, e)) for e in embeddings], embedding_used_tokens return [list(map(float, e)) for e in embeddings], embedding_used_tokens
def _calc_response_usage( def _calc_response_usage(
self, model: str, credentials: dict, tokens: int self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage: ) -> EmbeddingUsage:
""" """
Calculate response usage Calculate response usage
@ -125,7 +158,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
model=model, model=model,
credentials=credentials, credentials=credentials,
price_type=PriceType.INPUT, price_type=PriceType.INPUT,
tokens=tokens tokens=tokens,
) )
# transform usage # transform usage
@ -136,7 +169,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
price_unit=input_price_info.unit, price_unit=input_price_info.unit,
total_price=input_price_info.total_amount, total_price=input_price_info.total_amount,
currency=input_price_info.currency, currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at latency=time.perf_counter() - self.started_at,
) )
return usage return usage

View File

@ -1 +1 @@
- soloar-1-mini-chat - solar-1-mini-chat

View File

@ -0,0 +1,191 @@
import json
from typing import Any
import requests
from elasticsearch import Elasticsearch
from flask import current_app
from pydantic import BaseModel, model_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from models.dataset import Dataset
class ElasticSearchConfig(BaseModel):
host: str
port: str
username: str
password: str
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config HOST is required")
if not values['port']:
raise ValueError("config PORT is required")
if not values['username']:
raise ValueError("config USERNAME is required")
if not values['password']:
raise ValueError("config PASSWORD is required")
return values
class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
super().__init__(index_name.lower())
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
client = Elasticsearch(
hosts=f'{config.host}:{config.port}',
basic_auth=(config.username, config.password),
request_timeout=100000,
retry_on_timeout=True,
max_retries=10000,
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
return client
def get_type(self) -> str:
return 'elasticsearch'
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mapping = {
"properties": {
"text": {
"type": "text"
},
"vector": {
"type": "dense_vector",
"index": True,
"dims": dim,
"similarity": "l2_norm"
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mapping)
added_ids = []
for i, text in enumerate(texts):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
"text": text,
"vector": embeddings[i] if embeddings[i] else None,
"metadata": metadatas[i] if metadatas[i] else {},
})
added_ids.append(uuids[i])
self._client.indices.refresh(index=self._collection_name)
return uuids
def text_exists(self, id: str) -> bool:
return self._client.exists(index=self._collection_name, id=id).__bool__()
def delete_by_ids(self, ids: list[str]) -> None:
for id in ids:
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
query_str = {
'query': {
'match': {
f'metadata.{key}': f'{value}'
}
}
}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit['_id'] for hit in results['hits']['hits']]
if ids:
self.delete_by_ids(ids)
def delete(self) -> None:
self._client.indices.delete(index=self._collection_name)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_str = {
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {
"query_vector": query_vector
}
}
}
}
}
results = self._client.search(index=self._collection_name, body=query_str)
docs_and_scores = []
for hit in results['hits']['hits']:
docs_and_scores.append(
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
if score > score_threshold:
doc.metadata['score'] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
"text": query
}
}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return self.add_texts(texts, embeddings, **kwargs)
class ElasticSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get('ELASTICSEARCH_HOST'),
port=config.get('ELASTICSEARCH_PORT'),
username=config.get('ELASTICSEARCH_USERNAME'),
password=config.get('ELASTICSEARCH_PASSWORD'),
),
attributes=[]
)

View File

@ -93,7 +93,7 @@ class MyScaleVector(BaseVector):
@staticmethod @staticmethod
def escape_str(value: Any) -> str: def escape_str(value: Any) -> str:
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value)) return "".join(" " if c in ("\\", "'") else c for c in str(value))
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
@ -118,7 +118,7 @@ class MyScaleVector(BaseVector):
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs) return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs) return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)

View File

@ -71,6 +71,9 @@ class Vector:
case VectorType.RELYT: case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.TIDB_VECTOR: case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory return TiDBVectorFactory

View File

@ -15,3 +15,4 @@ class VectorType(str, Enum):
OPENSEARCH = 'opensearch' OPENSEARCH = 'opensearch'
TENCENT = 'tencent' TENCENT = 'tencent'
ORACLE = 'oracle' ORACLE = 'oracle'
ELASTICSEARCH = 'elasticsearch'

View File

@ -46,7 +46,7 @@ class ToolProviderType(Enum):
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid mode value {value}') raise ValueError(f'invalid mode value {value}')
class ApiProviderSchemaType(Enum): class ApiProviderSchemaType(Enum):
""" """
Enum class for api provider schema type. Enum class for api provider schema type.
@ -68,7 +68,7 @@ class ApiProviderSchemaType(Enum):
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid mode value {value}') raise ValueError(f'invalid mode value {value}')
class ApiProviderAuthType(Enum): class ApiProviderAuthType(Enum):
""" """
Enum class for api provider auth type. Enum class for api provider auth type.
@ -103,8 +103,8 @@ class ToolInvokeMessage(BaseModel):
""" """
plain text, image url or link url plain text, image url or link url
""" """
message: Union[str, bytes, dict] = None message: str | bytes | dict | None = None
meta: dict[str, Any] = None meta: dict[str, Any] | None = None
save_as: str = '' save_as: str = ''
class ToolInvokeMessageBinary(BaseModel): class ToolInvokeMessageBinary(BaseModel):
@ -154,8 +154,8 @@ class ToolParameter(BaseModel):
options: Optional[list[ToolParameterOption]] = None options: Optional[list[ToolParameterOption]] = None
@classmethod @classmethod
def get_simple_instance(cls, def get_simple_instance(cls,
name: str, llm_description: str, type: ToolParameterType, name: str, llm_description: str, type: ToolParameterType,
required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': required: bool, options: Optional[list[str]] = None) -> 'ToolParameter':
""" """
get a simple tool parameter get a simple tool parameter
@ -222,7 +222,7 @@ class ToolProviderCredentials(BaseModel):
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid mode value {value}') raise ValueError(f'invalid mode value {value}')
@staticmethod @staticmethod
def default(value: str) -> str: def default(value: str) -> str:
return "" return ""
@ -290,7 +290,7 @@ class ToolRuntimeVariablePool(BaseModel):
'tenant_id': self.tenant_id, 'tenant_id': self.tenant_id,
'pool': [variable.model_dump() for variable in self.pool], 'pool': [variable.model_dump() for variable in self.pool],
} }
def set_text(self, tool_name: str, name: str, value: str) -> None: def set_text(self, tool_name: str, name: str, value: str) -> None:
""" """
set a text variable set a text variable
@ -301,7 +301,7 @@ class ToolRuntimeVariablePool(BaseModel):
variable = cast(ToolRuntimeTextVariable, variable) variable = cast(ToolRuntimeTextVariable, variable)
variable.value = value variable.value = value
return return
variable = ToolRuntimeTextVariable( variable = ToolRuntimeTextVariable(
type=ToolRuntimeVariableType.TEXT, type=ToolRuntimeVariableType.TEXT,
name=name, name=name,
@ -334,7 +334,7 @@ class ToolRuntimeVariablePool(BaseModel):
variable = cast(ToolRuntimeImageVariable, variable) variable = cast(ToolRuntimeImageVariable, variable)
variable.value = value variable.value = value
return return
variable = ToolRuntimeImageVariable( variable = ToolRuntimeImageVariable(
type=ToolRuntimeVariableType.IMAGE, type=ToolRuntimeVariableType.IMAGE,
name=name, name=name,
@ -388,21 +388,21 @@ class ToolInvokeMeta(BaseModel):
Get an empty instance of ToolInvokeMeta Get an empty instance of ToolInvokeMeta
""" """
return cls(time_cost=0.0, error=None, tool_config={}) return cls(time_cost=0.0, error=None, tool_config={})
@classmethod @classmethod
def error_instance(cls, error: str) -> 'ToolInvokeMeta': def error_instance(cls, error: str) -> 'ToolInvokeMeta':
""" """
Get an instance of ToolInvokeMeta with error Get an instance of ToolInvokeMeta with error
""" """
return cls(time_cost=0.0, error=error, tool_config={}) return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
'time_cost': self.time_cost, 'time_cost': self.time_cost,
'error': self.error, 'error': self.error,
'tool_config': self.tool_config, 'tool_config': self.tool_config,
} }
class ToolLabel(BaseModel): class ToolLabel(BaseModel):
""" """
Tool label Tool label
@ -416,4 +416,4 @@ class ToolInvokeFrom(Enum):
Enum class for tool invoke Enum class for tool invoke
""" """
WORKFLOW = "workflow" WORKFLOW = "workflow"
AGENT = "agent" AGENT = "agent"

View File

@ -0,0 +1,2 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="24" height="25" viewBox="0 0 24 25" xmlns="http://www.w3.org/2000/svg" fill="none"><path fill="#FC6D26" d="M14.975 8.904L14.19 6.55l-1.552-4.67a.268.268 0 00-.255-.18.268.268 0 00-.254.18l-1.552 4.667H5.422L3.87 1.879a.267.267 0 00-.254-.179.267.267 0 00-.254.18l-1.55 4.667-.784 2.357a.515.515 0 00.193.583l6.78 4.812 6.778-4.812a.516.516 0 00.196-.583z"/><path fill="#E24329" d="M8 14.296l2.578-7.75H5.423L8 14.296z"/><path fill="#FC6D26" d="M8 14.296l-2.579-7.75H1.813L8 14.296z"/><path fill="#FCA326" d="M1.81 6.549l-.784 2.354a.515.515 0 00.193.583L8 14.3 1.81 6.55z"/><path fill="#E24329" d="M1.812 6.549h3.612L3.87 1.882a.268.268 0 00-.254-.18.268.268 0 00-.255.18L1.812 6.549z"/><path fill="#FC6D26" d="M8 14.296l2.578-7.75h3.614L8 14.296z"/><path fill="#FCA326" d="M14.19 6.549l.783 2.354a.514.514 0 01-.193.583L8 14.296l6.188-7.747h.001z"/><path fill="#E24329" d="M14.19 6.549H10.58l1.551-4.667a.267.267 0 01.255-.18c.115 0 .217.073.254.18l1.552 4.667z"/></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1,34 @@
from typing import Any
import requests
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class GitlabProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.")
if 'site_url' not in credentials or not credentials.get('site_url'):
site_url = 'https://gitlab.com'
else:
site_url = credentials.get('site_url')
try:
headers = {
"Content-Type": "application/vnd.text+json",
"Authorization": f"Bearer {credentials.get('access_tokens')}",
}
response = requests.get(
url= f"{site_url}/api/v4/user",
headers=headers)
if response.status_code != 200:
raise ToolProviderCredentialValidationError((response.json()).get('message'))
except Exception as e:
raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,38 @@
identity:
author: Leo.Wang
name: gitlab
label:
en_US: Gitlab
zh_Hans: Gitlab
description:
en_US: Gitlab plugin for commit
zh_Hans: 用于获取Gitlab commit的插件
icon: gitlab.svg
credentials_for_provider:
access_tokens:
type: secret-input
required: true
label:
en_US: Gitlab access token
zh_Hans: Gitlab access token
placeholder:
en_US: Please input your Gitlab access token
zh_Hans: 请输入你的 Gitlab access token
help:
en_US: Get your Gitlab access token from Gitlab
zh_Hans: 从 Gitlab 获取您的 access token
url: https://docs.gitlab.com/16.9/ee/api/oauth2.html
site_url:
type: text-input
required: false
default: 'https://gitlab.com'
label:
en_US: Gitlab site url
zh_Hans: Gitlab site url
placeholder:
en_US: Please input your Gitlab site url
zh_Hans: 请输入你的 Gitlab site url
help:
en_US: Find your Gitlab url
zh_Hans: 找到你的Gitlab url
url: https://gitlab.com/help

View File

@ -0,0 +1,101 @@
import json
from datetime import datetime, timedelta
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GitlabCommitsTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get('project', '')
employee = tool_parameters.get('employee', '')
start_time = tool_parameters.get('start_time', '')
end_time = tool_parameters.get('end_time', '')
if not project:
return self.create_text_message('Project is required')
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
if not end_time:
end_time = datetime.utcnow().isoformat()
access_token = self.runtime.credentials.get('access_tokens')
site_url = self.runtime.credentials.get('site_url')
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
return self.create_text_message("Gitlab API Access Tokens is required.")
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
site_url = 'https://gitlab.com'
# Get commit content
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time)
return self.create_text_message(json.dumps(result, ensure_ascii=False))
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# Get all of projects
url = f"{domain}/api/v4/projects"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
filtered_projects = [p for p in projects if project == "*" or p['name'] == project]
for project in filtered_projects:
project_id = project['id']
project_name = project['name']
print(f"Project: {project_name}")
# Get all of proejct commits
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
params = {
'since': start_time,
'until': end_time
}
if employee:
params['author'] = employee
commits_response = requests.get(commits_url, headers=headers, params=params)
commits_response.raise_for_status()
commits = commits_response.json()
for commit in commits:
commit_sha = commit['id']
print(f"\tCommit SHA: {commit_sha}")
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
diff_response = requests.get(diff_url, headers=headers)
diff_response.raise_for_status()
diffs = diff_response.json()
for diff in diffs:
# Caculate code lines of changed
added_lines = diff['diff'].count('\n+')
removed_lines = diff['diff'].count('\n-')
total_changes = added_lines + removed_lines
if total_changes > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
results.append({
"project": project_name,
"commit_sha": commit_sha,
"diff": final_code
})
print(f"Commit code:{final_code}")
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results

View File

@ -0,0 +1,56 @@
identity:
name: gitlab_commits
author: Leo.Wang
label:
en_US: Gitlab Commits
zh_Hans: Gitlab代码提交内容
description:
human:
en_US: A tool for query gitlab commits. Input should be a exists username.
zh_Hans: 一个用于查询gitlab代码提交记录的的工具输入的内容应该是一个已存在的用户名或者项目名。
llm: A tool for query gitlab commits. Input should be a exists username or project.
parameters:
- name: employee
type: string
required: false
label:
en_US: employee
zh_Hans: 员工用户名
human_description:
en_US: employee
zh_Hans: 员工用户名
llm_description: employee for gitlab
form: llm
- name: project
type: string
required: true
label:
en_US: project
zh_Hans: 项目名
human_description:
en_US: project
zh_Hans: 项目名
llm_description: project for gitlab
form: llm
- name: start_time
type: string
required: false
label:
en_US: start_time
zh_Hans: 开始时间
human_description:
en_US: start_time
zh_Hans: 开始时间
llm_description: start_time for gitlab
form: llm
- name: end_time
type: string
required: false
label:
en_US: end_time
zh_Hans: 结束时间
human_description:
en_US: end_time
zh_Hans: 结束时间
llm_description: end_time for gitlab
form: llm

View File

@ -4,13 +4,14 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from models.workflow import WorkflowNodeExecutionStatus from models import WorkflowNodeExecutionStatus
class NodeType(Enum): class NodeType(Enum):
""" """
Node Types. Node Types.
""" """
START = 'start' START = 'start'
END = 'end' END = 'end'
ANSWER = 'answer' ANSWER = 'answer'
@ -44,33 +45,11 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}') raise ValueError(f'invalid node type value {value}')
class SystemVariable(Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
@classmethod
def value_of(cls, value: str) -> 'SystemVariable':
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
class NodeRunMetadataKey(Enum): class NodeRunMetadataKey(Enum):
""" """
Node Run Metadata Key. Node Run Metadata Key.
""" """
TOTAL_TOKENS = 'total_tokens' TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price' TOTAL_PRICE = 'total_price'
CURRENCY = 'currency' CURRENCY = 'currency'
@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
""" """
Node Run Result. Node Run Result.
""" """
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs inputs: Optional[Mapping[str, Any]] = None # node inputs

View File

@ -6,7 +6,7 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable from core.workflow.enums import SystemVariable
VariableValue = Union[str, int, float, dict, list, FileVar] VariableValue = Union[str, int, float, dict, list, FileVar]

View File

@ -0,0 +1,25 @@
from enum import Enum
class SystemVariable(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')

View File

@ -133,9 +133,6 @@ class HttpRequestNode(BaseNode):
""" """
files = [] files = []
mimetype, file_binary = response.extract_file() mimetype, file_binary = response.extract_file()
# if not image, return directly
if 'image' not in mimetype:
return files
if mimetype: if mimetype:
# extract filename from url # extract filename from url

View File

@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import ( from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage, LLMNodeChatModelMessage,
@ -201,8 +202,8 @@ class LLMNode(BaseNode):
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()
return full_text, usage return full_text, usage
def _transform_chat_messages(self, def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
""" """
@ -249,13 +250,13 @@ class LLMNode(BaseNode):
# check if it's a context structure # check if it's a context structure
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
return d['content'] return d['content']
# else, parse the dict # else, parse the dict
try: try:
return json.dumps(d, ensure_ascii=False) return json.dumps(d, ensure_ascii=False)
except Exception: except Exception:
return str(d) return str(d)
if isinstance(value, str): if isinstance(value, str):
value = value value = value
elif isinstance(value, list): elif isinstance(value, list):

View File

@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
from os import path from os import path
from typing import Any, cast from typing import Any, cast
from core.app.segments import parser from core.app.segments import ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus from models import WorkflowNodeExecutionStatus
class ToolNode(BaseNode): class ToolNode(BaseNode):
@ -140,9 +141,9 @@ class ToolNode(BaseNode):
return result return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
# FIXME: ensure this is a ArrayVariable contains FileVariable.
variable = variable_pool.get(['sys', SystemVariable.FILES.value]) variable = variable_pool.get(['sys', SystemVariable.FILES.value])
return [file_var.value for file_var in variable.value] if variable else [] assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
""" """

View File

@ -3,6 +3,7 @@ import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
import contexts
from configs import dify_config from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -97,16 +98,16 @@ class WorkflowEngineManager:
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
callbacks: Sequence[WorkflowCallback], callbacks: Sequence[WorkflowCallback],
call_depth: int = 0, call_depth: int = 0,
variable_pool: VariablePool, variable_pool: VariablePool | None = None,
) -> None: ) -> None:
""" """
:param workflow: Workflow instance :param workflow: Workflow instance
:param user_id: user id :param user_id: user id
:param user_from: user from :param user_from: user from
:param user_inputs: user variables inputs :param invoke_from: invoke from
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks :param callbacks: workflow callbacks
:param call_depth: call depth :param call_depth: call depth
:param variable_pool: variable pool
""" """
# fetch workflow graph # fetch workflow graph
graph = workflow.graph_dict graph = workflow.graph_dict
@ -128,6 +129,8 @@ class WorkflowEngineManager:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
# init workflow run state # init workflow run state
if not variable_pool:
variable_pool = contexts.workflow_variable_pool.get()
workflow_run_state = WorkflowRunState( workflow_run_state = WorkflowRunState(
workflow=workflow, workflow=workflow,
start_at=time.perf_counter(), start_at=time.perf_counter(),

View File

@ -2,10 +2,10 @@
from abc import abstractmethod from abc import abstractmethod
import requests import requests
from api.models.source import DataSourceBearerBinding
from flask_login import current_user from flask_login import current_user
from extensions.ext_database import db from extensions.ext_database import db
from models.source import DataSourceBearerBinding
class BearerDataSource: class BearerDataSource:

View File

@ -0,0 +1,33 @@
"""add conversations.dialogue_count
Revision ID: 8782057ff0dc
Revises: 63a83fcf12ba
Create Date: 2024-08-14 13:54:25.161324
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '8782057ff0dc'
down_revision = '63a83fcf12ba'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.drop_column('dialogue_count')
# ### end Alembic commands ###

View File

@ -1,10 +1,10 @@
from enum import Enum from enum import Enum
from .model import AppMode from .model import App, AppMode, Message
from .types import StringUUID from .types import StringUUID
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus'] __all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message']
class CreatedByRole(Enum): class CreatedByRole(Enum):

View File

@ -7,6 +7,7 @@ from typing import Optional
from flask import request from flask import request
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy import Float, func, text from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column
from configs import dify_config from configs import dify_config
from core.file.tool_file_parser import ToolFileParser from core.file.tool_file_parser import ToolFileParser
@ -512,12 +513,12 @@ class Conversation(db.Model):
from_account_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID)
read_at = db.Column(db.DateTime) read_at = db.Column(db.DateTime)
read_account_id = db.Column(StringUUID) read_account_id = db.Column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
passive_deletes="all")
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))

40
api/poetry.lock generated
View File

@ -2100,6 +2100,44 @@ primp = ">=0.5.5"
dev = ["mypy (>=1.11.0)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.5.5)"] dev = ["mypy (>=1.11.0)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.5.5)"]
lxml = ["lxml (>=5.2.2)"] lxml = ["lxml (>=5.2.2)"]
[[package]]
name = "elastic-transport"
version = "8.15.0"
description = "Transport classes and utilities shared among Python Elastic client libraries"
optional = false
python-versions = ">=3.8"
files = [
{file = "elastic_transport-8.15.0-py3-none-any.whl", hash = "sha256:d7080d1dada2b4eee69e7574f9c17a76b42f2895eff428e562f94b0360e158c0"},
{file = "elastic_transport-8.15.0.tar.gz", hash = "sha256:85d62558f9baafb0868c801233a59b235e61d7b4804c28c2fadaa866b6766233"},
]
[package.dependencies]
certifi = "*"
urllib3 = ">=1.26.2,<3"
[package.extras]
develop = ["aiohttp", "furo", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"]
[[package]]
name = "elasticsearch"
version = "8.14.0"
description = "Python client for Elasticsearch"
optional = false
python-versions = ">=3.7"
files = [
{file = "elasticsearch-8.14.0-py3-none-any.whl", hash = "sha256:cef8ef70a81af027f3da74a4f7d9296b390c636903088439087b8262a468c130"},
{file = "elasticsearch-8.14.0.tar.gz", hash = "sha256:aa2490029dd96f4015b333c1827aa21fd6c0a4d223b00dfb0fe933b8d09a511b"},
]
[package.dependencies]
elastic-transport = ">=8.13,<9"
[package.extras]
async = ["aiohttp (>=3,<4)"]
orjson = ["orjson (>=3)"]
requests = ["requests (>=2.4.0,!=2.32.2,<3.0.0)"]
vectorstore-mmr = ["numpy (>=1)", "simsimd (>=3)"]
[[package]] [[package]]
name = "emoji" name = "emoji"
version = "2.12.1" version = "2.12.1"
@ -9546,4 +9584,4 @@ cffi = ["cffi (>=1.11)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "2b822039247a445f72e04e967aef84f841781e2789b70071acad022f36ba26a5" content-hash = "05dfa6b9bce9ed8ac21caf58eff1596f146080ab2ab6987924b189be673c22cf"

View File

@ -181,6 +181,7 @@ zhipuai = "1.0.7"
rank-bm25 = "~0.2.2" rank-bm25 = "~0.2.2"
openpyxl = "^3.1.5" openpyxl = "^3.1.5"
kaleido = "0.2.1" kaleido = "0.2.1"
elasticsearch = "8.14.0"
############################################################ ############################################################
# Tool dependencies required by tool implementations # Tool dependencies required by tool implementations

View File

@ -13,9 +13,9 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_dsl_version = "0.1.0" current_dsl_version = "0.1.1"
dsl_to_dify_version_mapping: dict[str, str] = { dsl_to_dify_version_mapping: dict[str, str] = {
"0.1.0": "0.6.0", # dsl version -> from dify version "0.1.1": "0.6.0", # dsl version -> from dify version
} }

View File

@ -1,5 +1,4 @@
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
class MockTEIClass: class MockTEIClass:
@ -12,7 +11,7 @@ class MockTEIClass:
model_type = 'embedding' model_type = 'embedding'
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
@staticmethod @staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
# Use space as token separator, and split the text into tokens # Use space as token separator, and split the text into tokens

View File

@ -1,12 +1,12 @@
import os import os
import pytest import pytest
from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import ( from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
HuggingfaceTeiTextEmbeddingModel, HuggingfaceTeiTextEmbeddingModel,
TeiHelper,
) )
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass

View File

@ -0,0 +1,25 @@
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class ElasticSearchVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self.vector = ElasticSearchVector(
index_name=self.collection_name.lower(),
config=ElasticSearchConfig(
host='http://localhost',
port='9200',
username='elastic',
password='elastic'
),
attributes=self.attributes
)
def test_elasticsearch_vector(setup_mock_redis):
ElasticSearchVectorTest().run_all_tests()

View File

@ -10,8 +10,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db from extensions.ext_database import db
@ -236,4 +236,4 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert 'sunny' in json.dumps(result.process_data) assert 'sunny' in json.dumps(result.process_data)
assert 'what\'s the weather today?' in json.dumps(result.process_data) assert 'what\'s the weather today?' in json.dumps(result.process_data)

View File

@ -12,8 +12,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db from extensions.ext_database import db
@ -363,7 +363,7 @@ def test_extract_json_response():
{ {
"location": "kawaii" "location": "kawaii"
} }
hello world. hello world.
""") """)
assert result['location'] == 'kawaii' assert result['location'] == 'kawaii'
@ -445,4 +445,4 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
assert latest_role != prompt.get('role') assert latest_role != prompt.get('role')
if prompt.get('role') in ['user', 'assistant']: if prompt.get('role') in ['user', 'assistant']:
latest_role = prompt.get('role') latest_role = prompt.get('role')

View File

@ -3,12 +3,9 @@ from uuid import uuid4
import pytest import pytest
from core.app.segments import ( from core.app.segments import (
ArrayFileVariable,
ArrayNumberVariable, ArrayNumberVariable,
ArrayObjectVariable, ArrayObjectVariable,
ArrayStringVariable, ArrayStringVariable,
FileSegment,
FileVariable,
FloatVariable, FloatVariable,
IntegerVariable, IntegerVariable,
ObjectSegment, ObjectSegment,
@ -149,83 +146,6 @@ def test_array_object_variable():
assert isinstance(variable.value[1]['key2'], int) assert isinstance(variable.value[1]['key2'], int)
def test_file_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'file',
'name': 'test_file',
'description': 'Description of the variable.',
'value': {
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, FileVariable)
def test_array_file_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[file]',
'name': 'test_array_file',
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
{
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayFileVariable)
assert isinstance(variable.value[0], FileSegment)
assert isinstance(variable.value[1], FileSegment)
def test_variable_cannot_large_than_5_kb(): def test_variable_cannot_large_than_5_kb():
with pytest.raises(VariableError): with pytest.raises(VariableError):
factory.build_variable_from_mapping( factory.build_variable_from_mapping(

View File

@ -1,7 +1,7 @@
from core.app.segments import SecretVariable, StringSegment, parser from core.app.segments import SecretVariable, StringSegment, parser
from core.helper import encrypter from core.helper import encrypter
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
def test_segment_group_to_text(): def test_segment_group_to_text():

View File

@ -1,8 +1,8 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -1,8 +1,8 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -3,8 +3,8 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode

View File

@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \ api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate \ api/tests/integration_tests/vdb/weaviate \
api/tests/integration_tests/vdb/elasticsearch \
api/tests/integration_tests/vdb/test_vector_store.py api/tests/integration_tests/vdb/test_vector_store.py

View File

@ -2,7 +2,7 @@ version: '3'
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:0.6.16 image: langgenius/dify-api:0.7.0
restart: always restart: always
environment: environment:
# Startup mode, 'api' starts the API server. # Startup mode, 'api' starts the API server.
@ -169,6 +169,11 @@ services:
CHROMA_DATABASE: default_database CHROMA_DATABASE: default_database
CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
CHROMA_AUTH_CREDENTIALS: xxxxxx CHROMA_AUTH_CREDENTIALS: xxxxxx
# ElasticSearch Config
ELASTICSEARCH_HOST: 127.0.0.1
ELASTICSEARCH_PORT: 9200
ELASTICSEARCH_USERNAME: elastic
ELASTICSEARCH_PASSWORD: elastic
# Mail configuration, support: resend, smtp # Mail configuration, support: resend, smtp
MAIL_TYPE: '' MAIL_TYPE: ''
# default send from email address, if not specified # default send from email address, if not specified
@ -224,7 +229,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:0.6.16 image: langgenius/dify-api:0.7.0
restart: always restart: always
environment: environment:
CONSOLE_WEB_URL: '' CONSOLE_WEB_URL: ''
@ -371,6 +376,11 @@ services:
CHROMA_DATABASE: default_database CHROMA_DATABASE: default_database
CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
CHROMA_AUTH_CREDENTIALS: xxxxxx CHROMA_AUTH_CREDENTIALS: xxxxxx
# ElasticSearch Config
ELASTICSEARCH_HOST: 127.0.0.1
ELASTICSEARCH_PORT: 9200
ELASTICSEARCH_USERNAME: elastic
ELASTICSEARCH_PASSWORD: elastic
# Notion import configuration, support public and internal # Notion import configuration, support public and internal
NOTION_INTEGRATION_TYPE: public NOTION_INTEGRATION_TYPE: public
NOTION_CLIENT_SECRET: you-client-secret NOTION_CLIENT_SECRET: you-client-secret
@ -390,7 +400,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:0.6.16 image: langgenius/dify-web:0.7.0
restart: always restart: always
environment: environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

View File

@ -125,6 +125,10 @@ x-shared-env: &shared-api-worker-env
CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database} CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database}
CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider} CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider}
CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-} CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-}
ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-127.0.0.1}
ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200}
ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic}
ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
# AnalyticDB configuration # AnalyticDB configuration
ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-}
ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-} ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-}
@ -187,7 +191,7 @@ x-shared-env: &shared-api-worker-env
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:0.6.16 image: langgenius/dify-api:0.7.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -207,7 +211,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:0.6.16 image: langgenius/dify-api:0.7.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -226,7 +230,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:0.6.16 image: langgenius/dify-web:0.7.0
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -583,7 +587,7 @@ services:
# MyScale vector database # MyScale vector database
myscale: myscale:
container_name: myscale container_name: myscale
image: myscale/myscaledb:1.6 image: myscale/myscaledb:1.6.4
profiles: profiles:
- myscale - myscale
restart: always restart: always
@ -595,6 +599,27 @@ services:
ports: ports:
- "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}" - "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}"
elasticsearch:
image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3
container_name: elasticsearch
profiles:
- elasticsearch
restart: always
environment:
- "ELASTIC_PASSWORD=${ELASTICSEARCH_USERNAME:-elastic}"
- "cluster.name=dify-es-cluster"
- "node.name=dify-es0"
- "discovery.type=single-node"
- "xpack.security.http.ssl.enabled=false"
- "xpack.license.self_generated.type=trial"
ports:
- "${ELASTICSEARCH_PORT:-9200}:${ELASTICSEARCH_PORT:-9200}"
healthcheck:
test: ["CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty"]
interval: 30s
timeout: 10s
retries: 50
# unstructured . # unstructured .
# (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.) # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.)
unstructured: unstructured:

View File

@ -922,6 +922,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='dataset_id' type='string' key='dataset_id'> <Property name='dataset_id' type='string' key='dataset_id'>
Knowledge ID Knowledge ID
</Property> </Property>
<Property name='document_id' type='string' key='document_id'>
Document ID
</Property>
<Property name='segment_id' type='string' key='segment_id'> <Property name='segment_id' type='string' key='segment_id'>
Document Segment ID Document Segment ID
</Property> </Property>
@ -965,6 +968,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='dataset_id' type='string' key='dataset_id'> <Property name='dataset_id' type='string' key='dataset_id'>
Knowledge ID Knowledge ID
</Property> </Property>
<Property name='document_id' type='string' key='document_id'>
Document ID
</Property>
<Property name='segment_id' type='string' key='segment_id'> <Property name='segment_id' type='string' key='segment_id'>
Document Segment ID Document Segment ID
</Property> </Property>

View File

@ -922,6 +922,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='dataset_id' type='string' key='dataset_id'> <Property name='dataset_id' type='string' key='dataset_id'>
知识库 ID 知识库 ID
</Property> </Property>
<Property name='document_id' type='string' key='document_id'>
文档 ID
</Property>
<Property name='segment_id' type='string' key='segment_id'> <Property name='segment_id' type='string' key='segment_id'>
文档分段ID 文档分段ID
</Property> </Property>
@ -965,6 +968,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='dataset_id' type='string' key='dataset_id'> <Property name='dataset_id' type='string' key='dataset_id'>
知识库 ID 知识库 ID
</Property> </Property>
<Property name='document_id' type='string' key='document_id'>
文档 ID
</Property>
<Property name='segment_id' type='string' key='segment_id'> <Property name='segment_id' type='string' key='segment_id'>
文档分段ID 文档分段ID
</Property> </Property>

View File

@ -1,4 +1,5 @@
import ReactMarkdown from 'react-markdown' import ReactMarkdown from 'react-markdown'
import ReactEcharts from 'echarts-for-react'
import 'katex/dist/katex.min.css' import 'katex/dist/katex.min.css'
import RemarkMath from 'remark-math' import RemarkMath from 'remark-math'
import RemarkBreaks from 'remark-breaks' import RemarkBreaks from 'remark-breaks'
@ -30,6 +31,7 @@ const capitalizationLanguageNameMap: Record<string, string> = {
mermaid: 'Mermaid', mermaid: 'Mermaid',
markdown: 'MarkDown', markdown: 'MarkDown',
makefile: 'MakeFile', makefile: 'MakeFile',
echarts: 'ECharts',
} }
const getCorrectCapitalizationLanguageName = (language: string) => { const getCorrectCapitalizationLanguageName = (language: string) => {
if (!language) if (!language)
@ -107,6 +109,14 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }
const match = /language-(\w+)/.exec(className || '') const match = /language-(\w+)/.exec(className || '')
const language = match?.[1] const language = match?.[1]
const languageShowName = getCorrectCapitalizationLanguageName(language || '') const languageShowName = getCorrectCapitalizationLanguageName(language || '')
let chartData = JSON.parse(String('{"title":{"text":"Something went wrong."}}').replace(/\n$/, ''))
if (language === 'echarts') {
try {
chartData = JSON.parse(String(children).replace(/\n$/, ''))
}
catch (error) {
}
}
// Use `useMemo` to ensure that `SyntaxHighlighter` only re-renders when necessary // Use `useMemo` to ensure that `SyntaxHighlighter` only re-renders when necessary
return useMemo(() => { return useMemo(() => {
@ -136,19 +146,25 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }
</div> </div>
{(language === 'mermaid' && isSVG) {(language === 'mermaid' && isSVG)
? (<Flowchart PrimitiveCode={String(children).replace(/\n$/, '')} />) ? (<Flowchart PrimitiveCode={String(children).replace(/\n$/, '')} />)
: (<SyntaxHighlighter : (
{...props} (language === 'echarts')
style={atelierHeathLight} ? (<div style={{ minHeight: '250px', minWidth: '250px' }}><ReactEcharts
customStyle={{ option={chartData}
paddingLeft: 12, >
backgroundColor: '#fff', </ReactEcharts></div>)
}} : (<SyntaxHighlighter
language={match[1]} {...props}
showLineNumbers style={atelierHeathLight}
PreTag="div" customStyle={{
> paddingLeft: 12,
{String(children).replace(/\n$/, '')} backgroundColor: '#fff',
</SyntaxHighlighter>)} }}
language={match[1]}
showLineNumbers
PreTag="div"
>
{String(children).replace(/\n$/, '')}
</SyntaxHighlighter>))}
</div> </div>
) )
: ( : (

View File

@ -329,36 +329,36 @@ const EditCustomCollectionModal: FC<Props> = ({
<Button variant='primary' onClick={handleSave}>{t('common.operation.save')}</Button> <Button variant='primary' onClick={handleSave}>{t('common.operation.save')}</Button>
</div> </div>
</div> </div>
{showEmojiPicker && <EmojiPicker
onSelect={(icon, icon_background) => {
setEmoji({ content: icon, background: icon_background })
setShowEmojiPicker(false)
}}
onClose={() => {
setShowEmojiPicker(false)
}}
/>}
{credentialsModalShow && (
<ConfigCredentials
positionCenter={isAdd}
credential={credential}
onChange={setCredential}
onHide={() => setCredentialsModalShow(false)}
/>)
}
{isShowTestApi && (
<TestApi
positionCenter={isAdd}
tool={currTool as CustomParamSchema}
customCollection={customCollection}
onHide={() => setIsShowTestApi(false)}
/>
)}
</div> </div>
} }
isShowMask={true} isShowMask={true}
clickOutsideNotOpen={true} clickOutsideNotOpen={true}
/> />
{showEmojiPicker && <EmojiPicker
onSelect={(icon, icon_background) => {
setEmoji({ content: icon, background: icon_background })
setShowEmojiPicker(false)
}}
onClose={() => {
setShowEmojiPicker(false)
}}
/>}
{credentialsModalShow && (
<ConfigCredentials
positionCenter={isAdd}
credential={credential}
onChange={setCredential}
onHide={() => setCredentialsModalShow(false)}
/>)
}
{isShowTestApi && (
<TestApi
positionCenter={isAdd}
tool={currTool as CustomParamSchema}
customCollection={customCollection}
onHide={() => setIsShowTestApi(false)}
/>
)}
</> </>
) )

View File

@ -7,12 +7,16 @@ import {
import type { ValueSelector, Var } from '@/app/components/workflow/types' import type { ValueSelector, Var } from '@/app/components/workflow/types'
type Params = { type Params = {
onlyLeafNodeVar?: boolean onlyLeafNodeVar?: boolean
hideEnv?: boolean
hideChatVar?: boolean
filterVar: (payload: Var, selector: ValueSelector) => boolean filterVar: (payload: Var, selector: ValueSelector) => boolean
} }
const useAvailableVarList = (nodeId: string, { const useAvailableVarList = (nodeId: string, {
onlyLeafNodeVar, onlyLeafNodeVar,
filterVar, filterVar,
hideEnv,
hideChatVar,
}: Params = { }: Params = {
onlyLeafNodeVar: false, onlyLeafNodeVar: false,
filterVar: () => true, filterVar: () => true,
@ -32,6 +36,8 @@ const useAvailableVarList = (nodeId: string, {
beforeNodes: availableNodes, beforeNodes: availableNodes,
isChatMode, isChatMode,
filterVar, filterVar,
hideEnv,
hideChatVar,
}) })
return { return {

View File

@ -23,6 +23,8 @@ const Panel: FC<NodePanelProps<AnswerNodeType>> = ({
const { availableVars, availableNodesWithParent } = useAvailableVarList(id, { const { availableVars, availableNodesWithParent } = useAvailableVarList(id, {
onlyLeafNodeVar: false, onlyLeafNodeVar: false,
hideChatVar: true,
hideEnv: true,
filterVar, filterVar,
}) })

View File

@ -62,30 +62,25 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
const formatNodeList = useCallback((list: NodeTracing[]) => { const formatNodeList = useCallback((list: NodeTracing[]) => {
const allItems = list.reverse() const allItems = list.reverse()
const result: NodeTracing[] = [] const result: NodeTracing[] = []
let iterationIndexInfos: { let iterationIndex = 0
start: number
end: number
}[] = []
allItems.forEach((item) => { allItems.forEach((item) => {
const { node_type, index, execution_metadata } = item const { node_type, execution_metadata } = item
if (node_type !== BlockEnum.Iteration) { if (node_type !== BlockEnum.Iteration) {
let isInIteration = false const isInIteration = !!execution_metadata?.iteration_id
let isIterationFirstNode = false
iterationIndexInfos.forEach(({ start, end }) => {
if (index >= start && index < end) {
if (index === start)
isIterationFirstNode = true
isInIteration = true
}
})
if (isInIteration) { if (isInIteration) {
const iterationDetails = result[result.length - 1].details! const iterationDetails = result[result.length - 1].details!
if (isIterationFirstNode) const currentIterationIndex = execution_metadata?.iteration_index
iterationDetails!.push([item]) const isIterationFirstNode = iterationIndex !== currentIterationIndex || iterationDetails.length === 0
else if (isIterationFirstNode) {
iterationDetails!.push([item])
iterationIndex = currentIterationIndex!
}
else {
iterationDetails[iterationDetails.length - 1].push(item) iterationDetails[iterationDetails.length - 1].push(item)
}
return return
} }
@ -95,26 +90,6 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
return return
} }
const { steps_boundary } = execution_metadata
iterationIndexInfos = []
steps_boundary.forEach((boundary, index) => {
if (index === 0) {
iterationIndexInfos.push({
start: boundary,
end: 0,
})
}
else if (index === steps_boundary.length - 1) {
iterationIndexInfos[iterationIndexInfos.length - 1].end = boundary
}
else {
iterationIndexInfos[iterationIndexInfos.length - 1].end = boundary
iterationIndexInfos.push({
start: boundary,
end: 0,
})
}
})
result.push({ result.push({
...item, ...item,
details: [], details: [],

View File

@ -123,7 +123,7 @@ const NodePanel: FC<Props> = ({
<div <div
className='flex items-center h-[34px] justify-between px-3 bg-gray-100 border-[0.5px] border-gray-200 rounded-lg cursor-pointer' className='flex items-center h-[34px] justify-between px-3 bg-gray-100 border-[0.5px] border-gray-200 rounded-lg cursor-pointer'
onClick={handleOnShowIterationDetail}> onClick={handleOnShowIterationDetail}>
<div className='leading-[18px] text-[13px] font-medium text-gray-700'>{t('workflow.nodes.iteration.iteration', { count: nodeInfo.metadata?.iterator_length || (nodeInfo.execution_metadata?.steps_boundary?.length - 1) })}</div> <div className='leading-[18px] text-[13px] font-medium text-gray-700'>{t('workflow.nodes.iteration.iteration', { count: nodeInfo.metadata?.iterator_length })}</div>
{justShowIterationNavArrow {justShowIterationNavArrow
? ( ? (
<RiArrowRightSLine className='w-3.5 h-3.5 text-gray-500' /> <RiArrowRightSLine className='w-3.5 h-3.5 text-gray-500' />

View File

@ -94,11 +94,38 @@ const translation = {
}, },
export: { export: {
title: 'シークレット環境変数をエクスポートしますか?', title: 'シークレット環境変数をエクスポートしますか?',
checkbox: 'シクレート値をエクスポート', checkbox: 'シークレット値をエクスポート',
ignore: 'DSLをエクスポート', ignore: 'DSLをエクスポート',
export: 'シクレート値を含むDSLをエクスポート', export: 'シークレット値を含むDSLをエクスポート',
}, },
}, },
chatVariable: {
panelTitle: '会話変数',
panelDescription: '会話変数は、LLMが記憶すべき対話情報を保存するために使用されます。この情報には、対話の履歴、アップロードされたファイル、ユーザーの好みなどが含まれます。読み書きが可能です。',
docLink: '詳しくはドキュメントをご覧ください。',
button: '変数を追加',
modal: {
title: '会話変数を追加',
editTitle: '会話変数を編集',
name: '名前',
namePlaceholder: '変数名前',
type: 'タイプ',
value: 'デフォルト値',
valuePlaceholder: 'デフォルト値、設定しない場合は空白にしでください',
description: '説明',
descriptionPlaceholder: '変数の説明',
editInJSON: 'JSONで編集する',
oneByOne: '次々に追加する',
editInForm: 'フォームで編集',
arrayValue: '値',
addArrayValue: '値を追加',
objectKey: 'キー',
objectType: 'タイプ',
objectValue: 'デフォルト値',
},
storedContent: '保存されたコンテンツ',
updatedAt: '更新日は',
},
changeHistory: { changeHistory: {
title: '変更履歴', title: '変更履歴',
placeholder: 'まだ何も変更していません', placeholder: 'まだ何も変更していません',
@ -149,6 +176,7 @@ const translation = {
tabs: { tabs: {
'searchBlock': 'ブロックを検索', 'searchBlock': 'ブロックを検索',
'blocks': 'ブロック', 'blocks': 'ブロック',
'searchTool': '検索ツール',
'tools': 'ツール', 'tools': 'ツール',
'allTool': 'すべて', 'allTool': 'すべて',
'workflowTool': 'ワークフロー', 'workflowTool': 'ワークフロー',
@ -171,8 +199,9 @@ const translation = {
'code': 'コード', 'code': 'コード',
'template-transform': 'テンプレート', 'template-transform': 'テンプレート',
'http-request': 'HTTPリクエスト', 'http-request': 'HTTPリクエスト',
'variable-assigner': '変数代入', 'variable-assigner': '変数代入',
'variable-aggregator': '変数集約器', 'variable-aggregator': '変数集約器',
'assigner': '変数代入',
'iteration-start': 'イテレーション開始', 'iteration-start': 'イテレーション開始',
'iteration': 'イテレーション', 'iteration': 'イテレーション',
'parameter-extractor': 'パラメーター抽出', 'parameter-extractor': 'パラメーター抽出',
@ -189,6 +218,7 @@ const translation = {
'template-transform': 'Jinjaテンプレート構文を使用してデータを文字列に変換します', 'template-transform': 'Jinjaテンプレート構文を使用してデータを文字列に変換します',
'http-request': 'HTTPプロトコル経由でサーバーリクエストを送信できます', 'http-request': 'HTTPプロトコル経由でサーバーリクエストを送信できます',
'variable-assigner': '複数のブランチの変数を1つの変数に集約し、下流のードに対して統一された設定を行います。', 'variable-assigner': '複数のブランチの変数を1つの変数に集約し、下流のードに対して統一された設定を行います。',
'assigner': '変数代入ノードは、書き込み可能な変数(例えば、会話変数)に値を割り当てるために使用されます。',
'variable-aggregator': '複数のブランチの変数を1つの変数に集約し、下流のードに対して統一された設定を行います。', 'variable-aggregator': '複数のブランチの変数を1つの変数に集約し、下流のードに対して統一された設定を行います。',
'iteration': 'リストオブジェクトに対して複数のステップを実行し、すべての結果が出力されるまで繰り返します。', 'iteration': 'リストオブジェクトに対して複数のステップを実行し、すべての結果が出力されるまで繰り返します。',
'parameter-extractor': '自然言語からツールの呼び出しやHTTPリクエストのための構造化されたパラメーターを抽出するためにLLMを使用します。', 'parameter-extractor': '自然言語からツールの呼び出しやHTTPリクエストのための構造化されたパラメーターを抽出するためにLLMを使用します。',
@ -215,6 +245,7 @@ const translation = {
checklistResolved: 'すべての問題が解決されました', checklistResolved: 'すべての問題が解決されました',
organizeBlocks: 'ブロックを整理', organizeBlocks: 'ブロックを整理',
change: '変更', change: '変更',
optional: '(オプション)',
}, },
nodes: { nodes: {
common: { common: {
@ -406,6 +437,17 @@ const translation = {
}, },
setAssignVariable: '代入された変数を設定', setAssignVariable: '代入された変数を設定',
}, },
assigner: {
'assignedVariable': '代入された変数',
'writeMode': '書き込みモード',
'writeModeTip': '代入された変数が配列の場合, 末尾に追記モードを追加する。',
'over-write': '上書き',
'append': '追記',
'plus': 'プラス',
'clear': 'クリア',
'setVariable': '変数を設定する',
'variable': '変数',
},
tool: { tool: {
toAuthorize: '承認するには', toAuthorize: '承認するには',
inputVars: '入力変数', inputVars: '入力変数',

View File

@ -1,6 +1,6 @@
{ {
"name": "dify-web", "name": "dify-web",
"version": "0.6.16", "version": "0.7.0",
"private": true, "private": true,
"engines": { "engines": {
"node": ">=18.17.0" "node": ">=18.17.0"

View File

@ -24,7 +24,8 @@ export type NodeTracing = {
total_tokens: number total_tokens: number
total_price: number total_price: number
currency: string currency: string
steps_boundary: number[] iteration_id?: string
iteration_index?: number
} }
metadata: { metadata: {
iterator_length: number iterator_length: number