From cd17ce9250d8c2b219eef391851e9dfa9e7eac5e Mon Sep 17 00:00:00 2001 From: kurokobo Date: Wed, 16 Apr 2025 10:54:03 +0900 Subject: [PATCH 01/68] fix: start api and worker after the database has become healthy (#18109) --- docker/docker-compose-template.yaml | 12 ++++++++---- docker/docker-compose.yaml | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index e933dd85c7..86976063c3 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -17,8 +17,10 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: - - db - - redis + db: + condition: service_healthy + redis: + condition: service_started volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage @@ -42,8 +44,10 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: - - db - - redis + db: + condition: service_healthy + redis: + condition: service_started volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index b322015961..e9c8c8715a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -485,8 +485,10 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: - - db - - redis + db: + condition: service_healthy + redis: + condition: service_started volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage @@ -510,8 +512,10 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: - - db - - redis + db: + condition: service_healthy + redis: + condition: service_started volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage From aead48726e8b092430c74492ae3ab33116d03cb6 Mon Sep 17 00:00:00 2001 From: Jimmyy <43959617+Jimmy0769@users.noreply.github.com> Date: Wed, 16 Apr 2025 09:56:46 +0800 Subject: [PATCH 02/68] fix: cannot regenerate with image(#15060) (#16611) Co-authored-by: werido <359066432@qq.com> --- api/fields/conversation_fields.py | 1 + api/models/model.py | 2 +- web/app/components/base/chat/chat-with-history/hooks.tsx | 4 ++-- web/app/components/base/file-uploader/utils.ts | 2 +- web/types/workflow.ts | 1 + 5 files changed, 6 insertions(+), 4 deletions(-) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 80d1f0baf5..78e0794833 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -42,6 +42,7 @@ message_file_fields = { "size": fields.Integer, "transfer_method": fields.String, "belongs_to": fields.String(default="user"), + "upload_file_id": fields.String(default=None), } agent_thought_fields = { diff --git a/api/models/model.py b/api/models/model.py index dfc1322d92..a826d13e7d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1155,7 +1155,7 @@ class Message(db.Model): # type: ignore[name-defined] files.append(file) result = [ - {"belongs_to": message_file.belongs_to, **file.to_dict()} + {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} for (file, message_file) in zip(files, message_files) ] diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 0a4cbae964..9afaca2568 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -52,7 +52,7 @@ function getFormattedChatList(messages: any[]) { id: `question-${item.id}`, content: item.query, isAnswer: false, - message_files: getProcessedFilesFromResponse(questionFiles.map((item: any) => ({ ...item, related_id: item.id }))), + message_files: getProcessedFilesFromResponse(questionFiles.map((item: any) => ({ ...item, related_id: item.id, upload_file_id: item.upload_file_id }))), parentMessageId: item.parent_message_id || undefined, }) const answerFiles = item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [] @@ -63,7 +63,7 @@ function getFormattedChatList(messages: any[]) { feedback: item.feedback, isAnswer: true, citation: item.retriever_resources, - message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))), + message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id, upload_file_id: item.upload_file_id }))), parentMessageId: `question-${item.id}`, }) }) diff --git a/web/app/components/base/file-uploader/utils.ts b/web/app/components/base/file-uploader/utils.ts index e095d4aa93..e05c0b2087 100644 --- a/web/app/components/base/file-uploader/utils.ts +++ b/web/app/components/base/file-uploader/utils.ts @@ -134,7 +134,7 @@ export const getProcessedFilesFromResponse = (files: FileResponse[]) => { progress: 100, transferMethod: fileItem.transfer_method, supportFileType: fileItem.type, - uploadedId: fileItem.related_id, + uploadedId: fileItem.upload_file_id || fileItem.related_id, url: fileItem.url, } }) diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 43af64e5ca..bd7334a261 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -197,6 +197,7 @@ export type FileResponse = { transfer_method: TransferMethod type: string url: string + upload_file_id: string } export type NodeFinishedResponse = { From 57b28576f02c432bf7554ac5b021d7db9196fd80 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Wed, 16 Apr 2025 11:55:19 +0800 Subject: [PATCH 03/68] chore: remove unused poetry.toml (#18112) --- api/poetry.toml | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 api/poetry.toml diff --git a/api/poetry.toml b/api/poetry.toml deleted file mode 100644 index 9a48dd825a..0000000000 --- a/api/poetry.toml +++ /dev/null @@ -1,4 +0,0 @@ -[virtualenvs] -in-project = true -create = true -prefer-active-python = true \ No newline at end of file From 2a0d7533d7912ef03f805c8151064378f8503704 Mon Sep 17 00:00:00 2001 From: AichiB7A Date: Wed, 16 Apr 2025 11:55:37 +0800 Subject: [PATCH 04/68] [Unit Test] Generate coverage number for UT (#18106) --- .github/workflows/api-tests.yml | 12 +++++++- .gitignore | 1 + api/pyproject.toml | 1 + api/pytest.ini | 1 + api/uv.lock | 49 +++++++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+), 1 deletion(-) diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 7da370e283..02583cda06 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -45,7 +45,17 @@ jobs: run: uv sync --project api --dev - name: Run Unit tests - run: uv run --project api bash dev/pytest/pytest_unit_tests.sh + run: | + uv run --project api bash dev/pytest/pytest_unit_tests.sh + # Extract coverage percentage and create a summary + TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])') + + # Create a detailed coverage summary + echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY + echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + uv run --project api coverage report >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY - name: Run dify config tests run: uv run --project api dev/pytest/pytest_config_tests.py diff --git a/.gitignore b/.gitignore index 819a249581..8818ab6f65 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ htmlcov/ .cache nosetests.xml coverage.xml +coverage.json *.cover *.py,cover .hypothesis/ diff --git a/api/pyproject.toml b/api/pyproject.toml index 1dc5a7cc7c..85679a6359 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -104,6 +104,7 @@ dev = [ "ruff~=0.11.5", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", + "pytest-cov~=4.1.0", "pytest-env~=1.1.3", "pytest-mock~=3.14.0", "types-aiofiles~=24.1.0", diff --git a/api/pytest.ini b/api/pytest.ini index 3de1649798..618e921825 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,5 +1,6 @@ [pytest] continue-on-collection-errors = true +addopts = --cov=./api --cov-report=json --cov-report=xml env = ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com diff --git a/api/uv.lock b/api/uv.lock index ac77d8e8e5..4ff9c34446 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1012,6 +1012,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/00/14b00a0748e9eda26e97be07a63cc911108844004687321ddcc213be956c/coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3", size = 204347 }, ] +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + [[package]] name = "crc32c" version = "2.7.1" @@ -1234,6 +1239,7 @@ dev = [ { name = "mypy" }, { name = "pytest" }, { name = "pytest-benchmark" }, + { name = "pytest-cov" }, { name = "pytest-env" }, { name = "pytest-mock" }, { name = "ruff" }, @@ -1401,6 +1407,7 @@ dev = [ { name = "mypy", specifier = "~=1.15.0" }, { name = "pytest", specifier = "~=8.3.2" }, { name = "pytest-benchmark", specifier = "~=4.0.0" }, + { name = "pytest-cov", specifier = "~=4.1.0" }, { name = "pytest-env", specifier = "~=1.1.3" }, { name = "pytest-mock", specifier = "~=3.14.0" }, { name = "ruff", specifier = "~=0.11.5" }, @@ -4333,6 +4340,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4d/a1/3b70862b5b3f830f0422844f25a823d0470739d994466be9dbbbb414d85a/pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6", size = 43951 }, ] +[[package]] +name = "pytest-cov" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7a/15/da3df99fd551507694a9b01f512a2f6cf1254f33601605843c3775f39460/pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6", size = 63245 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a", size = 21949 }, +] + [[package]] name = "pytest-env" version = "1.1.5" @@ -5235,6 +5255,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588 }, ] +[[package]] +name = "tomli" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 }, + { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 }, + { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 }, + { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 }, + { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 }, + { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 }, + { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 }, + { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 }, + { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 }, + { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 }, + { url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762 }, + { url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453 }, + { url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486 }, + { url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349 }, + { url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159 }, + { url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243 }, + { url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645 }, + { url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584 }, + { url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875 }, + { url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418 }, + { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, +] + [[package]] name = "tos" version = "2.7.2" From 95283b4dd3b34132050cb8b625daf2c463a4d1ff Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 16 Apr 2025 12:28:22 +0800 Subject: [PATCH 05/68] Feat/change split length method (#18097) Co-authored-by: JzoNg --- api/core/rag/splitter/fixed_text_splitter.py | 10 ++++++++-- api/services/dataset_service.py | 2 +- .../components/datasets/create/step-two/index.tsx | 14 +++++++------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 67f9b6384d..0fb1bcb2e0 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -39,6 +39,12 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): else: return [GPT2Tokenizer.get_num_tokens(text) for text in texts] + def _character_encoder(texts: list[str]) -> list[int]: + if not texts: + return [] + + return [len(text) for text in texts] + if issubclass(cls, TokenTextSplitter): extra_kwargs = { "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", @@ -47,7 +53,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): } kwargs = {**kwargs, **extra_kwargs} - return cls(length_function=_token_encoder, **kwargs) + return cls(length_function=_character_encoder, **kwargs) class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): @@ -103,7 +109,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) _good_splits_lengths = [] # cache the lengths of the splits _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) - if _separator != "": + if separator != "": for s, s_len in zip(splits, s_lens): if s_len < self._chunk_size: _good_splits.append(s) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0301c8a584..deb6be5a43 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -553,7 +553,7 @@ class DocumentService: {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, ], - "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + "segmentation": {"delimiter": "\n", "max_tokens": 1024, "chunk_overlap": 50}, }, "limits": { "indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH, diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 12fd54d0fe..6b6580ae7e 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -97,7 +97,7 @@ export enum IndexingType { } const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' -const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500 +const DEFAULT_MAXIMUM_CHUNK_LENGTH = 1024 const DEFAULT_OVERLAP = 50 const MAXIMUM_CHUNK_TOKEN_LENGTH = Number.parseInt(globalThis.document?.body?.getAttribute('data-public-indexing-max-segmentation-tokens-length') || '4000', 10) @@ -117,11 +117,11 @@ const defaultParentChildConfig: ParentChildConfig = { chunkForContext: 'paragraph', parent: { delimiter: '\\n\\n', - maxLength: 500, + maxLength: 1024, }, child: { delimiter: '\\n', - maxLength: 200, + maxLength: 512, }, } @@ -623,12 +623,12 @@ const StepTwo = ({ onChange={e => setSegmentIdentifier(e.target.value, true)} /> setParentChildConfig({ ...parentChildConfig, @@ -803,7 +803,7 @@ const StepTwo = ({ })} /> setParentChildConfig({ ...parentChildConfig, From 640ee80010829035809e498c6fc7a4168045d09c Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:15:23 +0800 Subject: [PATCH 06/68] feat: add red corner mark to Badge component for marketplace plugins (#18162) --- web/app/components/plugins/plugin-item/index.tsx | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/web/app/components/plugins/plugin-item/index.tsx b/web/app/components/plugins/plugin-item/index.tsx index 9f66e5f400..8ce26b737a 100644 --- a/web/app/components/plugins/plugin-item/index.tsx +++ b/web/app/components/plugins/plugin-item/index.tsx @@ -104,7 +104,10 @@ const PluginItem: FC = ({ {!isDifyVersionCompatible && } - +
From fcdf965037f03a531833f49c52080d94a07aa55a Mon Sep 17 00:00:00 2001 From: GuanMu Date: Wed, 16 Apr 2025 15:48:09 +0800 Subject: [PATCH 07/68] feat: add PATCH method support in Heading component (#18160) --- web/app/components/develop/md.tsx | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/web/app/components/develop/md.tsx b/web/app/components/develop/md.tsx index 655cd08280..a9b74a389f 100644 --- a/web/app/components/develop/md.tsx +++ b/web/app/components/develop/md.tsx @@ -12,7 +12,7 @@ type IChildrenProps = { type IHeaderingProps = { url: string - method: 'PUT' | 'DELETE' | 'GET' | 'POST' + method: 'PUT' | 'DELETE' | 'GET' | 'POST' | 'PATCH' title: string name: string } @@ -34,6 +34,9 @@ export const Heading = function H2({ case 'POST': style = 'ring-sky-300 bg-sky-400/10 text-sky-500 dark:ring-sky-400/30 dark:bg-sky-400/10 dark:text-sky-400' break + case 'PATCH': + style = 'ring-violet-300 bg-violet-400/10 text-violet-500 dark:ring-violet-400/30 dark:bg-violet-400/10 dark:text-violet-400' + break default: style = 'ring-emerald-300 dark:ring-emerald-400/30 bg-emerald-400/10 text-emerald-500 dark:text-emerald-400' break From b247ef85bf731c701745e3e61ea41e560db3e819 Mon Sep 17 00:00:00 2001 From: kenwoodjw Date: Wed, 16 Apr 2025 15:50:06 +0800 Subject: [PATCH 08/68] fix dataset api retrieval model null handling (#18151) Signed-off-by: kenwoodjw --- api/controllers/service_api/dataset/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index f087243a25..e1e6f3168f 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -139,7 +139,9 @@ class DatasetListApi(DatasetApiResource): external_knowledge_id=args["external_knowledge_id"], embedding_model_provider=args["embedding_model_provider"], embedding_model_name=args["embedding_model"], - retrieval_model=RetrievalModel(**args["retrieval_model"]), + retrieval_model=RetrievalModel(**args["retrieval_model"]) + if args["retrieval_model"] is not None + else None, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() From e1455cecd8a863a97ab04a37b77cb2d5f94f8dd8 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:50:15 +0800 Subject: [PATCH 09/68] feat: add switches for jina firecrawl watercrawl (#18153) --- docker/.env.example | 6 ++++++ docker/docker-compose-template.yaml | 4 +++- docker/docker-compose.yaml | 7 ++++++- web/.env.example | 5 +++++ .../datasets/create/step-one/index.tsx | 14 ++++++------- .../datasets/create/website/index.tsx | 13 ++++++------ .../datasets/create/website/no-data.tsx | 20 ++++++++++--------- .../data-source-page/index.tsx | 7 ++++--- web/config/index.ts | 12 +++++++++++ web/docker/entrypoint.sh | 4 +++- 10 files changed, 64 insertions(+), 28 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index acb09c0d4f..e49e8fee89 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -174,6 +174,12 @@ CELERY_MIN_WORKERS= API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 API_TOOL_DEFAULT_READ_TIMEOUT=60 +# ------------------------------- +# Datasource Configuration +# -------------------------------- +ENABLE_WEBSITE_JINAREADER=true +ENABLE_WEBSITE_FIRECRAWL=true +ENABLE_WEBSITE_WATERCRAWL=true # ------------------------------ # Database Configuration diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 86976063c3..a8f7b755fb 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -75,7 +75,9 @@ services: MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5} - + ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true} + ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true} + ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true} # The postgres database. db: image: postgres:15-alpine diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index e9c8c8715a..25b0c56561 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -43,6 +43,9 @@ x-shared-env: &shared-api-worker-env CELERY_MIN_WORKERS: ${CELERY_MIN_WORKERS:-} API_TOOL_DEFAULT_CONNECT_TIMEOUT: ${API_TOOL_DEFAULT_CONNECT_TIMEOUT:-10} API_TOOL_DEFAULT_READ_TIMEOUT: ${API_TOOL_DEFAULT_READ_TIMEOUT:-60} + ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true} + ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true} + ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true} DB_USERNAME: ${DB_USERNAME:-postgres} DB_PASSWORD: ${DB_PASSWORD:-difyai123456} DB_HOST: ${DB_HOST:-db} @@ -543,7 +546,9 @@ services: MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5} - + ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true} + ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true} + ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true} # The postgres database. db: image: postgres:15-alpine diff --git a/web/.env.example b/web/.env.example index 51dc3d6b3c..1c3f42ddfc 100644 --- a/web/.env.example +++ b/web/.env.example @@ -49,3 +49,8 @@ NEXT_PUBLIC_MAX_PARALLEL_LIMIT=10 # The maximum number of iterations for agent setting NEXT_PUBLIC_MAX_ITERATIONS_NUM=5 + +NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=true +NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true +NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true + diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index 6f4231bb1f..38c885ebe2 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -20,7 +20,7 @@ import { useProviderContext } from '@/context/provider-context' import VectorSpaceFull from '@/app/components/billing/vector-space-full' import classNames from '@/utils/classnames' import { Icon3Dots } from '@/app/components/base/icons/src/vender/line/others' - +import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' type IStepOneProps = { datasetId?: string dataSourceType?: DataSourceType @@ -126,9 +126,7 @@ const StepOne = ({ return true if (files.some(file => !file.file.id)) return true - if (isShowVectorSpaceFull) - return true - return false + return isShowVectorSpaceFull }, [files, isShowVectorSpaceFull]) return ( @@ -193,7 +191,8 @@ const StepOne = ({ {t('datasetCreation.stepOne.dataSourceType.notion')}
-
changeType(DataSourceType.WEB)} - > + > {t('datasetCreation.stepOne.dataSourceType.web')} -
+ + )} ) } diff --git a/web/app/components/datasets/create/website/index.tsx b/web/app/components/datasets/create/website/index.tsx index 5122ef6ed2..e2d0e2df99 100644 --- a/web/app/components/datasets/create/website/index.tsx +++ b/web/app/components/datasets/create/website/index.tsx @@ -12,6 +12,7 @@ import { useModalContext } from '@/context/modal-context' import type { CrawlOptions, CrawlResultItem } from '@/models/datasets' import { fetchDataSources } from '@/service/datasets' import { type DataSourceItem, DataSourceProvider } from '@/models/common' +import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' type Props = { onPreview: (payload: CrawlResultItem) => void @@ -84,7 +85,7 @@ const Website: FC = ({ {t('datasetCreation.stepOne.website.chooseProvider')}
- - - + }
{source && selectedProvider === DataSourceProvider.fireCrawl && ( diff --git a/web/app/components/datasets/create/website/no-data.tsx b/web/app/components/datasets/create/website/no-data.tsx index 14be2e29f6..65a314f516 100644 --- a/web/app/components/datasets/create/website/no-data.tsx +++ b/web/app/components/datasets/create/website/no-data.tsx @@ -6,6 +6,7 @@ import s from './index.module.css' import { Icon3Dots } from '@/app/components/base/icons/src/vender/line/others' import Button from '@/app/components/base/button' import { DataSourceProvider } from '@/models/common' +import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' const I18N_PREFIX = 'datasetCreation.stepOne.website' @@ -16,29 +17,30 @@ type Props = { const NoData: FC = ({ onConfig, - provider, }) => { const { t } = useTranslation() const providerConfig = { - [DataSourceProvider.jinaReader]: { + [DataSourceProvider.jinaReader]: ENABLE_WEBSITE_JINAREADER ? { emoji: , title: t(`${I18N_PREFIX}.jinaReaderNotConfigured`), description: t(`${I18N_PREFIX}.jinaReaderNotConfiguredDescription`), - }, - [DataSourceProvider.fireCrawl]: { + } : null, + [DataSourceProvider.fireCrawl]: ENABLE_WEBSITE_FIRECRAWL ? { emoji: '🔥', title: t(`${I18N_PREFIX}.fireCrawlNotConfigured`), description: t(`${I18N_PREFIX}.fireCrawlNotConfiguredDescription`), - }, - [DataSourceProvider.waterCrawl]: { - emoji: , + } : null, + [DataSourceProvider.waterCrawl]: ENABLE_WEBSITE_WATERCRAWL ? { + emoji: '💧', title: t(`${I18N_PREFIX}.waterCrawlNotConfigured`), description: t(`${I18N_PREFIX}.waterCrawlNotConfiguredDescription`), - }, + } : null, } - const currentProvider = providerConfig[provider] + const currentProvider = Object.values(providerConfig).find(provider => provider !== null) || providerConfig[DataSourceProvider.jinaReader] + + if (!currentProvider) return null return ( <> diff --git a/web/app/components/header/account-setting/data-source-page/index.tsx b/web/app/components/header/account-setting/data-source-page/index.tsx index d99bd25e02..fb13813d70 100644 --- a/web/app/components/header/account-setting/data-source-page/index.tsx +++ b/web/app/components/header/account-setting/data-source-page/index.tsx @@ -3,6 +3,7 @@ import DataSourceNotion from './data-source-notion' import DataSourceWebsite from './data-source-website' import { fetchDataSource } from '@/service/common' import { DataSourceProvider } from '@/models/common' +import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' export default function DataSourcePage() { const { data } = useSWR({ url: 'data-source/integrates' }, fetchDataSource) @@ -11,9 +12,9 @@ export default function DataSourcePage() { return (
- - - + {ENABLE_WEBSITE_JINAREADER && } + {ENABLE_WEBSITE_FIRECRAWL && } + {ENABLE_WEBSITE_WATERCRAWL && }
) } diff --git a/web/config/index.ts b/web/config/index.ts index 2b81adb095..b164392c52 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -302,3 +302,15 @@ else if (globalThis.document?.body?.getAttribute('data-public-max-iterations-num maxIterationsNum = Number.parseInt(globalThis.document.body.getAttribute('data-public-max-iterations-num') as string) export const MAX_ITERATIONS_NUM = maxIterationsNum + +export const ENABLE_WEBSITE_JINAREADER = process.env.NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER !== undefined + ? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER === 'true' + : true + +export const ENABLE_WEBSITE_FIRECRAWL = process.env.NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL !== undefined + ? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL === 'true' + : true + +export const ENABLE_WEBSITE_WATERCRAWL = process.env.NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL !== undefined + ? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL === 'true' + : true diff --git a/web/docker/entrypoint.sh b/web/docker/entrypoint.sh index 797b61081a..8395ac5f4d 100755 --- a/web/docker/entrypoint.sh +++ b/web/docker/entrypoint.sh @@ -28,5 +28,7 @@ export NEXT_PUBLIC_CSP_WHITELIST=${CSP_WHITELIST} export NEXT_PUBLIC_TOP_K_MAX_VALUE=${TOP_K_MAX_VALUE} export NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH} export NEXT_PUBLIC_MAX_TOOLS_NUM=${MAX_TOOLS_NUM} - +export NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=${ENABLE_WEBSITE_JINAREADER:-true} +export NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=${ENABLE_WEBSITE_FIRECRAWL:-true} +export NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=${ENABLE_WEBSITE_WATERCRAWL:-true} pm2 start /app/web/server.js --name dify-web --cwd /app/web -i ${PM2_INSTANCES} --no-daemon From b006b9ac0cf8fbdbc07c33eb82c97f3067a5658b Mon Sep 17 00:00:00 2001 From: Ganondorf <364776488@qq.com> Date: Wed, 16 Apr 2025 15:59:34 +0800 Subject: [PATCH 10/68] Http requests node add ssl verify (#18125) Co-authored-by: lizb --- api/core/helper/ssrf_proxy.py | 19 ++++++++++--------- .../workflow/nodes/http_request/entities.py | 1 + .../workflow/nodes/http_request/executor.py | 2 ++ api/core/workflow/nodes/http_request/node.py | 1 + 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 969cd112ee..11f245812e 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -48,25 +48,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, ) + if "ssl_verify" not in kwargs: + kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY + + ssl_verify = kwargs.pop("ssl_verify") + retries = 0 while retries <= max_retries: try: if dify_config.SSRF_PROXY_ALL_URL: - with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: + with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client: response = client.request(method=method, url=url, **kwargs) elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: proxy_mounts = { - "http://": httpx.HTTPTransport( - proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY - ), - "https://": httpx.HTTPTransport( - proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY - ), + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify), + "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify), } - with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: + with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client: response = client.request(method=method, url=url, **kwargs) else: - with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: + with httpx.Client(verify=ssl_verify) as client: response = client.request(method=method, url=url, **kwargs) if response.status_code not in STATUS_FORCELIST: diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 054e30f0aa..8d7ba25d47 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -90,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData): params: str body: Optional[HttpRequestNodeBody] = None timeout: Optional[HttpRequestNodeTimeout] = None + ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY class Response: diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index f7fa8d670c..5d466e645f 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -88,6 +88,7 @@ class Executor: self.method = node_data.method self.auth = node_data.authorization self.timeout = timeout + self.ssl_verify = node_data.ssl_verify self.params = [] self.headers = {} self.content = None @@ -316,6 +317,7 @@ class Executor: "headers": headers, "params": self.params, "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "ssl_verify": self.ssl_verify, "follow_redirects": True, "max_retries": self.max_retries, } diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 467161d5ed..fd2b0f9ae8 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -51,6 +51,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, + "ssl_verify": dify_config.HTTP_REQUEST_NODE_SSL_VERIFY, }, "retry_config": { "max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES, From c6e2970b65ed63e2b6e7579b296dce9c461c0754 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 16 Apr 2025 17:09:17 +0900 Subject: [PATCH 11/68] chore: Reorganizes test file structure (#18155) Signed-off-by: -LAN- --- .../unit_tests/services/workflow}/test_workflow_deletion.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename api/{ => tests/unit_tests/services/workflow}/test_workflow_deletion.py (100%) diff --git a/api/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py similarity index 100% rename from api/test_workflow_deletion.py rename to api/tests/unit_tests/services/workflow/test_workflow_deletion.py From 8cc37f31157903286114caec7beca14d32ebd473 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Wed, 16 Apr 2025 16:26:24 +0800 Subject: [PATCH 12/68] fix:the extraction function of the list operation node received 0 that should not be received (#18170) --- api/core/workflow/nodes/list_operator/node.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 432c57294e..04ccfc5405 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -149,7 +149,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): def _extract_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - 1 + value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) + if value < 1: + raise ValueError(f"Invalid serial index: must be >= 1, got {value}") + value -= 1 if len(variable.value) > int(value): result = variable.value[value] else: From da7c8621f7b09659b7b5a519f39d897f0c002d1f Mon Sep 17 00:00:00 2001 From: "Junjie.M" <118170653@qq.com> Date: Wed, 16 Apr 2025 17:03:18 +0800 Subject: [PATCH 13/68] fix: agent strategy string type parameter default value invalid (#18185) --- .../workflow/nodes/_base/components/agent-strategy.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx b/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx index 63f9fb92ec..be57cbca0f 100644 --- a/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx +++ b/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx @@ -65,7 +65,7 @@ export const AgentStrategy = memo((props: AgentStrategyProps) => { switch (schema.type) { case FormTypeEnum.textInput: { const def = schema as CredentialFormSchemaTextInput - const value = props.value[schema.variable] + const value = props.value[schema.variable] || schema.default const onChange = (value: string) => { props.onChange({ ...props.value, [schema.variable]: value }) } From b7e8517b31e6747576cd3e2227318934f60dbe85 Mon Sep 17 00:00:00 2001 From: "Junjie.M" <118170653@qq.com> Date: Wed, 16 Apr 2025 17:24:09 +0800 Subject: [PATCH 14/68] feat: agent strategy parameter add help information (#18192) --- api/core/agent/plugin_entities.py | 1 + web/app/components/plugins/types.ts | 3 +-- web/app/components/workflow/nodes/agent/panel.tsx | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 6cf3975333..9c722baa23 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter): return cast_parameter_value(self, value) type: AgentStrategyParameterType = Field(..., description="The type of the parameter") + help: Optional[I18nObject] = None def init_frontend_parameter(self, value: Any): return init_frontend_parameter(self, self.type, value) diff --git a/web/app/components/plugins/types.ts b/web/app/components/plugins/types.ts index 6c42e50123..5ed05d4523 100644 --- a/web/app/components/plugins/types.ts +++ b/web/app/components/plugins/types.ts @@ -406,8 +406,7 @@ export type VersionProps = { export type StrategyParamItem = { name: string label: Record - human_description: Record - llm_description: string + help: Record placeholder: Record type: string scope: string diff --git a/web/app/components/workflow/nodes/agent/panel.tsx b/web/app/components/workflow/nodes/agent/panel.tsx index da87312a90..6a80728d91 100644 --- a/web/app/components/workflow/nodes/agent/panel.tsx +++ b/web/app/components/workflow/nodes/agent/panel.tsx @@ -27,6 +27,7 @@ export function strategyParamToCredientialForm(param: StrategyParamItem): Creden variable: param.name, show_on: [], type: toType(param.type), + tooltip: param.help, } } @@ -53,6 +54,7 @@ const AgentPanel: FC> = (props) => { outputSchema, handleMemoryChange, } = useConfig(props.id, props.data) + console.log('currentStrategy', currentStrategy) const { t } = useTranslation() const nodeInfo = useMemo(() => { if (!runResult) From bbd9fe9777e41f445092e0d2fcbb65c78e67519e Mon Sep 17 00:00:00 2001 From: KVOJJJin Date: Wed, 16 Apr 2025 17:25:25 +0800 Subject: [PATCH 15/68] Fix:style of opening questions (#18194) --- .../base/chat/chat/answer/suggested-questions.tsx | 12 +++--------- .../base/chat/embedded-chatbot/chat-wrapper.tsx | 4 ++-- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/web/app/components/base/chat/chat/answer/suggested-questions.tsx b/web/app/components/base/chat/chat/answer/suggested-questions.tsx index 7b8da0e9f0..8b64bff6a3 100644 --- a/web/app/components/base/chat/chat/answer/suggested-questions.tsx +++ b/web/app/components/base/chat/chat/answer/suggested-questions.tsx @@ -2,8 +2,6 @@ import type { FC } from 'react' import { memo } from 'react' import type { ChatItem } from '../../types' import { useChatContext } from '../context' -import Button from '@/app/components/base/button' -import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' type SuggestedQuestionsProps = { item: ChatItem @@ -12,9 +10,6 @@ const SuggestedQuestions: FC = ({ item, }) => { const { onSend } = useChatContext() - const media = useBreakpoints() - const isMobile = media === MediaType.mobile - const klassName = `mr-1 mt-1 ${isMobile ? 'block overflow-hidden text-ellipsis' : ''} max-w-full shrink-0 last:mr-0` const { isOpeningStatement, @@ -27,14 +22,13 @@ const SuggestedQuestions: FC = ({ return (
{suggestedQuestions.filter(q => !!q && q.trim()).map((question, index) => ( - ), +
), )} ) diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index cb9dd37b43..a06930c48f 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -184,7 +184,7 @@ const ChatWrapper = () => { return null if (welcomeMessage.suggestedQuestions && welcomeMessage.suggestedQuestions?.length > 0) { return ( -
+
{ ) } return ( -
+
Date: Wed, 16 Apr 2025 17:26:47 +0800 Subject: [PATCH 16/68] fix: page/limit param not effective (#18196) --- api/controllers/service_api/dataset/segment.py | 2 ++ api/services/dataset_service.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 3d5869c371..2a79e15cc5 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -122,6 +122,8 @@ class SegmentApi(DatasetApiResource): tenant_id=current_user.current_tenant_id, status_list=args["status"], keyword=args["keyword"], + page=page, + limit=limit, ) response = { diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index deb6be5a43..b08d70489a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2175,7 +2175,13 @@ class SegmentService: @classmethod def get_segments( - cls, document_id: str, tenant_id: str, status_list: list[str] | None = None, keyword: str | None = None + cls, + document_id: str, + tenant_id: str, + status_list: list[str] | None = None, + keyword: str | None = None, + page: int = 1, + limit: int = 20, ): """Get segments for a document with optional filtering.""" query = DocumentSegment.query.filter( @@ -2188,10 +2194,11 @@ class SegmentService: if keyword: query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) - segments = query.order_by(DocumentSegment.position.asc()).all() - total = len(segments) + paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate( + page=page, per_page=limit, max_per_page=100, error_out=False + ) - return segments, total + return paginated_segments.items, paginated_segments.total @classmethod def update_segment_by_id( From 18f98f4fe1a1a3feb50393d28cbaa8b2e55a8fe1 Mon Sep 17 00:00:00 2001 From: jiangbo721 <365065261@qq.com> Date: Wed, 16 Apr 2025 19:21:18 +0800 Subject: [PATCH 17/68] fix: ruff check isoparse (#18033) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 刘江波 --- api/controllers/console/app/workflow_app_log.py | 7 +++---- api/controllers/service_api/app/workflow.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 54640b1a19..d863747995 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,5 +1,4 @@ -from datetime import datetime - +from dateutil.parser import isoparse from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session @@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource): args.status = WorkflowRunStatus(args.status) if args.status else None if args.created_at__before: - args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) + args.created_at__before = isoparse(args.created_at__before) if args.created_at__after: - args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) + args.created_at__after = isoparse(args.created_at__after) # get paginate workflow app logs workflow_app_service = WorkflowAppService() diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 2854a43505..8b10a028f3 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,6 +1,6 @@ import logging -from datetime import datetime +from dateutil.parser import isoparse from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session @@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource): args.status = WorkflowRunStatus(args.status) if args.status else None if args.created_at__before: - args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) + args.created_at__before = isoparse(args.created_at__before) if args.created_at__after: - args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) + args.created_at__after = isoparse(args.created_at__after) # get paginate workflow app logs workflow_app_service = WorkflowAppService() From cac0d3c33e04790cc8874162f7d97827ef4783ba Mon Sep 17 00:00:00 2001 From: Arcaner <52057416+lrhan321@users.noreply.github.com> Date: Wed, 16 Apr 2025 19:21:50 +0800 Subject: [PATCH 18/68] fix: implement robust file type checks to align with existing logic (#17557) Co-authored-by: Bowen Liang --- api/core/app/apps/base_app_generator.py | 2 + api/core/app/apps/workflow/app_generator.py | 6 +- api/factories/file_factory.py | 43 +++- .../factories/test_build_from_mapping.py | 198 ++++++++++++++++++ 4 files changed, 243 insertions(+), 6 deletions(-) create mode 100644 api/tests/unit_tests/factories/test_build_from_mapping.py diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 5d559b96d7..a83b75cc1a 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -17,6 +17,7 @@ class BaseAppGenerator: user_inputs: Optional[Mapping[str, Any]], variables: Sequence["VariableEntity"], tenant_id: str, + strict_type_validation: bool = False, ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values @@ -37,6 +38,7 @@ class BaseAppGenerator: allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, ), + strict_type_validation=strict_type_validation, ) for k, v in user_inputs.items() if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index cc7bcdeee1..08986b16f0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator): mappings=files, tenant_id=app_model.tenant_id, config=file_extra_config, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, ) # convert to app config @@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator): app_config=app_config, file_upload_config=file_extra_config, inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + user_inputs=inputs, + variables=app_config.variables, + tenant_id=app_model.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, ), files=list(system_files), user_id=user.id, diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index b69621ba5b..52f119936f 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -52,6 +52,7 @@ def build_from_mapping( mapping: Mapping[str, Any], tenant_id: str, config: FileUploadConfig | None = None, + strict_type_validation: bool = False, ) -> File: transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) @@ -69,6 +70,7 @@ def build_from_mapping( mapping=mapping, tenant_id=tenant_id, transfer_method=transfer_method, + strict_type_validation=strict_type_validation, ) if config and not _is_file_valid_with_config( @@ -87,12 +89,14 @@ def build_from_mappings( mappings: Sequence[Mapping[str, Any]], config: FileUploadConfig | None = None, tenant_id: str, + strict_type_validation: bool = False, ) -> Sequence[File]: files = [ build_from_mapping( mapping=mapping, tenant_id=tenant_id, config=config, + strict_type_validation=strict_type_validation, ) for mapping in mappings ] @@ -116,6 +120,7 @@ def _build_from_local_file( mapping: Mapping[str, Any], tenant_id: str, transfer_method: FileTransferMethod, + strict_type_validation: bool = False, ) -> File: upload_file_id = mapping.get("upload_file_id") if not upload_file_id: @@ -134,10 +139,16 @@ def _build_from_local_file( if row is None: raise ValueError("Invalid upload file") - file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - if file_type.value != mapping.get("type", "custom"): + detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + specified_type = mapping.get("type", "custom") + + if strict_type_validation and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") + file_type = ( + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type + ) + return File( id=mapping.get("id"), filename=row.name, @@ -158,6 +169,7 @@ def _build_from_remote_url( mapping: Mapping[str, Any], tenant_id: str, transfer_method: FileTransferMethod, + strict_type_validation: bool = False, ) -> File: upload_file_id = mapping.get("upload_file_id") if upload_file_id: @@ -174,10 +186,21 @@ def _build_from_remote_url( if upload_file is None: raise ValueError("Invalid upload file") - file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type) - if file_type.value != mapping.get("type", "custom"): + detected_file_type = _standardize_file_type( + extension="." + upload_file.extension, mime_type=upload_file.mime_type + ) + + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") + file_type = ( + FileType(specified_type) + if specified_type and specified_type != FileType.CUSTOM.value + else detected_file_type + ) + return File( id=mapping.get("id"), filename=upload_file.name, @@ -237,6 +260,7 @@ def _build_from_tool_file( mapping: Mapping[str, Any], tenant_id: str, transfer_method: FileTransferMethod, + strict_type_validation: bool = False, ) -> File: tool_file = ( db.session.query(ToolFile) @@ -252,7 +276,16 @@ def _build_from_tool_file( extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype) + + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + file_type = ( + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type + ) return File( id=mapping.get("id"), diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py new file mode 100644 index 0000000000..48463a369e --- /dev/null +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -0,0 +1,198 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from httpx import Response + +from factories.file_factory import ( + File, + FileTransferMethod, + FileType, + FileUploadConfig, + build_from_mapping, +) +from models import ToolFile, UploadFile + +# Test Data +TEST_TENANT_ID = "test_tenant_id" +TEST_UPLOAD_FILE_ID = str(uuid.uuid4()) +TEST_TOOL_FILE_ID = str(uuid.uuid4()) +TEST_REMOTE_URL = "http://example.com/test.jpg" + +# Test Config +TEST_CONFIG = FileUploadConfig( + allowed_file_types=["image", "document"], + allowed_file_extensions=[".jpg", ".pdf"], + allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE], + number_limits=10, +) + + +# Fixtures +@pytest.fixture +def mock_upload_file(): + mock = MagicMock(spec=UploadFile) + mock.id = TEST_UPLOAD_FILE_ID + mock.tenant_id = TEST_TENANT_ID + mock.name = "test.jpg" + mock.extension = "jpg" + mock.mime_type = "image/jpeg" + mock.source_url = TEST_REMOTE_URL + mock.size = 1024 + mock.key = "test_key" + with patch("factories.file_factory.db.session.scalar", return_value=mock) as m: + yield m + + +@pytest.fixture +def mock_tool_file(): + mock = MagicMock(spec=ToolFile) + mock.id = TEST_TOOL_FILE_ID + mock.tenant_id = TEST_TENANT_ID + mock.name = "tool_file.pdf" + mock.file_key = "tool_file.pdf" + mock.mimetype = "application/pdf" + mock.original_url = "http://example.com/tool.pdf" + mock.size = 2048 + with patch("factories.file_factory.db.session.query") as mock_query: + mock_query.return_value.filter.return_value.first.return_value = mock + yield mock + + +@pytest.fixture +def mock_http_head(): + def _mock_response(filename, size, content_type): + return Response( + status_code=200, + headers={ + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": str(size), + "Content-Type": content_type, + }, + ) + + with patch("factories.file_factory.ssrf_proxy.head") as mock_head: + mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg") + yield mock_head + + +# Helper functions +def local_file_mapping(file_type="image"): + return { + "transfer_method": "local_file", + "upload_file_id": TEST_UPLOAD_FILE_ID, + "type": file_type, + } + + +def tool_file_mapping(file_type="document"): + return { + "transfer_method": "tool_file", + "tool_file_id": TEST_TOOL_FILE_ID, + "type": file_type, + } + + +# Tests +def test_build_from_mapping_backward_compatibility(mock_upload_file): + mapping = local_file_mapping(file_type="image") + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + assert isinstance(file, File) + assert file.transfer_method == FileTransferMethod.LOCAL_FILE + assert file.type == FileType.IMAGE + assert file.related_id == TEST_UPLOAD_FILE_ID + + +@pytest.mark.parametrize( + ("file_type", "should_pass", "expected_error"), + [ + ("image", True, None), + ("document", False, "Detected file type does not match"), + ], +) +def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error): + mapping = local_file_mapping(file_type=file_type) + if should_pass: + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + assert file.type == FileType(file_type) + else: + with pytest.raises(ValueError, match=expected_error): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + + +@pytest.mark.parametrize( + ("file_type", "should_pass", "expected_error"), + [ + ("document", True, None), + ("image", False, "Detected file type does not match"), + ], +) +def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error): + """Strict type validation for tool_file.""" + mapping = tool_file_mapping(file_type=file_type) + if should_pass: + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + assert file.type == FileType(file_type) + else: + with pytest.raises(ValueError, match=expected_error): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + + +def test_build_from_remote_url(mock_http_head): + mapping = { + "transfer_method": "remote_url", + "url": TEST_REMOTE_URL, + "type": "image", + } + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + assert file.transfer_method == FileTransferMethod.REMOTE_URL + assert file.type == FileType.IMAGE + assert file.filename == "remote_test.jpg" + assert file.size == 2048 + + +def test_tool_file_not_found(): + """Test ToolFile not found in database.""" + with patch("factories.file_factory.db.session.query") as mock_query: + mock_query.return_value.filter.return_value.first.return_value = None + mapping = tool_file_mapping() + with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + +def test_local_file_not_found(): + """Test UploadFile not found in database.""" + with patch("factories.file_factory.db.session.scalar", return_value=None): + mapping = local_file_mapping() + with pytest.raises(ValueError, match="Invalid upload file"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + +def test_build_without_type_specification(mock_upload_file): + """Test the situation where no file type is specified""" + mapping = { + "transfer_method": "local_file", + "upload_file_id": TEST_UPLOAD_FILE_ID, + # leave out the type + } + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + # It should automatically infer the type as "image" based on the file extension + assert file.type == FileType.IMAGE + + +@pytest.mark.parametrize( + ("file_type", "should_pass", "expected_error"), + [ + ("image", True, None), + ("video", False, "File validation failed"), + ], +) +def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error): + """Test the validation of files and configurations""" + mapping = local_file_mapping(file_type=file_type) + if should_pass: + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG) + assert file is not None + else: + with pytest.raises(ValueError, match=expected_error): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG) From e912928ccef2388d5b5b2020909363f3aa8ba16e Mon Sep 17 00:00:00 2001 From: devxing <66726106+devxing@users.noreply.github.com> Date: Wed, 16 Apr 2025 19:56:21 +0800 Subject: [PATCH 19/68] fix: create child chunk (#18209) Co-authored-by: devxing --- api/services/dataset_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index b08d70489a..44d2594ee8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2025,7 +2025,7 @@ class SegmentService: dataset_id=dataset.id, document_id=document.id, segment_id=segment.id, - position=max_position + 1, + position=max_position + 1 if max_position else 1, index_node_id=index_node_id, index_node_hash=index_node_hash, content=content, From 358fd28c28a2ed1c491938ae3dd51a90e6736cb0 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Wed, 16 Apr 2025 20:27:29 +0800 Subject: [PATCH 20/68] feat: fetch app info in plugins (#18202) --- api/controllers/common/helpers.py | 41 ----------------- api/controllers/console/explore/parameter.py | 6 +-- api/controllers/inner_api/plugin/plugin.py | 13 ++++++ api/controllers/service_api/app/app.py | 6 +-- api/controllers/web/app.py | 6 +-- .../common/parameters_mapping/__init__.py | 45 +++++++++++++++++++ api/core/plugin/backwards_invocation/app.py | 29 ++++++++++++ api/core/plugin/entities/request.py | 8 ++++ 8 files changed, 101 insertions(+), 53 deletions(-) create mode 100644 api/core/app/app_config/common/parameters_mapping/__init__.py diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 282708c037..008f1f0f7a 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -4,14 +4,10 @@ import platform import re import urllib.parse import warnings -from collections.abc import Mapping -from typing import Any from uuid import uuid4 import httpx -from constants import DEFAULT_FILE_NUMBER_LIMITS - try: import magic except ImportError: @@ -31,8 +27,6 @@ except ImportError: from pydantic import BaseModel -from configs import dify_config - class FileInfo(BaseModel): filename: str @@ -89,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response): mimetype=mimetype, size=int(response.headers.get("Content-Length", -1)), ) - - -def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]): - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": DEFAULT_FILE_NUMBER_LIMITS, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, - }, - } diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 5bc74d16e7..bf9f0d6b28 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,10 +1,10 @@ from flask_restful import marshal_with # type: ignore from controllers.common import fields -from controllers.common import helpers as controller_helpers from controllers.console import api from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from models.model import AppMode, InstalledApp from services.app_service import AppService @@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource): user_input_form = features_dict.get("user_input_form", []) - return controller_helpers.get_parameters_from_feature_dict( - features_dict=features_dict, user_input_form=user_input_form - ) + return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) class ExploreAppMetaApi(InstalledAppResource): diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index fe892922e9..061ad62a4a 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation from core.plugin.entities.request import ( + RequestFetchAppInfo, RequestInvokeApp, RequestInvokeEncrypt, RequestInvokeLLM, @@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource): return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() +class PluginFetchAppInfoApi(Resource): + @setup_required + @plugin_inner_api_only + @get_user_tenant + @plugin_data(payload_type=RequestFetchAppInfo) + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo): + return BaseBackwardsInvocationResponse( + data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id) + ).model_dump() + + api.add_resource(PluginInvokeLLMApi, "/invoke/llm") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") @@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app") api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") +api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info") diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 8388e2045d..7131e8a310 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,10 +1,10 @@ from flask_restful import Resource, marshal_with # type: ignore from controllers.common import fields -from controllers.common import helpers as controller_helpers from controllers.service_api import api from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from models.model import App, AppMode from services.app_service import AppService @@ -32,9 +32,7 @@ class AppParameterApi(Resource): user_input_form = features_dict.get("user_input_form", []) - return controller_helpers.get_parameters_from_feature_dict( - features_dict=features_dict, user_input_form=user_input_form - ) + return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) class AppMetaApi(Resource): diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 20e071c834..a84b846112 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,10 +1,10 @@ from flask_restful import marshal_with # type: ignore from controllers.common import fields -from controllers.common import helpers as controller_helpers from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from models.model import App, AppMode from services.app_service import AppService @@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource): user_input_form = features_dict.get("user_input_form", []) - return controller_helpers.get_parameters_from_feature_dict( - features_dict=features_dict, user_input_form=user_input_form - ) + return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) class AppMeta(WebApiResource): diff --git a/api/core/app/app_config/common/parameters_mapping/__init__.py b/api/core/app/app_config/common/parameters_mapping/__init__.py new file mode 100644 index 0000000000..6f1a3bf045 --- /dev/null +++ b/api/core/app/app_config/common/parameters_mapping/__init__.py @@ -0,0 +1,45 @@ +from collections.abc import Mapping +from typing import Any + +from configs import dify_config +from constants import DEFAULT_FILE_NUMBER_LIMITS + + +def get_parameters_from_feature_dict( + *, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]] +) -> Mapping[str, Any]: + """ + Mapping from feature dict to webapp parameters + """ + return { + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": DEFAULT_FILE_NUMBER_LIMITS, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": { + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + }, + } diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 29873b508f..484f52e33c 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping from typing import Optional, Union from controllers.service_api.wraps import create_or_update_end_user_for_user_id +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator @@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser class PluginAppBackwardsInvocation(BaseBackwardsInvocation): + @classmethod + def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping: + """ + Fetch app info + """ + app = cls._get_app(app_id, tenant_id) + + """Retrieve app parameters.""" + if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + workflow = app.workflow + if workflow is None: + raise ValueError("unexpected app type") + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form(to_old_structure=True) + else: + app_model_config = app.app_model_config + if app_model_config is None: + raise ValueError("unexpected app type") + + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get("user_input_form", []) + + return { + "data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form), + } + @classmethod def invoke_app( cls, diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 837dcf59c4..6c0c7f2868 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel): filename: str mimetype: str + + +class RequestFetchAppInfo(BaseModel): + """ + Request to fetch app info + """ + + app_id: str From 44cdb3dceaf6414d3ca3e5fd864fed4ed1e335fa Mon Sep 17 00:00:00 2001 From: Panpan Date: Wed, 16 Apr 2025 21:08:13 +0800 Subject: [PATCH 21/68] feat: improve embedding sys.user_id and conversion id info usage (#18035) --- .../base/chat/chat-with-history/hooks.tsx | 23 +++++++--- .../base/chat/embedded-chatbot/hooks.tsx | 23 +++++++--- web/app/components/share/utils.ts | 44 ++++++++++++++----- web/service/base.ts | 8 ++-- web/service/fetch.ts | 34 ++++++-------- 5 files changed, 85 insertions(+), 47 deletions(-) diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 9afaca2568..91ceaffd1e 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -16,7 +16,7 @@ import type { Feedback, } from '../types' import { CONVERSATION_ID_INFO } from '../constants' -import { buildChatItemTree } from '../utils' +import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams } from '../utils' import { addFileInfos, sortAgentSorts } from '../../../tools/utils' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import { @@ -106,6 +106,13 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [isInstalledApp, installedAppInfo, appInfo]) const appId = useMemo(() => appData?.app_id, [appData]) + const [userId, setUserId] = useState() + useEffect(() => { + getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => { + setUserId(user_id) + }) + }, []) + useEffect(() => { if (appData?.site.default_language) changeLanguage(appData.site.default_language) @@ -124,18 +131,24 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { setSidebarCollapseState(localState === 'collapsed') } }, [appId]) - const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState>(CONVERSATION_ID_INFO, { + const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState>>(CONVERSATION_ID_INFO, { defaultValue: {}, }) - const currentConversationId = useMemo(() => conversationIdInfo?.[appId || ''] || '', [appId, conversationIdInfo]) + const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId]) const handleConversationIdInfoChange = useCallback((changeConversationId: string) => { if (appId) { + let prevValue = conversationIdInfo?.[appId || ''] + if (typeof prevValue === 'string') + prevValue = {} setConversationIdInfo({ ...conversationIdInfo, - [appId || '']: changeConversationId, + [appId || '']: { + ...prevValue, + [userId || 'DEFAULT']: changeConversationId, + }, }) } - }, [appId, conversationIdInfo, setConversationIdInfo]) + }, [appId, conversationIdInfo, setConversationIdInfo, userId]) const [newConversationId, setNewConversationId] = useState('') const chatShouldReloadKey = useMemo(() => { diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index a5665ab346..d6a7b230e4 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -15,7 +15,7 @@ import type { Feedback, } from '../types' import { CONVERSATION_ID_INFO } from '../constants' -import { buildChatItemTree, getProcessedInputsFromUrlParams } from '../utils' +import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams } from '../utils' import { getProcessedFilesFromResponse } from '../../file-uploader/utils' import { fetchAppInfo, @@ -72,23 +72,36 @@ export const useEmbeddedChatbot = () => { }, [appInfo]) const appId = useMemo(() => appData?.app_id, [appData]) + const [userId, setUserId] = useState() + useEffect(() => { + getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => { + setUserId(user_id) + }) + }, []) + useEffect(() => { if (appInfo?.site.default_language) changeLanguage(appInfo.site.default_language) }, [appInfo]) - const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState>(CONVERSATION_ID_INFO, { + const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState>>(CONVERSATION_ID_INFO, { defaultValue: {}, }) - const currentConversationId = useMemo(() => conversationIdInfo?.[appId || ''] || '', [appId, conversationIdInfo]) + const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId]) const handleConversationIdInfoChange = useCallback((changeConversationId: string) => { if (appId) { + let prevValue = conversationIdInfo?.[appId || ''] + if (typeof prevValue === 'string') + prevValue = {} setConversationIdInfo({ ...conversationIdInfo, - [appId || '']: changeConversationId, + [appId || '']: { + ...prevValue, + [userId || 'DEFAULT']: changeConversationId, + }, }) } - }, [appId, conversationIdInfo, setConversationIdInfo]) + }, [appId, conversationIdInfo, setConversationIdInfo, userId]) const [newConversationId, setNewConversationId] = useState('') const chatShouldReloadKey = useMemo(() => { diff --git a/web/app/components/share/utils.ts b/web/app/components/share/utils.ts index f3ef12e4aa..9ce891a50c 100644 --- a/web/app/components/share/utils.ts +++ b/web/app/components/share/utils.ts @@ -2,29 +2,44 @@ import { CONVERSATION_ID_INFO } from '../base/chat/constants' import { fetchAccessToken } from '@/service/share' import { getProcessedSystemVariablesFromUrlParams } from '../base/chat/utils' +export const isTokenV1 = (token: Record) => { + return !token.version +} + +export const getInitialTokenV2 = (): Record => ({ + version: 2, +}) + export const checkOrSetAccessToken = async () => { const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] - const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) - let accessTokenJson = { [sharedToken]: '' } + const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id + const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2()) + let accessTokenJson = getInitialTokenV2() try { accessTokenJson = JSON.parse(accessToken) + if (isTokenV1(accessTokenJson)) + accessTokenJson = getInitialTokenV2() } catch { } - if (!accessTokenJson[sharedToken]) { - const sysUserId = (await getProcessedSystemVariablesFromUrlParams()).user_id - const res = await fetchAccessToken(sharedToken, sysUserId) - accessTokenJson[sharedToken] = res.access_token + if (!accessTokenJson[sharedToken]?.[userId || 'DEFAULT']) { + const res = await fetchAccessToken(sharedToken, userId) + accessTokenJson[sharedToken] = { + ...accessTokenJson[sharedToken], + [userId || 'DEFAULT']: res.access_token, + } localStorage.setItem('token', JSON.stringify(accessTokenJson)) } } -export const setAccessToken = async (sharedToken: string, token: string) => { - const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) - let accessTokenJson = { [sharedToken]: '' } +export const setAccessToken = async (sharedToken: string, token: string, user_id?: string) => { + const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2()) + let accessTokenJson = getInitialTokenV2() try { accessTokenJson = JSON.parse(accessToken) + if (isTokenV1(accessTokenJson)) + accessTokenJson = getInitialTokenV2() } catch { @@ -32,17 +47,22 @@ export const setAccessToken = async (sharedToken: string, token: string) => { localStorage.removeItem(CONVERSATION_ID_INFO) - accessTokenJson[sharedToken] = token + accessTokenJson[sharedToken] = { + ...accessTokenJson[sharedToken], + [user_id || 'DEFAULT']: token, + } localStorage.setItem('token', JSON.stringify(accessTokenJson)) } export const removeAccessToken = () => { const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] - const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) - let accessTokenJson = { [sharedToken]: '' } + const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2()) + let accessTokenJson = getInitialTokenV2() try { accessTokenJson = JSON.parse(accessToken) + if (isTokenV1(accessTokenJson)) + accessTokenJson = getInitialTokenV2() } catch { diff --git a/web/service/base.ts b/web/service/base.ts index f265d8052c..e3d1dc0ca2 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -287,9 +287,9 @@ const handleStream = ( const baseFetch = base -export const upload = (options: any, isPublicAPI?: boolean, url?: string, searchParams?: string): Promise => { +export const upload = async (options: any, isPublicAPI?: boolean, url?: string, searchParams?: string): Promise => { const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX - const token = getAccessToken(isPublicAPI) + const token = await getAccessToken(isPublicAPI) const defaultOptions = { method: 'POST', url: (url ? `${urlPrefix}${url}` : `${urlPrefix}/files/upload`) + (searchParams || ''), @@ -324,7 +324,7 @@ export const upload = (options: any, isPublicAPI?: boolean, url?: string, search }) } -export const ssePost = ( +export const ssePost = async ( url: string, fetchOptions: FetchOptionType, otherOptions: IOtherOptions, @@ -385,7 +385,7 @@ export const ssePost = ( if (body) options.body = JSON.stringify(body) - const accessToken = getAccessToken(isPublicAPI) + const accessToken = await getAccessToken(isPublicAPI) ; (options.headers as Headers).set('Authorization', `Bearer ${accessToken}`) globalThis.fetch(urlWithPrefix, options as RequestInit) diff --git a/web/service/fetch.ts b/web/service/fetch.ts index 75dd775f6c..fc41310c80 100644 --- a/web/service/fetch.ts +++ b/web/service/fetch.ts @@ -3,6 +3,8 @@ import ky from 'ky' import type { IOtherOptions } from './base' import Toast from '@/app/components/base/toast' import { API_PREFIX, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config' +import { getInitialTokenV2, isTokenV1 } from '@/app/components/share/utils' +import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/chat/utils' const TIME_OUT = 100000 @@ -67,44 +69,34 @@ const beforeErrorToast = (otherOptions: IOtherOptions): BeforeErrorHook => { } } -export const getPublicToken = () => { - let token = '' - const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] - const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) - let accessTokenJson = { [sharedToken]: '' } - try { - accessTokenJson = JSON.parse(accessToken) - } - catch { } - token = accessTokenJson[sharedToken] - return token || '' -} - -export function getAccessToken(isPublicAPI?: boolean) { +export async function getAccessToken(isPublicAPI?: boolean) { if (isPublicAPI) { const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] - const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) - let accessTokenJson = { [sharedToken]: '' } + const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id + const accessToken = localStorage.getItem('token') || JSON.stringify({ version: 2 }) + let accessTokenJson: Record = { version: 2 } try { accessTokenJson = JSON.parse(accessToken) + if (isTokenV1(accessTokenJson)) + accessTokenJson = getInitialTokenV2() } catch { } - return accessTokenJson[sharedToken] + return accessTokenJson[sharedToken]?.[userId || 'DEFAULT'] } else { return localStorage.getItem('console_token') || '' } } -const beforeRequestPublicAuthorization: BeforeRequestHook = (request) => { - const token = getAccessToken(true) +const beforeRequestPublicAuthorization: BeforeRequestHook = async (request) => { + const token = await getAccessToken(true) request.headers.set('Authorization', `Bearer ${token}`) } -const beforeRequestAuthorization: BeforeRequestHook = (request) => { - const accessToken = getAccessToken() +const beforeRequestAuthorization: BeforeRequestHook = async (request) => { + const accessToken = await getAccessToken() request.headers.set('Authorization', `Bearer ${accessToken}`) } From c91045a9d031f4acc200313a9e28389c6c7c871a Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Wed, 16 Apr 2025 22:34:07 +0800 Subject: [PATCH 22/68] fix(fail-branch): prevent streaming output in exception branches (#17153) --- .../nodes/answer/answer_stream_processor.py | 23 +++++++- .../workflow/nodes/test_continue_on_error.py | 57 +++++++++++++++++-- 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index d8ad1dbd49..ba6ba16e36 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor): for answer_node_id, route_position in self.route_position.items(): if answer_node_id not in self.rest_node_ids: continue - # exclude current node id + # Remove current node id from answer dependencies to support stream output if it is a success branch answer_dependencies = self.generate_routes.answer_dependencies - if event.node_id in answer_dependencies[answer_node_id]: + edge_mapping = self.graph.edge_mapping.get(event.node_id) + success_edge = ( + next( + ( + edge + for edge in edge_mapping + if edge.run_condition + and edge.run_condition.type == "branch_identify" + and edge.run_condition.branch_identify == "success-branch" + ), + None, + ) + if edge_mapping + else None + ) + if ( + event.node_id in answer_dependencies[answer_node_id] + and success_edge + and success_edge.target_node_id == answer_node_id + ): answer_dependencies[answer_node_id].remove(event.node_id) answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) # all depends on answer node id not in rest node ids diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index ed35d8a32a..111c647d9c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -1,14 +1,20 @@ +from unittest.mock import patch + from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, NodeRunExceptionEvent, + NodeRunFailedEvent, NodeRunStreamChunkEvent, ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom -from models.workflow import WorkflowType +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType class ContinueOnErrorTestHelper: @@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error(): "edges": FAIL_BRANCH_EDGES[:-1], "nodes": [ {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, - "id": "success", - }, + {"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"}, ContinueOnErrorTestHelper.get_http_node(), ], } @@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 + + +def test_stream_output_with_fail_branch_continue_on_error(): + """Test stream output with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_llm_node(), + ], + } + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + + def llm_generator(self): + contents = ["hi", "bye", "good morning"] + + yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"]) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={}, + process_data={}, + outputs={}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: 1, + NodeRunMetadataKey.TOTAL_PRICE: 1, + NodeRunMetadataKey.CURRENCY: "USD", + }, + ) + ) + + with patch.object(LLMNode, "_run", new=llm_generator): + events = list(graph_engine.run()) + assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1 + assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events) From 6da7e6158f14421d41fe50b8cacdce2e8027b14b Mon Sep 17 00:00:00 2001 From: AirLin <49427685+AAirLin@users.noreply.github.com> Date: Wed, 16 Apr 2025 23:07:05 +0800 Subject: [PATCH 23/68] Add the parameter appid to apiserver (#18224) --- web/app/components/develop/ApiServer.tsx | 4 +++- web/app/components/develop/index.tsx | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/web/app/components/develop/ApiServer.tsx b/web/app/components/develop/ApiServer.tsx index 4de98c6cd4..9f2c9cf7f4 100644 --- a/web/app/components/develop/ApiServer.tsx +++ b/web/app/components/develop/ApiServer.tsx @@ -7,9 +7,11 @@ import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-butt type ApiServerProps = { apiBaseUrl: string + appId?: string } const ApiServer: FC = ({ apiBaseUrl, + appId, }) => { const { t } = useTranslation() @@ -25,7 +27,7 @@ const ApiServer: FC = ({ {t('appApi.ok')}
) diff --git a/web/app/components/develop/index.tsx b/web/app/components/develop/index.tsx index 5b14b680c1..c3f88a15f8 100644 --- a/web/app/components/develop/index.tsx +++ b/web/app/components/develop/index.tsx @@ -23,7 +23,7 @@ const DevelopMain = ({ appId }: IDevelopMainProps) => {
- +
From a1d20085e659e816e70a2ff7c39e04fbf48292fb Mon Sep 17 00:00:00 2001 From: Chenming C <43266446+chen622@users.noreply.github.com> Date: Thu, 17 Apr 2025 10:10:27 +0800 Subject: [PATCH 24/68] fix: change the method of update_dataset api in document (#18197) --- .../datasets/template/template.en.mdx | 69 ++++++++++++++++--- .../datasets/template/template.zh.mdx | 69 ++++++++++++++++--- 2 files changed, 122 insertions(+), 16 deletions(-) diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 357b66a96f..54e08b45d8 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -557,7 +557,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi @@ -585,8 +585,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi Specified embedding model, corresponding to the model field(Optional) - - Specified retrieval model, corresponding to the model field(Optional) + + Retrieval model (optional, if not filled, it will be recalled according to the default method) + - search_method (text) Search method: One of the following four keywords is required + - keyword_search Keyword search + - semantic_search Semantic search + - full_text_search Full-text search + - hybrid_search Hybrid search + - reranking_enable (bool) Whether to enable reranking, required if the search mode is semantic_search or hybrid_search (optional) + - reranking_mode (object) Rerank model configuration, required if reranking is enabled + - reranking_provider_name (string) Rerank model provider + - reranking_model_name (string) Rerank model name + - weights (float) Semantic search weight setting in hybrid search mode + - top_k (integer) Number of results to return (optional) + - score_threshold_enabled (bool) Whether to enable score threshold + - score_threshold (float) Score threshold Partial member list(Optional) @@ -596,16 +609,56 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \ + curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ - --data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me",\ - "embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}' + --data-raw '{ + "name": "Test Knowledge Base", + "indexing_technique": "high_quality", + "permission": "only_me", + "embedding_model_provider": "zhipuai", + "embedding_model": "embedding-3", + "retrieval_model": { + "search_method": "keyword_search", + "reranking_enable": false, + "reranking_mode": null, + "reranking_model": { + "reranking_provider_name": "", + "reranking_model_name": "" + }, + "weights": null, + "top_k": 1, + "score_threshold_enabled": false, + "score_threshold": null + }, + "partial_member_list": [] + }' ``` diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index fb8f728b61..b435a9bb67 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -557,7 +557,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi @@ -589,8 +589,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi 嵌入模型(选填) - - 检索模型(选填) + + 检索参数(选填,如不填,按照默认方式召回) + - search_method (text) 检索方法:以下三个关键字之一,必填 + - keyword_search 关键字检索 + - semantic_search 语义检索 + - full_text_search 全文检索 + - hybrid_search 混合检索 + - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为 semantic_search 模式或者 hybrid_search 则传值 + - reranking_mode (object) Rerank 模型配置,非必填,如果启用了 reranking 则传值 + - reranking_provider_name (string) Rerank 模型提供商 + - reranking_model_name (string) Rerank 模型名称 + - weights (float) 混合检索模式下语意检索的权重设置 + - top_k (integer) 返回结果数量,非必填 + - score_threshold_enabled (bool) 是否开启 score 阈值 + - score_threshold (float) Score 阈值 部分团队成员 ID 列表(选填) @@ -600,16 +613,56 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \ + curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ - --data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me",\ - "embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}' + --data-raw '{ + "name": "Test Knowledge Base", + "indexing_technique": "high_quality", + "permission": "only_me", + "embedding_model_provider": "zhipuai", + "embedding_model": "embedding-3", + "retrieval_model": { + "search_method": "keyword_search", + "reranking_enable": false, + "reranking_mode": null, + "reranking_model": { + "reranking_provider_name": "", + "reranking_model_name": "" + }, + "weights": null, + "top_k": 1, + "score_threshold_enabled": false, + "score_threshold": null + }, + "partial_member_list": [] + }' ``` From e8d98e3d8907105c524f045c360d7115edc238b7 Mon Sep 17 00:00:00 2001 From: Rain Wang Date: Thu, 17 Apr 2025 10:38:56 +0800 Subject: [PATCH 25/68] Add analyzer_params config for milvus vectordb (#18180) --- api/.env.example | 1 + api/configs/middleware/vdb/milvus_config.py | 5 ++++ .../datasource/vdb/milvus/milvus_vector.py | 24 ++++++++++++------- docker/.env.example | 1 + docker/docker-compose.yaml | 1 + 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/api/.env.example b/api/.env.example index af95a4fe2d..502461f658 100644 --- a/api/.env.example +++ b/api/.env.example @@ -165,6 +165,7 @@ MILVUS_URI=http://127.0.0.1:19530 MILVUS_TOKEN= MILVUS_USER=root MILVUS_PASSWORD=Milvus +MILVUS_ANALYZER_PARAMS= # MyScale configuration MYSCALE_HOST=127.0.0.1 diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index ebdf8857b9..d398ef5bd8 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -39,3 +39,8 @@ class MilvusConfig(BaseSettings): "older versions", default=True, ) + + MILVUS_ANALYZER_PARAMS: Optional[str] = Field( + description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.', + default=None, + ) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 7a3319f4a6..100bcb198c 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -32,6 +32,7 @@ class MilvusConfig(BaseModel): batch_size: int = 100 # Batch size for operations database: str = "default" # Database name enable_hybrid_search: bool = False # Flag to enable hybrid search + analyzer_params: Optional[str] = None # Analyzer params @model_validator(mode="before") @classmethod @@ -58,6 +59,7 @@ class MilvusConfig(BaseModel): "user": self.user, "password": self.password, "db_name": self.database, + "analyzer_params": self.analyzer_params, } @@ -300,14 +302,19 @@ class MilvusVector(BaseVector): # Create the text field, enable_analyzer will be set True to support milvus automatically # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md - fields.append( - FieldSchema( - Field.CONTENT_KEY.value, - DataType.VARCHAR, - max_length=65_535, - enable_analyzer=self._hybrid_search_enabled, - ) - ) + content_field_kwargs: dict[str, Any] = { + "max_length": 65_535, + "enable_analyzer": self._hybrid_search_enabled, + } + if ( + self._hybrid_search_enabled + and self._client_config.analyzer_params is not None + and self._client_config.analyzer_params.strip() + ): + content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params + + fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs)) + # Create the primary key field fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors @@ -383,5 +390,6 @@ class MilvusVectorFactory(AbstractVectorFactory): password=dify_config.MILVUS_PASSWORD or "", database=dify_config.MILVUS_DATABASE or "", enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False, + analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "", ), ) diff --git a/docker/.env.example b/docker/.env.example index e49e8fee89..9b372dcec9 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -410,6 +410,7 @@ MILVUS_TOKEN= MILVUS_USER= MILVUS_PASSWORD= MILVUS_ENABLE_HYBRID_SEARCH=False +MILVUS_ANALYZER_PARAMS= # MyScale configuration, only available when VECTOR_STORE is `myscale` # For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 25b0c56561..172cbe2d2f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -142,6 +142,7 @@ x-shared-env: &shared-api-worker-env MILVUS_USER: ${MILVUS_USER:-} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-} MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False} + MILVUS_ANALYZER_PARAMS: ${MILVUS_ANALYZER_PARAMS:-} MYSCALE_HOST: ${MYSCALE_HOST:-myscale} MYSCALE_PORT: ${MYSCALE_PORT:-8123} MYSCALE_USER: ${MYSCALE_USER:-default} From 6d66e3f680b849cfb718e7dd73bdbd4916ce4194 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Thu, 17 Apr 2025 10:41:56 +0800 Subject: [PATCH 26/68] fix(follow_ups): handle empty LLM responses in context (#18237) --- api/core/memory/token_buffer_memory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 003a0c85b1..3c90dd22a2 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -44,6 +44,7 @@ class TokenBufferMemory: Message.created_at, Message.workflow_run_id, Message.parent_message_id, + Message.answer_tokens, ) .filter( Message.conversation_id == self.conversation.id, @@ -63,7 +64,7 @@ class TokenBufferMemory: thread_messages = extract_thread_messages(messages) # for newly created message, its answer is temporarily empty, we don't need to add it to memory - if thread_messages and not thread_messages[0].answer: + if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: thread_messages.pop(0) messages = list(reversed(thread_messages)) From 9d139fa30677821588fc03f360576a50bd5ad13d Mon Sep 17 00:00:00 2001 From: Joel Date: Thu, 17 Apr 2025 11:22:06 +0800 Subject: [PATCH 27/68] fix: Could not load the logo of workflow as Tool in Agent Node (#18243) --- .../workflow/nodes/agent/components/tool-icon.tsx | 6 ++++-- web/app/components/workflow/nodes/agent/node.tsx | 7 ++++--- web/app/components/workflow/nodes/agent/panel.tsx | 1 - 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/web/app/components/workflow/nodes/agent/components/tool-icon.tsx b/web/app/components/workflow/nodes/agent/components/tool-icon.tsx index 4ac789f22e..b94258855a 100644 --- a/web/app/components/workflow/nodes/agent/components/tool-icon.tsx +++ b/web/app/components/workflow/nodes/agent/components/tool-icon.tsx @@ -10,6 +10,7 @@ import { Group } from '@/app/components/base/icons/src/vender/other' type Status = 'not-installed' | 'not-authorized' | undefined export type ToolIconProps = { + id: string providerName: string } @@ -29,10 +30,11 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => { const author = providerNameParts[0] const name = providerNameParts[1] const icon = useMemo(() => { + if (!isDataReady) return '' if (currentProvider) return currentProvider.icon as string const iconFromMarketPlace = getIconFromMarketPlace(`${author}/${name}`) return iconFromMarketPlace - }, [author, currentProvider, name]) + }, [author, currentProvider, name, isDataReady]) const status: Status = useMemo(() => { if (!isDataReady) return undefined if (!currentProvider) return 'not-installed' @@ -60,7 +62,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => { )} ref={containerRef} > - {!iconFetchError + {(!iconFetchError && isDataReady) ? > = (props) => { const tools = useMemo(() => { const tools: Array = [] - currentStrategy?.parameters.forEach((param) => { + currentStrategy?.parameters.forEach((param, i) => { if (param.type === FormTypeEnum.toolSelector) { const field = param.name const value = inputs.agent_parameters?.[field]?.value if (value) { tools.push({ + id: `${param.name}-${i}`, providerName: value.provider_name as any, }) } @@ -55,6 +56,7 @@ const AgentNode: FC> = (props) => { if (value) { (value as unknown as any[]).forEach((item) => { tools.push({ + id: `${param.name}-${i}`, providerName: item.provider_name, }) }) @@ -102,8 +104,7 @@ const AgentNode: FC> = (props) => { {t('workflow.nodes.agent.toolbox')} }>
- {/* eslint-disable-next-line sonarjs/no-uniq-key */} - {tools.map(tool => )} + {tools.map(tool => )}
}
diff --git a/web/app/components/workflow/nodes/agent/panel.tsx b/web/app/components/workflow/nodes/agent/panel.tsx index 6a80728d91..19be60cb51 100644 --- a/web/app/components/workflow/nodes/agent/panel.tsx +++ b/web/app/components/workflow/nodes/agent/panel.tsx @@ -54,7 +54,6 @@ const AgentPanel: FC> = (props) => { outputSchema, handleMemoryChange, } = useConfig(props.id, props.data) - console.log('currentStrategy', currentStrategy) const { t } = useTranslation() const nodeInfo = useMemo(() => { if (!runResult) From 77fde04ef7ec3e24935320485f3616ee185a1d8b Mon Sep 17 00:00:00 2001 From: GuanMu Date: Thu, 17 Apr 2025 11:47:59 +0800 Subject: [PATCH 28/68] style: add left padding to editor component and remove unused CSS (#18247) --- .../workflow/nodes/_base/components/editor/base.tsx | 2 +- .../nodes/_base/components/editor/code-editor/style.css | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/web/app/components/workflow/nodes/_base/components/editor/base.tsx b/web/app/components/workflow/nodes/_base/components/editor/base.tsx index 3b31f44619..38968b2e0d 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/base.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/base.tsx @@ -109,7 +109,7 @@ const Base: FC = ({ onHeightChange={setEditorContentHeight} hideResize={isExpand} > -
+
{children}
diff --git a/web/app/components/workflow/nodes/_base/components/editor/code-editor/style.css b/web/app/components/workflow/nodes/_base/components/editor/code-editor/style.css index 296ea0ab14..72e0087a3c 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/code-editor/style.css +++ b/web/app/components/workflow/nodes/_base/components/editor/code-editor/style.css @@ -1,10 +1,3 @@ -.margin-view-overlays { - padding-left: 10px; -} - -.no-wrapper .margin-view-overlays { - padding-left: 0; -} .monaco-editor { background-color: transparent !important; From 6d9dd3109e807f0648a1fe7d0844c766717488b4 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 17 Apr 2025 12:48:52 +0900 Subject: [PATCH 29/68] feat: add a abstract layer for WorkflowNodeExcetion (#18026) --- api/.env.example | 6 + api/app_factory.py | 2 + api/configs/feature/__init__.py | 7 +- .../advanced_chat/generate_task_pipeline.py | 28 ++- .../apps/workflow/generate_task_pipeline.py | 46 ++--- .../task_pipeline/workflow_cycle_manage.py | 109 ++++++----- api/core/ops/langfuse_trace/langfuse_trace.py | 39 ++-- .../ops/langsmith_trace/langsmith_trace.py | 43 ++--- api/core/ops/opik_trace/opik_trace.py | 43 ++--- api/core/repository/__init__.py | 15 ++ api/core/repository/repository_factory.py | 97 ++++++++++ .../workflow_node_execution_repository.py | 88 +++++++++ api/extensions/ext_repositories.py | 18 ++ api/extensions/ext_storage.py | 48 ++--- api/models/workflow.py | 30 +--- api/repositories/__init__.py | 6 + api/repositories/repository_registry.py | 87 +++++++++ .../workflow_node_execution/__init__.py | 9 + .../sqlalchemy_repository.py | 170 ++++++++++++++++++ api/tests/unit_tests/repositories/__init__.py | 3 + .../workflow_node_execution/__init__.py | 3 + .../test_sqlalchemy_repository.py | 154 ++++++++++++++++ docker/.env.example | 6 + docker/docker-compose.yaml | 1 + 24 files changed, 807 insertions(+), 251 deletions(-) create mode 100644 api/core/repository/__init__.py create mode 100644 api/core/repository/repository_factory.py create mode 100644 api/core/repository/workflow_node_execution_repository.py create mode 100644 api/extensions/ext_repositories.py create mode 100644 api/repositories/__init__.py create mode 100644 api/repositories/repository_registry.py create mode 100644 api/repositories/workflow_node_execution/__init__.py create mode 100644 api/repositories/workflow_node_execution/sqlalchemy_repository.py create mode 100644 api/tests/unit_tests/repositories/__init__.py create mode 100644 api/tests/unit_tests/repositories/workflow_node_execution/__init__.py create mode 100644 api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py diff --git a/api/.env.example b/api/.env.example index 502461f658..01ddb4adfd 100644 --- a/api/.env.example +++ b/api/.env.example @@ -424,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5 WORKFLOW_PARALLEL_DEPTH_LIMIT=3 MAX_VARIABLE_SIZE=204800 +# Workflow storage configuration +# Options: rdbms, hybrid +# rdbms: Use only the relational database (default) +# hybrid: Save new data to object storage, read from both object storage and RDBMS +WORKFLOW_NODE_EXECUTION_STORAGE=rdbms + # App configuration APP_MAX_EXECUTION_TIME=1200 APP_MAX_ACTIVE_REQUESTS=0 diff --git a/api/app_factory.py b/api/app_factory.py index 1c886ac5c7..586f2ded9e 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp): ext_otel, ext_proxy_fix, ext_redis, + ext_repositories, ext_sentry, ext_set_secretkey, ext_storage, @@ -74,6 +75,7 @@ def initialize_extensions(app: DifyApp): ext_migrate, ext_redis, ext_storage, + ext_repositories, ext_celery, ext_login, ext_mail, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d35a74e3ee..f498dccbbc 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -12,7 +12,7 @@ from pydantic import ( ) from pydantic_settings import BaseSettings -from configs.feature.hosted_service import HostedServiceConfig +from .hosted_service import HostedServiceConfig class SecurityConfig(BaseSettings): @@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings): default=100, ) + WORKFLOW_NODE_EXECUTION_STORAGE: str = Field( + default="rdbms", + description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'", + ) + class AuthConfig(BaseSettings): """ diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 66f2c754bb..3bf6c330db 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( - session=session, event=event + event=event ) node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( - session=session, event=event - ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + event=event + ) - node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_finish_resp: yield node_finish_resp diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 14441ada40..1f998edb6a 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline: if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( - session=session, event=event - ) - node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + event=event + ) + node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_success_response: yield node_success_response @@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( - session=session, - event=event, - ) - node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + event=event, + ) + node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_failed_response: yield node_failed_response @@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline: workflow_app_log.created_by = self._user_id session.add(workflow_app_log) + session.commit() def _text_chunk_to_stream_response( self, text: str, from_variable_selector: Optional[list[str]] = None diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4d629ca186..5ce9f737d1 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast from uuid import uuid4 from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repository import RepositoryFactory from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser @@ -80,6 +82,21 @@ class WorkflowCycleManage: self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables + # Initialize the session factory and repository + # We use the global db engine instead of the session passed to methods + # Disable expire_on_commit to avoid the need for merging objects + self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": self._application_generate_entity.app_config.tenant_id, + "app_id": self._application_generate_entity.app_config.app_id, + "session_factory": self._session_factory, + } + ) + + # We'll still keep the cache for backward compatibility and performance + # but use the repository for database operations + def _handle_workflow_run_start( self, *, @@ -254,19 +271,15 @@ class WorkflowCycleManage: workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - stmt = select(WorkflowNodeExecution.node_execution_id).where( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + # Use the instance repository to find running executions for a workflow run + running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( + workflow_run_id=workflow_run.id ) - ids = session.scalars(stmt).all() - # Use self._get_workflow_node_execution here to make sure the cache is updated - running_workflow_node_executions = [ - self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id - ] + + # Update the cache with the retrieved executions + for execution in running_workflow_node_executions: + if execution.node_execution_id: + self._workflow_node_executions[execution.node_execution_id] = execution for workflow_node_execution in running_workflow_node_executions: now = datetime.now(UTC).replace(tzinfo=None) @@ -288,7 +301,7 @@ class WorkflowCycleManage: return workflow_run def _handle_node_execution_start( - self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent ) -> WorkflowNodeExecution: workflow_node_execution = WorkflowNodeExecution() workflow_node_execution.id = str(uuid4()) @@ -315,17 +328,14 @@ class WorkflowCycleManage: ) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - session.add(workflow_node_execution) + # Use the instance repository to save the workflow node execution + self._workflow_node_execution_repository.save(workflow_node_execution) self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution - def _handle_workflow_node_execution_success( - self, *, session: Session, event: QueueNodeSucceededEvent - ) -> WorkflowNodeExecution: - workflow_node_execution = self._get_workflow_node_execution( - session=session, node_execution_id=event.node_execution_id - ) + def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: + workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) @@ -344,13 +354,13 @@ class WorkflowCycleManage: workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution = session.merge(workflow_node_execution) + # Use the instance repository to update the workflow node execution + self._workflow_node_execution_repository.update(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_failed( self, *, - session: Session, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeInLoopFailedEvent @@ -361,9 +371,7 @@ class WorkflowCycleManage: :param event: queue node failed event :return: """ - workflow_node_execution = self._get_workflow_node_execution( - session=session, node_execution_id=event.node_execution_id - ) + workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) @@ -387,14 +395,14 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_retried( - self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed + :param workflow_run: workflow run :param event: queue node failed event :return: """ @@ -439,15 +447,12 @@ class WorkflowCycleManage: workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.index = event.node_run_index - session.add(workflow_node_execution) + # Use the instance repository to save the workflow node execution + self._workflow_node_execution_repository.save(workflow_node_execution) self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution - ################################################# - # to stream responses # - ################################################# - def _workflow_start_to_stream_response( self, *, @@ -455,7 +460,6 @@ class WorkflowCycleManage: task_id: str, workflow_run: WorkflowRun, ) -> WorkflowStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return WorkflowStartStreamResponse( task_id=task_id, @@ -521,14 +525,10 @@ class WorkflowCycleManage: def _workflow_node_start_to_stream_response( self, *, - session: Session, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeStartStreamResponse]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -571,7 +571,6 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, *, - session: Session, event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent @@ -580,8 +579,6 @@ class WorkflowCycleManage: task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -621,13 +618,10 @@ class WorkflowCycleManage: def _workflow_node_retry_to_stream_response( self, *, - session: Session, event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -668,7 +662,6 @@ class WorkflowCycleManage: def _workflow_parallel_branch_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return ParallelBranchStartStreamResponse( task_id=task_id, @@ -692,7 +685,6 @@ class WorkflowCycleManage: workflow_run: WorkflowRun, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, ) -> ParallelBranchFinishedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return ParallelBranchFinishedStreamResponse( task_id=task_id, @@ -713,7 +705,6 @@ class WorkflowCycleManage: def _workflow_iteration_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent ) -> IterationNodeStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeStartStreamResponse( task_id=task_id, @@ -735,7 +726,6 @@ class WorkflowCycleManage: def _workflow_iteration_next_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent ) -> IterationNodeNextStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeNextStreamResponse( task_id=task_id, @@ -759,7 +749,6 @@ class WorkflowCycleManage: def _workflow_iteration_completed_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent ) -> IterationNodeCompletedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeCompletedStreamResponse( task_id=task_id, @@ -790,7 +779,6 @@ class WorkflowCycleManage: def _workflow_loop_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent ) -> LoopNodeStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeStartStreamResponse( task_id=task_id, @@ -812,7 +800,6 @@ class WorkflowCycleManage: def _workflow_loop_next_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent ) -> LoopNodeNextStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeNextStreamResponse( task_id=task_id, @@ -836,7 +823,6 @@ class WorkflowCycleManage: def _workflow_loop_completed_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent ) -> LoopNodeCompletedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeCompletedStreamResponse( task_id=task_id, @@ -934,11 +920,22 @@ class WorkflowCycleManage: return workflow_run - def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: - if node_execution_id not in self._workflow_node_executions: + def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + # First check the cache for performance + if node_execution_id in self._workflow_node_executions: + cached_execution = self._workflow_node_executions[node_execution_id] + # No need to merge with session since expire_on_commit=False + return cached_execution + + # If not in cache, use the instance repository to get by node_execution_id + execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) + + if not execution: raise ValueError(f"Workflow node execution not found: {node_execution_id}") - cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] - return session.merge(cached_workflow_node_execution) + + # Update cache + self._workflow_node_executions[node_execution_id] = execution + return execution def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: """ diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index f67e270ab1..fa78b7b8e9 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from typing import Optional from langfuse import Langfuse # type: ignore +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangfuseConfig @@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_trace(langfuse_trace_data=trace_data) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory}, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index e3494e2f23..85a0eafdc1 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -7,6 +7,7 @@ from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangSmithConfig @@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance): self.add_run(langsmith_run) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": trace_info.tenant_id, + "app_id": trace_info.metadata.get("app_id"), + "session_factory": session_factory, + }, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fabf38fbd6..923b9a24ed 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -7,6 +7,7 @@ from typing import Optional, cast from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import OpikConfig @@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance): } self.add_trace(trace_data) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": trace_info.tenant_id, + "app_id": trace_info.metadata.get("app_id"), + "session_factory": session_factory, + }, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/repository/__init__.py b/api/core/repository/__init__.py new file mode 100644 index 0000000000..253df1251d --- /dev/null +++ b/api/core/repository/__init__.py @@ -0,0 +1,15 @@ +""" +Repository interfaces for data access. + +This package contains repository interfaces that define the contract +for accessing and manipulating data, regardless of the underlying +storage mechanism. +""" + +from core.repository.repository_factory import RepositoryFactory +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository + +__all__ = [ + "RepositoryFactory", + "WorkflowNodeExecutionRepository", +] diff --git a/api/core/repository/repository_factory.py b/api/core/repository/repository_factory.py new file mode 100644 index 0000000000..02e343d7ff --- /dev/null +++ b/api/core/repository/repository_factory.py @@ -0,0 +1,97 @@ +""" +Repository factory for creating repository instances. + +This module provides a simple factory interface for creating repository instances. +It does not contain any implementation details or dependencies on specific repositories. +""" + +from collections.abc import Callable, Mapping +from typing import Any, Literal, Optional, cast + +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository + +# Type for factory functions - takes a dict of parameters and returns any repository type +RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any] + +# Type for workflow node execution factory function +WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository] + +# Repository type literals +RepositoryType = Literal["workflow_node_execution"] + + +class RepositoryFactory: + """ + Factory class for creating repository instances. + + This factory delegates the actual repository creation to implementation-specific + factory functions that are registered with the factory at runtime. + """ + + # Dictionary to store factory functions + _factory_functions: dict[str, RepositoryFactoryFunc] = {} + + @classmethod + def _register_factory(cls, repository_type: RepositoryType, factory_func: RepositoryFactoryFunc) -> None: + """ + Register a factory function for a specific repository type. + This is a private method and should not be called directly. + + Args: + repository_type: The type of repository (e.g., 'workflow_node_execution') + factory_func: A function that takes parameters and returns a repository instance + """ + cls._factory_functions[repository_type] = factory_func + + @classmethod + def _create_repository(cls, repository_type: RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any: + """ + Create a new repository instance with the provided parameters. + This is a private method and should not be called directly. + + Args: + repository_type: The type of repository to create + params: A dictionary of parameters to pass to the factory function + + Returns: + A new instance of the requested repository + + Raises: + ValueError: If no factory function is registered for the repository type + """ + if repository_type not in cls._factory_functions: + raise ValueError(f"No factory function registered for repository type '{repository_type}'") + + # Use empty dict if params is None + params = params or {} + + return cls._factory_functions[repository_type](params) + + @classmethod + def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None: + """ + Register a factory function for the workflow node execution repository. + + Args: + factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance + """ + cls._register_factory("workflow_node_execution", factory_func) + + @classmethod + def create_workflow_node_execution_repository( + cls, params: Optional[Mapping[str, Any]] = None + ) -> WorkflowNodeExecutionRepository: + """ + Create a new WorkflowNodeExecutionRepository instance with the provided parameters. + + Args: + params: A dictionary of parameters to pass to the factory function + + Returns: + A new instance of the WorkflowNodeExecutionRepository + + Raises: + ValueError: If no factory function is registered for the workflow_node_execution repository type + """ + # We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc + return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params)) diff --git a/api/core/repository/workflow_node_execution_repository.py b/api/core/repository/workflow_node_execution_repository.py new file mode 100644 index 0000000000..6dea4566de --- /dev/null +++ b/api/core/repository/workflow_node_execution_repository.py @@ -0,0 +1,88 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Optional, Protocol + +from models.workflow import WorkflowNodeExecution + + +@dataclass +class OrderConfig: + """Configuration for ordering WorkflowNodeExecution instances.""" + + order_by: list[str] + order_direction: Optional[Literal["asc", "desc"]] = None + + +class WorkflowNodeExecutionRepository(Protocol): + """ + Repository interface for WorkflowNodeExecution. + + This interface defines the contract for accessing and manipulating + WorkflowNodeExecution data, regardless of the underlying storage mechanism. + + Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), + and trigger sources (triggered_from) should be handled at the implementation level, not in + the core interface. This keeps the core domain model clean and independent of specific + application domains or deployment scenarios. + """ + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save a WorkflowNodeExecution instance. + + Args: + execution: The WorkflowNodeExecution instance to save + """ + ... + + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + """ + Retrieve a WorkflowNodeExecution by its node_execution_id. + + Args: + node_execution_id: The node execution ID + + Returns: + The WorkflowNodeExecution instance if found, None otherwise + """ + ... + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of WorkflowNodeExecution instances + """ + ... + + def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + + Returns: + A list of running WorkflowNodeExecution instances + """ + ... + + def update(self, execution: WorkflowNodeExecution) -> None: + """ + Update an existing WorkflowNodeExecution instance. + + Args: + execution: The WorkflowNodeExecution instance to update + """ + ... diff --git a/api/extensions/ext_repositories.py b/api/extensions/ext_repositories.py new file mode 100644 index 0000000000..27d8408ec1 --- /dev/null +++ b/api/extensions/ext_repositories.py @@ -0,0 +1,18 @@ +""" +Extension for initializing repositories. + +This extension registers repository implementations with the RepositoryFactory. +""" + +from dify_app import DifyApp +from repositories.repository_registry import register_repositories + + +def init_app(_app: DifyApp) -> None: + """ + Initialize repository implementations. + + Args: + _app: The Flask application instance (unused) + """ + register_repositories() diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 588bdb2d27..4c811c66ba 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -73,11 +73,7 @@ class Storage: raise ValueError(f"unsupported storage type {storage_type}") def save(self, filename, data): - try: - self.storage_runner.save(filename, data) - except Exception as e: - logger.exception(f"Failed to save file {filename}") - raise e + self.storage_runner.save(filename, data) @overload def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... @@ -86,49 +82,25 @@ class Storage: def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: - try: - if stream: - return self.load_stream(filename) - else: - return self.load_once(filename) - except Exception as e: - logger.exception(f"Failed to load file {filename}") - raise e + if stream: + return self.load_stream(filename) + else: + return self.load_once(filename) def load_once(self, filename: str) -> bytes: - try: - return self.storage_runner.load_once(filename) - except Exception as e: - logger.exception(f"Failed to load_once file {filename}") - raise e + return self.storage_runner.load_once(filename) def load_stream(self, filename: str) -> Generator: - try: - return self.storage_runner.load_stream(filename) - except Exception as e: - logger.exception(f"Failed to load_stream file {filename}") - raise e + return self.storage_runner.load_stream(filename) def download(self, filename, target_filepath): - try: - self.storage_runner.download(filename, target_filepath) - except Exception as e: - logger.exception(f"Failed to download file {filename}") - raise e + self.storage_runner.download(filename, target_filepath) def exists(self, filename): - try: - return self.storage_runner.exists(filename) - except Exception as e: - logger.exception(f"Failed to check file exists {filename}") - raise e + return self.storage_runner.exists(filename) def delete(self, filename): - try: - return self.storage_runner.delete(filename) - except Exception as e: - logger.exception(f"Failed to delete file {filename}") - raise e + return self.storage_runner.delete(filename) storage = Storage() diff --git a/api/models/workflow.py b/api/models/workflow.py index 8b7c376e4b..045fa0aaa0 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -510,7 +510,7 @@ class WorkflowRun(Base): ) -class WorkflowNodeExecutionTriggeredFrom(Enum): +class WorkflowNodeExecutionTriggeredFrom(StrEnum): """ Workflow Node Execution Triggered From Enum """ @@ -518,21 +518,8 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): SINGLE_STEP = "single-step" WORKFLOW_RUN = "workflow-run" - @classmethod - def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": - """ - Get value of given mode. - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid workflow node execution triggered from value {value}") - - -class WorkflowNodeExecutionStatus(Enum): +class WorkflowNodeExecutionStatus(StrEnum): """ Workflow Node Execution Status Enum """ @@ -543,19 +530,6 @@ class WorkflowNodeExecutionStatus(Enum): EXCEPTION = "exception" RETRY = "retry" - @classmethod - def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid workflow node execution status value {value}") - class WorkflowNodeExecution(Base): """ diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py new file mode 100644 index 0000000000..4cc339688b --- /dev/null +++ b/api/repositories/__init__.py @@ -0,0 +1,6 @@ +""" +Repository implementations for data access. + +This package contains concrete implementations of the repository interfaces +defined in the core.repository package. +""" diff --git a/api/repositories/repository_registry.py b/api/repositories/repository_registry.py new file mode 100644 index 0000000000..aa0a208d8e --- /dev/null +++ b/api/repositories/repository_registry.py @@ -0,0 +1,87 @@ +""" +Registry for repository implementations. + +This module is responsible for registering factory functions with the repository factory. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repository.repository_factory import RepositoryFactory +from extensions.ext_database import db +from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository + +logger = logging.getLogger(__name__) + +# Storage type constants +STORAGE_TYPE_RDBMS = "rdbms" +STORAGE_TYPE_HYBRID = "hybrid" + + +def register_repositories() -> None: + """ + Register repository factory functions with the RepositoryFactory. + + This function reads configuration settings to determine which repository + implementations to register. + """ + # Configure WorkflowNodeExecutionRepository factory based on configuration + workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE + + # Check storage type and register appropriate implementation + if workflow_node_execution_storage == STORAGE_TYPE_RDBMS: + # Register SQLAlchemy implementation for RDBMS storage + logger.info("Registering WorkflowNodeExecution repository with RDBMS storage") + RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository) + elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID: + # Hybrid storage is not yet implemented + raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented") + else: + # Unknown storage type + raise ValueError( + f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. " + f"Supported types: {STORAGE_TYPE_RDBMS}" + ) + + +def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository: + """ + Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation. + + This factory function creates a repository for the RDBMS storage type. + + Args: + params: Parameters for creating the repository, including: + - tenant_id: Required. The tenant ID for multi-tenancy. + - app_id: Optional. The application ID for filtering. + - session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided, + a new sessionmaker will be created using the global database engine. + + Returns: + A WorkflowNodeExecutionRepository instance + + Raises: + ValueError: If required parameters are missing + """ + # Extract required parameters + tenant_id = params.get("tenant_id") + if tenant_id is None: + raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage") + + # Extract optional parameters + app_id = params.get("app_id") + + # Use the session_factory from params if provided, otherwise create one using the global db engine + session_factory = params.get("session_factory") + if session_factory is None: + # Create a sessionmaker using the same engine as the global db session + session_factory = sessionmaker(bind=db.engine) + + # Create and return the repository + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, tenant_id=tenant_id, app_id=app_id + ) diff --git a/api/repositories/workflow_node_execution/__init__.py b/api/repositories/workflow_node_execution/__init__.py new file mode 100644 index 0000000000..eed827bd05 --- /dev/null +++ b/api/repositories/workflow_node_execution/__init__.py @@ -0,0 +1,9 @@ +""" +WorkflowNodeExecution repository implementations. +""" + +from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository + +__all__ = [ + "SQLAlchemyWorkflowNodeExecutionRepository", +] diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/repositories/workflow_node_execution/sqlalchemy_repository.py new file mode 100644 index 0000000000..01c54dfcd7 --- /dev/null +++ b/api/repositories/workflow_node_execution/sqlalchemy_repository.py @@ -0,0 +1,170 @@ +""" +SQLAlchemy implementation of the WorkflowNodeExecutionRepository. +""" + +import logging +from collections.abc import Sequence +from typing import Optional + +from sqlalchemy import UnaryExpression, asc, desc, select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repository.workflow_node_execution_repository import OrderConfig +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +class SQLAlchemyWorkflowNodeExecutionRepository: + """ + SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface. + + This implementation supports multi-tenancy by filtering operations based on tenant_id. + Each method creates its own session, handles the transaction, and commits changes + to the database. This prevents long-running connections in the workflow core. + """ + + def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + tenant_id: Tenant ID for multi-tenancy + app_id: Optional app ID for filtering by application + """ + # If an engine is provided, create a sessionmaker from it + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory) + else: + self._session_factory = session_factory + + self._tenant_id = tenant_id + self._app_id = app_id + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save a WorkflowNodeExecution instance and commit changes to the database. + + Args: + execution: The WorkflowNodeExecution instance to save + """ + with self._session_factory() as session: + # Ensure tenant_id is set + if not execution.tenant_id: + execution.tenant_id = self._tenant_id + + # Set app_id if provided and not already set + if self._app_id and not execution.app_id: + execution.app_id = self._app_id + + session.add(execution) + session.commit() + + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + """ + Retrieve a WorkflowNodeExecution by its node_execution_id. + + Args: + node_execution_id: The node execution ID + + Returns: + The WorkflowNodeExecution instance if found, None otherwise + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.node_execution_id == node_execution_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + return session.scalar(stmt) + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of WorkflowNodeExecution instances + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.workflow_run_id == workflow_run_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + # Apply ordering if provided + if order_config and order_config.order_by: + order_columns: list[UnaryExpression] = [] + for field in order_config.order_by: + column = getattr(WorkflowNodeExecution, field, None) + if not column: + continue + if order_config.order_direction == "desc": + order_columns.append(desc(column)) + else: + order_columns.append(asc(column)) + + if order_columns: + stmt = stmt.order_by(*order_columns) + + return session.scalars(stmt).all() + + def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + + Returns: + A list of running WorkflowNodeExecution instances + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.workflow_run_id == workflow_run_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + return session.scalars(stmt).all() + + def update(self, execution: WorkflowNodeExecution) -> None: + """ + Update an existing WorkflowNodeExecution instance and commit changes to the database. + + Args: + execution: The WorkflowNodeExecution instance to update + """ + with self._session_factory() as session: + # Ensure tenant_id is set + if not execution.tenant_id: + execution.tenant_id = self._tenant_id + + # Set app_id if provided and not already set + if self._app_id and not execution.app_id: + execution.app_id = self._app_id + + session.merge(execution) + session.commit() diff --git a/api/tests/unit_tests/repositories/__init__.py b/api/tests/unit_tests/repositories/__init__.py new file mode 100644 index 0000000000..bc0d6e78c9 --- /dev/null +++ b/api/tests/unit_tests/repositories/__init__.py @@ -0,0 +1,3 @@ +""" +Unit tests for repositories. +""" diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py new file mode 100644 index 0000000000..78815a8d1a --- /dev/null +++ b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py @@ -0,0 +1,3 @@ +""" +Unit tests for workflow_node_execution repositories. +""" diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py new file mode 100644 index 0000000000..f31adab2a8 --- /dev/null +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -0,0 +1,154 @@ +""" +Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. +""" + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy.orm import Session, sessionmaker + +from core.repository.workflow_node_execution_repository import OrderConfig +from models.workflow import WorkflowNodeExecution +from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository + + +@pytest.fixture +def session(): + """Create a mock SQLAlchemy session.""" + session = MagicMock(spec=Session) + # Configure the session to be used as a context manager + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=None) + + # Configure the session factory to return the session + session_factory = MagicMock(spec=sessionmaker) + session_factory.return_value = session + return session, session_factory + + +@pytest.fixture +def repository(session): + """Create a repository instance with test data.""" + _, session_factory = session + tenant_id = "test-tenant" + app_id = "test-app" + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, tenant_id=tenant_id, app_id=app_id + ) + + +def test_save(repository, session): + """Test save method.""" + session_obj, _ = session + # Create a mock execution + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = None + execution.app_id = None + + # Call save method + repository.save(execution) + + # Assert tenant_id and app_id are set + assert execution.tenant_id == repository._tenant_id + assert execution.app_id == repository._app_id + + # Assert session.add was called + session_obj.add.assert_called_once_with(execution) + + +def test_save_with_existing_tenant_id(repository, session): + """Test save method with existing tenant_id.""" + session_obj, _ = session + # Create a mock execution with existing tenant_id + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = "existing-tenant" + execution.app_id = None + + # Call save method + repository.save(execution) + + # Assert tenant_id is not changed and app_id is set + assert execution.tenant_id == "existing-tenant" + assert execution.app_id == repository._app_id + + # Assert session.add was called + session_obj.add.assert_called_once_with(execution) + + +def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): + """Test get_by_node_execution_id method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution) + + # Call method + result = repository.get_by_node_execution_id("test-node-execution-id") + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalar.assert_called_once_with(mock_stmt) + assert result is not None + + +def test_get_by_workflow_run(repository, session, mocker: MockerFixture): + """Test get_by_workflow_run method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] + + # Call method + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config) + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalars.assert_called_once_with(mock_stmt) + assert len(result) == 1 + + +def test_get_running_executions(repository, session, mocker: MockerFixture): + """Test get_running_executions method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] + + # Call method + result = repository.get_running_executions("test-workflow-run-id") + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalars.assert_called_once_with(mock_stmt) + assert len(result) == 1 + + +def test_update(repository, session): + """Test update method.""" + session_obj, _ = session + # Create a mock execution + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = None + execution.app_id = None + + # Call update method + repository.update(execution) + + # Assert tenant_id and app_id are set + assert execution.tenant_id == repository._tenant_id + assert execution.app_id == repository._app_id + + # Assert session.merge was called + session_obj.merge.assert_called_once_with(execution) diff --git a/docker/.env.example b/docker/.env.example index 9b372dcec9..82ef4174c2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -744,6 +744,12 @@ MAX_VARIABLE_SIZE=204800 WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_FILE_UPLOAD_LIMIT=10 +# Workflow storage configuration +# Options: rdbms, hybrid +# rdbms: Use only the relational database (default) +# hybrid: Save new data to object storage, read from both object storage and RDBMS +WORKFLOW_NODE_EXECUTION_STORAGE=rdbms + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 172cbe2d2f..e01b9f7e9a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -327,6 +327,7 @@ x-shared-env: &shared-api-worker-env MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} + WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} From 83f1aeec1d255d9d474147d651deda3eca904fa5 Mon Sep 17 00:00:00 2001 From: Rain Wang Date: Thu, 17 Apr 2025 14:15:05 +0800 Subject: [PATCH 30/68] Fix ORDER BY (score, id) error in api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py line 249 (#18252) --- api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 778e8a07d8..c1792943bb 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -246,7 +246,7 @@ class AnalyticdbVectorBySql: ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score FROM {self.table_name} WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause} - ORDER BY (score,id) DESC + ORDER BY score DESC, id DESC LIMIT {top_k}""", (f"'{query}'", f"'{query}'"), ) From e8e47aee21b3e190b5512d3aea9f5eed6d20d42e Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Thu, 17 Apr 2025 15:17:22 +0800 Subject: [PATCH 31/68] fix: Access the text-generation app's API doc error (#18278) --- web/app/components/develop/template/template.zh.mdx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx index 17a2090dce..24abb481e3 100755 --- a/web/app/components/develop/template/template.zh.mdx +++ b/web/app/components/develop/template/template.zh.mdx @@ -776,6 +776,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 嵌入模型的提供商和模型名称可以通过以下接口获取:v1/workspaces/current/models/model-types/text-embedding, 具体见:通过 API 维护知识库。 使用的Authorization是Dataset的API Token。 + 该接口是异步执行,所以会返回一个job_id,通过查询job状态接口可以获取到最终的执行结果。 From caa179a1d3fec3795df40f3b2dcd1709c12dd474 Mon Sep 17 00:00:00 2001 From: moonpanda Date: Thu, 17 Apr 2025 15:25:31 +0800 Subject: [PATCH 32/68] If the DSL version is less than 0.1.5, it causes errors in an intranet environment. (#18273) Co-authored-by: warlocgao --- api/services/plugin/dependencies_analysis.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/services/plugin/dependencies_analysis.py b/api/services/plugin/dependencies_analysis.py index 778f05a0cd..07e624b4e8 100644 --- a/api/services/plugin/dependencies_analysis.py +++ b/api/services/plugin/dependencies_analysis.py @@ -1,3 +1,4 @@ +from configs import dify_config from core.helper import marketplace from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID from core.plugin.manager.plugin import PluginInstallationManager @@ -111,6 +112,8 @@ class DependenciesAnalysisService: Generate the latest version of dependencies """ dependencies = list(set(dependencies)) + if not dify_config.MARKETPLACE_ENABLED: + return [] deps = marketplace.batch_fetch_plugin_manifests(dependencies) return [ PluginDependency( From 22a1bc337f7a46dc75c58d8fc88e0bde8af6590b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 17 Apr 2025 16:44:00 +0900 Subject: [PATCH 33/68] fix: perferred model provider not match with provider. (#18282) Signed-off-by: -LAN- --- api/core/provider_manager.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 099acfd7f4..7570200175 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -124,6 +124,15 @@ class ProviderManager: # Get All preferred provider types of the workspace provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) + # Ensure that both the original provider name and its ModelProviderID string representation + # are present in the dictionary to handle cases where either form might be used + for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()): + provider_id = ModelProviderID(provider_name) + if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict: + # Add the ModelProviderID string representation if it's not already present + provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = ( + provider_name_to_preferred_model_provider_records_dict[provider_name] + ) # Get All provider model settings provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) @@ -497,8 +506,8 @@ class ProviderManager: @staticmethod def _init_trial_provider_records( - tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] - ) -> dict[str, list]: + tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] + ) -> dict[str, list[Provider]]: """ Initialize trial provider records if not exists. @@ -532,7 +541,7 @@ class ProviderManager: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic - provider_record = Provider( + new_provider_record = Provider( tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, @@ -542,11 +551,12 @@ class ProviderManager: quota_used=0, is_valid=True, ) - db.session.add(provider_record) + db.session.add(new_provider_record) db.session.commit() + provider_name_to_provider_records_dict[provider_name].append(new_provider_record) except IntegrityError: db.session.rollback() - provider_record = ( + existed_provider_record = ( db.session.query(Provider) .filter( Provider.tenant_id == tenant_id, @@ -556,11 +566,14 @@ class ProviderManager: ) .first() ) - if provider_record and not provider_record.is_valid: - provider_record.is_valid = True + if not existed_provider_record: + continue + + if not existed_provider_record.is_valid: + existed_provider_record.is_valid = True db.session.commit() - provider_name_to_provider_records_dict[provider_name].append(provider_record) + provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) return provider_name_to_provider_records_dict From b6b608219aae21fcfe8e36a19e95ea3aee24d62e Mon Sep 17 00:00:00 2001 From: GuanMu Date: Thu, 17 Apr 2025 16:18:06 +0800 Subject: [PATCH 34/68] fix: update retrieval_model documentation (#18289) --- web/app/(commonLayout)/datasets/template/template.zh.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index b435a9bb67..099a37ab63 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -591,7 +591,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi 检索参数(选填,如不填,按照默认方式召回) - - search_method (text) 检索方法:以下三个关键字之一,必填 + - search_method (text) 检索方法:以下四个关键字之一,必填 - keyword_search 关键字检索 - semantic_search 语义检索 - full_text_search 全文检索 @@ -1817,7 +1817,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi 检索参数(选填,如不填,按照默认方式召回) - - search_method (text) 检索方法:以下三个关键字之一,必填 + - search_method (text) 检索方法:以下四个关键字之一,必填 - keyword_search 关键字检索 - semantic_search 语义检索 - full_text_search 全文检索 From defd5520ea02116588edbbe7b84e5a91477ff430 Mon Sep 17 00:00:00 2001 From: Vitor Date: Thu, 17 Apr 2025 16:52:49 +0800 Subject: [PATCH 35/68] fix: invalid new tool call creation logic during response handling in OAI-Compat model (#17781) --- .../__base/large_language_model.py | 86 ++++++++++------ .../core/model_runtime/__base/__init__.py | 0 .../__base/test_increase_tool_call.py | 99 +++++++++++++++++++ 3 files changed, 153 insertions(+), 32 deletions(-) create mode 100644 api/tests/unit_tests/core/model_runtime/__base/__init__.py create mode 100644 api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 53de16d621..1b799131e7 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,5 +1,6 @@ import logging import time +import uuid from collections.abc import Generator, Sequence from typing import Optional, Union @@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager logger = logging.getLogger(__name__) +def _gen_tool_call_id() -> str: + return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" + + +def _increase_tool_call( + new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] +): + """ + Merge incremental tool call updates into existing tool calls. + + :param new_tool_calls: List of new tool call deltas to be merged. + :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. + """ + + def get_tool_call(tool_call_id: str): + """ + Get or create a tool call by ID + + :param tool_call_id: tool call ID + :return: existing or new tool call + """ + if not tool_call_id: + return existing_tools_calls[-1] + + _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None) + if _tool_call is None: + _tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), + ) + existing_tools_calls.append(_tool_call) + + return _tool_call + + for new_tool_call in new_tool_calls: + # generate ID for tool calls with function name but no ID to track them + if new_tool_call.function.name and not new_tool_call.id: + new_tool_call.id = _gen_tool_call_id() + # get tool call + tool_call = get_tool_call(new_tool_call.id) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + class LargeLanguageModel(AIModel): """ Model class for large language model. @@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel): system_fingerprint = None tools_calls: list[AssistantPromptMessage.ToolCall] = [] - def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): - def get_tool_call(tool_name: str): - if not tool_name: - return tools_calls[-1] - - tool_call = next( - (tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None - ) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id="", - type="", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), - ) - tools_calls.append(tool_call) - - return tool_call - - for new_tool_call in new_tool_calls: - # get tool call - tool_call = get_tool_call(new_tool_call.function.name) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments - for chunk in result: if isinstance(chunk.delta.message.content, str): content += chunk.delta.message.content elif isinstance(chunk.delta.message.content, list): content_list.extend(chunk.delta.message.content) if chunk.delta.message.tool_calls: - increase_tool_call(chunk.delta.message.tool_calls) + _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) usage = chunk.delta.usage or LLMUsage.empty_usage() system_fingerprint = chunk.system_fingerprint diff --git a/api/tests/unit_tests/core/model_runtime/__base/__init__.py b/api/tests/unit_tests/core/model_runtime/__base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py new file mode 100644 index 0000000000..93d8a20cac --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py @@ -0,0 +1,99 @@ +from unittest.mock import MagicMock, patch + +from core.model_runtime.entities.message_entities import AssistantPromptMessage +from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call + +ToolCall = AssistantPromptMessage.ToolCall + +# CASE 1: Single tool call +INPUTS_CASE_1 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_1 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), +] + +# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...) +INPUTS_CASE_2 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_2 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), + ToolCall( + id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') + ), +] + +# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...) +INPUTS_CASE_3 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_3 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), + ToolCall( + id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') + ), +] + +# CASE 4: Tool call sequences with no IDs +INPUTS_CASE_4 = [ + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_4 = [ + ToolCall( + id="RANDOM_ID_1", + type="function", + function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), + ), + ToolCall( + id="RANDOM_ID_2", + type="function", + function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'), + ), +] + + +def _run_case(inputs: list[ToolCall], expected: list[ToolCall]): + actual = [] + _increase_tool_call(inputs, actual) + assert actual == expected + + +def test__increase_tool_call(): + # case 1: + _run_case(INPUTS_CASE_1, EXPECTED_CASE_1) + + # case 2: + _run_case(INPUTS_CASE_2, EXPECTED_CASE_2) + + # case 3: + _run_case(INPUTS_CASE_3, EXPECTED_CASE_3) + + # case 4: + mock_id_generator = MagicMock() + mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] + with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator): + _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) From 8f547e63409ed50b9d035c6bc7a7893cb56a19fc Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Thu, 17 Apr 2025 16:58:29 +0800 Subject: [PATCH 36/68] fix(typing): validate OAuth code before processing access token (#18288) --- api/controllers/console/auth/data_source_oauth.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index e911c9a5e5..b4bd80fe2f 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource): if not oauth_provider: return {"error": "Invalid provider"}, 400 if "code" in request.args: - code = request.args.get("code") + code = request.args.get("code", "") + if not code: + return {"error": "Invalid code"}, 400 try: oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: From 397e2a85220708d0feb9ea021a0e1b1055a35c5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=B0=8F=E9=BE=99?= <258392906@qq.com> Date: Thu, 17 Apr 2025 18:04:43 +0800 Subject: [PATCH 37/68] datasets api create-by-file add reranking_mode properties (#18300) --- web/app/(commonLayout)/datasets/template/template.zh.mdx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index 099a37ab63..a8bb7046e6 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -94,6 +94,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - semantic_search 语义检索 - full_text_search 全文检索 - reranking_enable (bool) 是否开启rerank + - reranking_mode (String) 混合检索 + - weighted_score 权重设置 + - reranking_model Rerank 模型 - reranking_model (object) Rerank 模型配置 - reranking_provider_name (string) Rerank 模型的提供商 - reranking_model_name (string) Rerank 模型的名称 From e90c532c3ab334274f318adfe99f6e239ee93c77 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Thu, 17 Apr 2025 18:05:15 +0800 Subject: [PATCH 38/68] fix retrival resource miss in chatflow (#18307) --- api/controllers/web/message.py | 1 + .../index_tool_callback_handler.py | 24 ------------------- api/models/model.py | 7 +----- 3 files changed, 2 insertions(+), 30 deletions(-) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 494b357d46..17e9a3990f 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -46,6 +46,7 @@ class MessageListApi(WebApiResource): "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, } diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 64c734f626..56859df7f4 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -6,7 +6,6 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import DatasetRetrieverResource class DatasetIndexToolCallbackHandler: @@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler: def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" - if resource and len(resource) > 0: - for item in resource: - dataset_retriever_resource = DatasetRetrieverResource( - message_id=self._message_id, - position=item.get("position") or 0, - dataset_id=item.get("dataset_id"), - dataset_name=item.get("dataset_name"), - document_id=item.get("document_id"), - document_name=item.get("document_name"), - data_source_type=item.get("data_source_type"), - segment_id=item.get("segment_id"), - score=item.get("score") if "score" in item else None, - hit_count=item.get("hit_count") if "hit_count" in item else None, - word_count=item.get("word_count") if "word_count" in item else None, - segment_position=item.get("segment_position") if "segment_position" in item else None, - index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, - content=item.get("content"), - retriever_from=item.get("retriever_from"), - created_by=self._user_id, - ) - db.session.add(dataset_retriever_resource) - db.session.commit() - self._queue_manager.publish( QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/models/model.py b/api/models/model.py index a826d13e7d..6577492d1b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1091,12 +1091,7 @@ class Message(db.Model): # type: ignore[name-defined] @property def retriever_resources(self): - return ( - db.session.query(DatasetRetrieverResource) - .filter(DatasetRetrieverResource.message_id == self.id) - .order_by(DatasetRetrieverResource.position.asc()) - .all() - ) + return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @property def message_files(self): From dc9c5a4bc7a2746262e051e2f973a06780c113ba Mon Sep 17 00:00:00 2001 From: Ganondorf <364776488@qq.com> Date: Thu, 17 Apr 2025 18:49:22 +0800 Subject: [PATCH 39/68] make repository type be private (#18304) Co-authored-by: lizb --- api/core/repository/repository_factory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/core/repository/repository_factory.py b/api/core/repository/repository_factory.py index 02e343d7ff..7da7e49055 100644 --- a/api/core/repository/repository_factory.py +++ b/api/core/repository/repository_factory.py @@ -17,7 +17,7 @@ RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any] WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository] # Repository type literals -RepositoryType = Literal["workflow_node_execution"] +_RepositoryType = Literal["workflow_node_execution"] class RepositoryFactory: @@ -32,7 +32,7 @@ class RepositoryFactory: _factory_functions: dict[str, RepositoryFactoryFunc] = {} @classmethod - def _register_factory(cls, repository_type: RepositoryType, factory_func: RepositoryFactoryFunc) -> None: + def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None: """ Register a factory function for a specific repository type. This is a private method and should not be called directly. @@ -44,7 +44,7 @@ class RepositoryFactory: cls._factory_functions[repository_type] = factory_func @classmethod - def _create_repository(cls, repository_type: RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any: + def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any: """ Create a new repository instance with the provided parameters. This is a private method and should not be called directly. From bbc6efd7733f27d6cceb8e567598db30ef81d046 Mon Sep 17 00:00:00 2001 From: devxing <66726106+devxing@users.noreply.github.com> Date: Thu, 17 Apr 2025 19:50:20 +0800 Subject: [PATCH 40/68] fix: curl request address (#18320) Co-authored-by: devxing --- .../template/template_advanced_chat.zh.mdx | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index 42eaf4f7b2..7135cf6188 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -523,7 +523,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested'?user=abc-123 \ + curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested?user=abc-123' \ --header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \ --header 'Content-Type: application/json' \ ``` @@ -967,7 +967,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' "user": "abc-123" }' ``` - + @@ -1191,10 +1191,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="GET" label="/apps/annotations" - targetCode={`curl --location --request GET '${props.apiBaseUrl}/apps/annotations?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`} + targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/apps/annotations?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`} > ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/apps/annotations?page=1&limit=20' \ + curl --location --request GET '${props.appDetail.api_base_url}/apps/annotations?page=1&limit=20' \ --header 'Authorization: Bearer {api_key}' ``` @@ -1245,10 +1245,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="POST" label="/apps/annotations" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotations' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`} + targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`} > ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/apps/annotations' \ + curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -1301,10 +1301,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="PUT" label="/apps/annotations/{annotation_id}" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`} + targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`} > ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \ + curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -1351,10 +1351,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="PUT" label="/apps/annotations/{annotation_id}" - targetCode={`curl --location --request DELETE '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'`} + targetCode={`curl --location --request DELETE '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'`} > ```bash {{ title: 'cURL' }} - curl --location --request DELETE '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \ + curl --location --request DELETE '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \ --header 'Authorization: Bearer {api_key}' ``` @@ -1398,7 +1398,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="POST" label="/apps/annotation-reply/{action}" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotation-reply/{action}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"score_threshold": 0.9, "embedding_provider_name": "zhipu", "embedding_model_name": "embedding_3"}'`} + targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotation-reply/{action}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"score_threshold": 0.9, "embedding_provider_name": "zhipu", "embedding_model_name": "embedding_3"}'`} > ```bash {{ title: 'cURL' }} curl --location --request POST 'https://api.dify.ai/v1/apps/annotation-reply/{action}' \ @@ -1448,10 +1448,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="GET" label="/apps/annotations" - targetCode={`curl --location --request GET '${props.apiBaseUrl}/apps/annotation-reply/{action}/status/{job_id}' \\\n--header 'Authorization: Bearer {api_key}'`} + targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/apps/annotation-reply/{action}/status/{job_id}' \\\n--header 'Authorization: Bearer {api_key}'`} > ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/apps/annotation-reply/{action}/status/{job_id}' \ + curl --location --request GET '${props.appDetail.api_base_url}/apps/annotation-reply/{action}/status/{job_id}' \ --header 'Authorization: Bearer {api_key}' ``` From b287aaccecfa8a283a94d9bda4bb3b2f9b1fe524 Mon Sep 17 00:00:00 2001 From: sayThQ199 <693858278@qq.com> Date: Thu, 17 Apr 2025 19:50:41 +0800 Subject: [PATCH 41/68] fix: Correctly render multiple think blocks in Markdown (#18310) Co-authored-by: xzj16125 Co-authored-by: crazywoola <427733928@qq.com> --- web/app/components/base/markdown.tsx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index 24ae59af73..d50c397177 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -85,9 +85,11 @@ const preprocessLaTeX = (content: string) => { } const preprocessThinkTag = (content: string) => { + const thinkOpenTagRegex = /\n/g + const thinkCloseTagRegex = /\n<\/think>/g return flow([ - (str: string) => str.replace('\n', '
\n'), - (str: string) => str.replace('\n', '\n[ENDTHINKFLAG]
'), + (str: string) => str.replace(thinkOpenTagRegex, '
\n'), + (str: string) => str.replace(thinkCloseTagRegex, '\n[ENDTHINKFLAG]
'), ])(content) } From 721294948c8f56d9daf01e73791ef1053d401e07 Mon Sep 17 00:00:00 2001 From: Ganondorf <364776488@qq.com> Date: Thu, 17 Apr 2025 21:09:19 +0800 Subject: [PATCH 42/68] =?UTF-8?q?Diable=20expire=5Fon=5Fcommit=20in=20the?= =?UTF-8?q?=20implemention=20of=20the=20WorkflowNodeExecut=E2=80=A6=20(#18?= =?UTF-8?q?321)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: lizb --- .../workflow_node_execution/sqlalchemy_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/repositories/workflow_node_execution/sqlalchemy_repository.py index 01c54dfcd7..c9c6e70ff3 100644 --- a/api/repositories/workflow_node_execution/sqlalchemy_repository.py +++ b/api/repositories/workflow_node_execution/sqlalchemy_repository.py @@ -36,7 +36,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository: """ # If an engine is provided, create a sessionmaker from it if isinstance(session_factory, Engine): - self._session_factory = sessionmaker(bind=session_factory) + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) else: self._session_factory = session_factory From 28ffe7e3dbb32de126e2ad475a69e1448eda5cc6 Mon Sep 17 00:00:00 2001 From: hbprotoss Date: Thu, 17 Apr 2025 21:10:58 +0800 Subject: [PATCH 43/68] fix: missing headers in some cases (#18283) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> --- web/service/fetch.ts | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/web/service/fetch.ts b/web/service/fetch.ts index fc41310c80..5d09256f1d 100644 --- a/web/service/fetch.ts +++ b/web/service/fetch.ts @@ -132,12 +132,13 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: getAbortController, } = otherOptions - const base - = isMarketplaceAPI - ? MARKETPLACE_API_PREFIX - : isPublicAPI - ? PUBLIC_API_PREFIX - : API_PREFIX + let base: string + if (isMarketplaceAPI) + base = MARKETPLACE_API_PREFIX + else if (isPublicAPI) + base = PUBLIC_API_PREFIX + else + base = API_PREFIX if (getAbortController) { const abortController = new AbortController() @@ -145,7 +146,7 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: options.signal = abortController.signal } - const fetchPathname = `${base}${url.startsWith('/') ? url : `/${url}`}` + const fetchPathname = base + (url.startsWith('/') ? url : `/${url}`) if (deleteContentType) (headers as any).delete('Content-Type') @@ -180,6 +181,16 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: }, ...(bodyStringify ? { json: body } : { body: body as BodyInit }), searchParams: params, + fetch(resource: RequestInfo | URL, options?: RequestInit) { + if (resource instanceof Request && options) { + const mergedHeaders = new Headers(options.headers || {}) + resource.headers.forEach((value, key) => { + mergedHeaders.append(key, value) + }) + options.headers = mergedHeaders + } + return globalThis.fetch(resource, options) + }, }) if (needAllResponseContent) From b96ecd072a7636d78a64fefcc2b812600b746da6 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Fri, 18 Apr 2025 09:42:08 +0800 Subject: [PATCH 44/68] fix: can not input R when debug (#18323) --- .../components/workflow/panel/debug-and-preview/index.tsx | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/web/app/components/workflow/panel/debug-and-preview/index.tsx b/web/app/components/workflow/panel/debug-and-preview/index.tsx index 53c91299a2..c33a6355f2 100644 --- a/web/app/components/workflow/panel/debug-and-preview/index.tsx +++ b/web/app/components/workflow/panel/debug-and-preview/index.tsx @@ -5,7 +5,7 @@ import { useRef, useState, } from 'react' -import { useKeyPress } from 'ahooks' + import { RiCloseLine, RiEqualizer2Line } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useNodes } from 'reactflow' @@ -48,12 +48,6 @@ const DebugAndPreview = () => { chatRef.current.handleRestart() } - useKeyPress('shift.r', () => { - handleRestartChat() - }, { - exactMatch: true, - }) - const [panelWidth, setPanelWidth] = useState(420) const [isResizing, setIsResizing] = useState(false) From 523efbfea5574c1b50e6a004adfa6e38539b84f2 Mon Sep 17 00:00:00 2001 From: Ethan <118581835+realethanhsu@users.noreply.github.com> Date: Fri, 18 Apr 2025 09:42:38 +0800 Subject: [PATCH 45/68] Fix: ValueError: Formatting field not found in record: 'req_id' (#18327) --- api/extensions/ext_logging.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 422ec87765..aa55862b7c 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -26,9 +26,12 @@ def init_app(app: DifyApp): # Always add StreamHandler to log to console sh = logging.StreamHandler(sys.stdout) - sh.addFilter(RequestIdFilter()) log_handlers.append(sh) + # Apply RequestIdFilter to all handlers + for handler in log_handlers: + handler.addFilter(RequestIdFilter()) + logging.basicConfig( level=dify_config.LOG_LEVEL, format=dify_config.LOG_FORMAT, From efe5db38ee028071d1eb5d4fbde163166c909c6f Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Fri, 18 Apr 2025 13:59:12 +0800 Subject: [PATCH 46/68] Chore/slice workflow (#18351) --- .../[appId]/workflow/page.tsx | 4 +- .../components/workflow-children.tsx | 69 ++++ .../workflow-header/chat-variable-trigger.tsx | 11 + .../workflow-header/features-trigger.tsx | 152 ++++++++ .../components/workflow-header/index.tsx | 31 ++ .../workflow-app/components/workflow-main.tsx | 87 +++++ .../components/workflow-panel.tsx | 109 ++++++ .../components/workflow-app/hooks/index.ts | 6 + .../workflow-app/hooks/use-is-chat-mode.ts | 7 + .../hooks/use-nodes-sync-draft.ts | 148 ++++++++ .../workflow-app/hooks/use-workflow-init.ts | 123 ++++++ .../workflow-app/hooks/use-workflow-run.ts | 357 ++++++++++++++++++ .../hooks/use-workflow-start-run.tsx | 96 +++++ .../hooks/use-workflow-template.ts | 8 +- web/app/components/workflow-app/index.tsx | 108 ++++++ .../store/workflow/workflow-slice.ts | 18 + web/app/components/workflow/context.tsx | 13 +- .../workflow/header/editing-title.tsx | 4 +- .../workflow/header/header-in-normal.tsx | 69 ++++ .../workflow/header/header-in-restoring.tsx | 93 +++++ .../header/header-in-view-history.tsx | 50 +++ web/app/components/workflow/header/index.tsx | 283 ++------------ .../workflow/header/restoring-title.tsx | 4 +- .../workflow/header/view-history.tsx | 4 +- .../components/workflow/hooks-store/index.ts | 2 + .../workflow/hooks-store/provider.tsx | 36 ++ .../components/workflow/hooks-store/store.ts | 72 ++++ web/app/components/workflow/hooks/index.ts | 2 +- .../use-edges-interactions-without-sync.ts | 27 ++ .../workflow/hooks/use-edges-interactions.ts | 17 - .../hooks/use-format-time-from-now.ts | 12 + .../use-nodes-interactions-without-sync.ts | 27 ++ .../workflow/hooks/use-nodes-interactions.ts | 17 - .../workflow/hooks/use-nodes-sync-draft.ts | 136 +------ .../hooks/use-workflow-interactions.ts | 8 +- .../workflow/hooks/use-workflow-run.ts | 351 +---------------- .../workflow/hooks/use-workflow-start-run.tsx | 91 +---- .../components/workflow/hooks/use-workflow.ts | 130 +------ web/app/components/workflow/index.tsx | 191 +++------- web/app/components/workflow/panel/index.tsx | 83 +--- .../workflow/store/workflow/index.ts | 16 +- .../workflow/store/workflow/node-slice.ts | 4 - .../workflow/store/workflow/workflow-slice.ts | 10 +- web/service/use-workflow.ts | 8 +- 44 files changed, 1855 insertions(+), 1239 deletions(-) create mode 100644 web/app/components/workflow-app/components/workflow-children.tsx create mode 100644 web/app/components/workflow-app/components/workflow-header/chat-variable-trigger.tsx create mode 100644 web/app/components/workflow-app/components/workflow-header/features-trigger.tsx create mode 100644 web/app/components/workflow-app/components/workflow-header/index.tsx create mode 100644 web/app/components/workflow-app/components/workflow-main.tsx create mode 100644 web/app/components/workflow-app/components/workflow-panel.tsx create mode 100644 web/app/components/workflow-app/hooks/index.ts create mode 100644 web/app/components/workflow-app/hooks/use-is-chat-mode.ts create mode 100644 web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts create mode 100644 web/app/components/workflow-app/hooks/use-workflow-init.ts create mode 100644 web/app/components/workflow-app/hooks/use-workflow-run.ts create mode 100644 web/app/components/workflow-app/hooks/use-workflow-start-run.tsx rename web/app/components/{workflow => workflow-app}/hooks/use-workflow-template.ts (87%) create mode 100644 web/app/components/workflow-app/index.tsx create mode 100644 web/app/components/workflow-app/store/workflow/workflow-slice.ts create mode 100644 web/app/components/workflow/header/header-in-normal.tsx create mode 100644 web/app/components/workflow/header/header-in-restoring.tsx create mode 100644 web/app/components/workflow/header/header-in-view-history.tsx create mode 100644 web/app/components/workflow/hooks-store/index.ts create mode 100644 web/app/components/workflow/hooks-store/provider.tsx create mode 100644 web/app/components/workflow/hooks-store/store.ts create mode 100644 web/app/components/workflow/hooks/use-edges-interactions-without-sync.ts create mode 100644 web/app/components/workflow/hooks/use-format-time-from-now.ts create mode 100644 web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx index f4d49425ae..d5df70f004 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx @@ -1,11 +1,11 @@ 'use client' -import Workflow from '@/app/components/workflow' +import WorkflowApp from '@/app/components/workflow-app' const Page = () => { return (
- +
) } diff --git a/web/app/components/workflow-app/components/workflow-children.tsx b/web/app/components/workflow-app/components/workflow-children.tsx new file mode 100644 index 0000000000..6a6bbcd61a --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-children.tsx @@ -0,0 +1,69 @@ +import { + memo, + useState, +} from 'react' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import { DSL_EXPORT_CHECK } from '@/app/components/workflow/constants' +import { useStore } from '@/app/components/workflow/store' +import Features from '@/app/components/workflow/features' +import PluginDependency from '@/app/components/workflow/plugin-dependency' +import UpdateDSLModal from '@/app/components/workflow/update-dsl-modal' +import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal' +import { + useDSL, + usePanelInteractions, +} from '@/app/components/workflow/hooks' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import WorkflowHeader from './workflow-header' +import WorkflowPanel from './workflow-panel' + +const WorkflowChildren = () => { + const { eventEmitter } = useEventEmitterContextContext() + const [secretEnvList, setSecretEnvList] = useState([]) + const showFeaturesPanel = useStore(s => s.showFeaturesPanel) + const showImportDSLModal = useStore(s => s.showImportDSLModal) + const setShowImportDSLModal = useStore(s => s.setShowImportDSLModal) + const { + handlePaneContextmenuCancel, + } = usePanelInteractions() + const { + exportCheck, + handleExportDSL, + } = useDSL() + + eventEmitter?.useSubscription((v: any) => { + if (v.type === DSL_EXPORT_CHECK) + setSecretEnvList(v.payload.data as EnvironmentVariable[]) + }) + + return ( + <> + + { + showFeaturesPanel && + } + { + showImportDSLModal && ( + setShowImportDSLModal(false)} + onBackup={exportCheck} + onImport={handlePaneContextmenuCancel} + /> + ) + } + { + secretEnvList.length > 0 && ( + setSecretEnvList([])} + /> + ) + } + + + + ) +} + +export default memo(WorkflowChildren) diff --git a/web/app/components/workflow-app/components/workflow-header/chat-variable-trigger.tsx b/web/app/components/workflow-app/components/workflow-header/chat-variable-trigger.tsx new file mode 100644 index 0000000000..df93914285 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-header/chat-variable-trigger.tsx @@ -0,0 +1,11 @@ +import { memo } from 'react' +import ChatVariableButton from '@/app/components/workflow/header/chat-variable-button' +import { + useNodesReadOnly, +} from '@/app/components/workflow/hooks' + +const ChatVariableTrigger = () => { + const { nodesReadOnly } = useNodesReadOnly() + return +} +export default memo(ChatVariableTrigger) diff --git a/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx new file mode 100644 index 0000000000..da64409090 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx @@ -0,0 +1,152 @@ +import { + memo, + useCallback, + useMemo, +} from 'react' +import { useNodes } from 'reactflow' +import { RiApps2AddLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { + useStore, + useWorkflowStore, +} from '@/app/components/workflow/store' +import { + useChecklistBeforePublish, + useNodesReadOnly, + useNodesSyncDraft, +} from '@/app/components/workflow/hooks' +import Button from '@/app/components/base/button' +import AppPublisher from '@/app/components/app/app-publisher' +import { useFeatures } from '@/app/components/base/features/hooks' +import { + BlockEnum, + InputVarType, +} from '@/app/components/workflow/types' +import type { StartNodeType } from '@/app/components/workflow/nodes/start/types' +import { useToastContext } from '@/app/components/base/toast' +import { usePublishWorkflow, useResetWorkflowVersionHistory } from '@/service/use-workflow' +import type { PublishWorkflowParams } from '@/types/workflow' +import { fetchAppDetail, fetchAppSSO } from '@/service/apps' +import { useStore as useAppStore } from '@/app/components/app/store' +import { useSelector as useAppSelector } from '@/context/app-context' + +const FeaturesTrigger = () => { + const { t } = useTranslation() + const workflowStore = useWorkflowStore() + const appDetail = useAppStore(s => s.appDetail) + const appID = appDetail?.id + const setAppDetail = useAppStore(s => s.setAppDetail) + const systemFeatures = useAppSelector(state => state.systemFeatures) + const { + nodesReadOnly, + getNodesReadOnly, + } = useNodesReadOnly() + const publishedAt = useStore(s => s.publishedAt) + const draftUpdatedAt = useStore(s => s.draftUpdatedAt) + const toolPublished = useStore(s => s.toolPublished) + const nodes = useNodes() + const startNode = nodes.find(node => node.data.type === BlockEnum.Start) + const startVariables = startNode?.data.variables + const fileSettings = useFeatures(s => s.features.file) + const variables = useMemo(() => { + const data = startVariables || [] + if (fileSettings?.image?.enabled) { + return [ + ...data, + { + type: InputVarType.files, + variable: '__image', + required: false, + label: 'files', + }, + ] + } + + return data + }, [fileSettings?.image?.enabled, startVariables]) + + const { handleCheckBeforePublish } = useChecklistBeforePublish() + const { handleSyncWorkflowDraft } = useNodesSyncDraft() + const { notify } = useToastContext() + + const handleShowFeatures = useCallback(() => { + const { + showFeaturesPanel, + isRestoring, + setShowFeaturesPanel, + } = workflowStore.getState() + if (getNodesReadOnly() && !isRestoring) + return + setShowFeaturesPanel(!showFeaturesPanel) + }, [workflowStore, getNodesReadOnly]) + + const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id) + + const updateAppDetail = useCallback(async () => { + try { + const res = await fetchAppDetail({ url: '/apps', id: appID! }) + if (systemFeatures.enable_web_sso_switch_component) { + const ssoRes = await fetchAppSSO({ appId: appID! }) + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + } + else { + setAppDetail({ ...res }) + } + } + catch (error) { + console.error(error) + } + }, [appID, setAppDetail, systemFeatures.enable_web_sso_switch_component]) + const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!) + const onPublish = useCallback(async (params?: PublishWorkflowParams) => { + if (await handleCheckBeforePublish()) { + const res = await publishWorkflow({ + title: params?.title || '', + releaseNotes: params?.releaseNotes || '', + }) + + if (res) { + notify({ type: 'success', message: t('common.api.actionSuccess') }) + updateAppDetail() + workflowStore.getState().setPublishedAt(res.created_at) + resetWorkflowVersionHistory() + } + } + else { + throw new Error('Checklist failed') + } + }, [handleCheckBeforePublish, notify, t, workflowStore, publishWorkflow, resetWorkflowVersionHistory, updateAppDetail]) + + const onPublisherToggle = useCallback((state: boolean) => { + if (state) + handleSyncWorkflowDraft(true) + }, [handleSyncWorkflowDraft]) + + const handleToolConfigureUpdate = useCallback(() => { + workflowStore.setState({ toolPublished: true }) + }, [workflowStore]) + + return ( + <> + + + + ) +} + +export default memo(FeaturesTrigger) diff --git a/web/app/components/workflow-app/components/workflow-header/index.tsx b/web/app/components/workflow-app/components/workflow-header/index.tsx new file mode 100644 index 0000000000..4eb8df7162 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-header/index.tsx @@ -0,0 +1,31 @@ +import { useMemo } from 'react' +import type { HeaderProps } from '@/app/components/workflow/header' +import Header from '@/app/components/workflow/header' +import { useStore as useAppStore } from '@/app/components/app/store' +import ChatVariableTrigger from './chat-variable-trigger' +import FeaturesTrigger from './features-trigger' +import { useResetWorkflowVersionHistory } from '@/service/use-workflow' + +const WorkflowHeader = () => { + const appDetail = useAppStore(s => s.appDetail) + const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id) + + const headerProps: HeaderProps = useMemo(() => { + return { + normal: { + components: { + left: , + middle: , + }, + }, + restoring: { + onRestoreSettled: resetWorkflowVersionHistory, + }, + } + }, [resetWorkflowVersionHistory]) + return ( +
+ ) +} + +export default WorkflowHeader diff --git a/web/app/components/workflow-app/components/workflow-main.tsx b/web/app/components/workflow-app/components/workflow-main.tsx new file mode 100644 index 0000000000..4ff1f4c624 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-main.tsx @@ -0,0 +1,87 @@ +import { + useCallback, + useMemo, +} from 'react' +import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { WorkflowWithInnerContext } from '@/app/components/workflow' +import type { WorkflowProps } from '@/app/components/workflow' +import WorkflowChildren from './workflow-children' +import { + useNodesSyncDraft, + useWorkflowRun, + useWorkflowStartRun, +} from '../hooks' + +type WorkflowMainProps = Pick +const WorkflowMain = ({ + nodes, + edges, + viewport, +}: WorkflowMainProps) => { + const featuresStore = useFeaturesStore() + + const handleWorkflowDataUpdate = useCallback((payload: any) => { + if (payload.features && featuresStore) { + const { setFeatures } = featuresStore.getState() + + setFeatures(payload.features) + } + }, [featuresStore]) + + const { + doSyncWorkflowDraft, + syncWorkflowDraftWhenPageClose, + } = useNodesSyncDraft() + const { + handleBackupDraft, + handleLoadBackupDraft, + handleRestoreFromPublishedWorkflow, + handleRun, + handleStopRun, + } = useWorkflowRun() + const { + handleStartWorkflowRun, + handleWorkflowStartRunInChatflow, + handleWorkflowStartRunInWorkflow, + } = useWorkflowStartRun() + + const hooksStore = useMemo(() => { + return { + syncWorkflowDraftWhenPageClose, + doSyncWorkflowDraft, + handleBackupDraft, + handleLoadBackupDraft, + handleRestoreFromPublishedWorkflow, + handleRun, + handleStopRun, + handleStartWorkflowRun, + handleWorkflowStartRunInChatflow, + handleWorkflowStartRunInWorkflow, + } + }, [ + syncWorkflowDraftWhenPageClose, + doSyncWorkflowDraft, + handleBackupDraft, + handleLoadBackupDraft, + handleRestoreFromPublishedWorkflow, + handleRun, + handleStopRun, + handleStartWorkflowRun, + handleWorkflowStartRunInChatflow, + handleWorkflowStartRunInWorkflow, + ]) + + return ( + + + + ) +} + +export default WorkflowMain diff --git a/web/app/components/workflow-app/components/workflow-panel.tsx b/web/app/components/workflow-app/components/workflow-panel.tsx new file mode 100644 index 0000000000..3c1b5c8aac --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-panel.tsx @@ -0,0 +1,109 @@ +import { useMemo } from 'react' +import { useShallow } from 'zustand/react/shallow' +import { useStore } from '@/app/components/workflow/store' +import { + useIsChatMode, +} from '../hooks' +import DebugAndPreview from '@/app/components/workflow/panel/debug-and-preview' +import Record from '@/app/components/workflow/panel/record' +import WorkflowPreview from '@/app/components/workflow/panel/workflow-preview' +import ChatRecord from '@/app/components/workflow/panel/chat-record' +import ChatVariablePanel from '@/app/components/workflow/panel/chat-variable-panel' +import GlobalVariablePanel from '@/app/components/workflow/panel/global-variable-panel' +import VersionHistoryPanel from '@/app/components/workflow/panel/version-history-panel' +import { useStore as useAppStore } from '@/app/components/app/store' +import MessageLogModal from '@/app/components/base/message-log-modal' +import type { PanelProps } from '@/app/components/workflow/panel' +import Panel from '@/app/components/workflow/panel' + +const WorkflowPanelOnLeft = () => { + const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ + currentLogItem: state.currentLogItem, + setCurrentLogItem: state.setCurrentLogItem, + showMessageLogModal: state.showMessageLogModal, + setShowMessageLogModal: state.setShowMessageLogModal, + currentLogModalActiveTab: state.currentLogModalActiveTab, + }))) + return ( + <> + { + showMessageLogModal && ( + { + setCurrentLogItem() + setShowMessageLogModal(false) + }} + defaultTab={currentLogModalActiveTab} + /> + ) + } + + ) +} +const WorkflowPanelOnRight = () => { + const isChatMode = useIsChatMode() + const historyWorkflowData = useStore(s => s.historyWorkflowData) + const showDebugAndPreviewPanel = useStore(s => s.showDebugAndPreviewPanel) + const showChatVariablePanel = useStore(s => s.showChatVariablePanel) + const showGlobalVariablePanel = useStore(s => s.showGlobalVariablePanel) + const showWorkflowVersionHistoryPanel = useStore(s => s.showWorkflowVersionHistoryPanel) + + return ( + <> + { + historyWorkflowData && !isChatMode && ( + + ) + } + { + historyWorkflowData && isChatMode && ( + + ) + } + { + showDebugAndPreviewPanel && isChatMode && ( + + ) + } + { + showDebugAndPreviewPanel && !isChatMode && ( + + ) + } + { + showChatVariablePanel && ( + + ) + } + { + showGlobalVariablePanel && ( + + ) + } + { + showWorkflowVersionHistoryPanel && ( + + ) + } + + ) +} +const WorkflowPanel = () => { + const panelProps: PanelProps = useMemo(() => { + return { + components: { + left: , + right: , + }, + } + }, []) + + return ( + + ) +} + +export default WorkflowPanel diff --git a/web/app/components/workflow-app/hooks/index.ts b/web/app/components/workflow-app/hooks/index.ts new file mode 100644 index 0000000000..1517eb9a16 --- /dev/null +++ b/web/app/components/workflow-app/hooks/index.ts @@ -0,0 +1,6 @@ +export * from './use-workflow-init' +export * from './use-workflow-template' +export * from './use-nodes-sync-draft' +export * from './use-workflow-run' +export * from './use-workflow-start-run' +export * from './use-is-chat-mode' diff --git a/web/app/components/workflow-app/hooks/use-is-chat-mode.ts b/web/app/components/workflow-app/hooks/use-is-chat-mode.ts new file mode 100644 index 0000000000..3cdfc77b2a --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-is-chat-mode.ts @@ -0,0 +1,7 @@ +import { useStore as useAppStore } from '@/app/components/app/store' + +export const useIsChatMode = () => { + const appDetail = useAppStore(s => s.appDetail) + + return appDetail?.mode === 'advanced-chat' +} diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts new file mode 100644 index 0000000000..7c6eb6a5be --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -0,0 +1,148 @@ +import { useCallback } from 'react' +import produce from 'immer' +import { useStoreApi } from 'reactflow' +import { useParams } from 'next/navigation' +import { + useWorkflowStore, +} from '@/app/components/workflow/store' +import { BlockEnum } from '@/app/components/workflow/types' +import { useWorkflowUpdate } from '@/app/components/workflow/hooks' +import { + useNodesReadOnly, +} from '@/app/components/workflow/hooks/use-workflow' +import { syncWorkflowDraft } from '@/service/workflow' +import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { API_PREFIX } from '@/config' + +export const useNodesSyncDraft = () => { + const store = useStoreApi() + const workflowStore = useWorkflowStore() + const featuresStore = useFeaturesStore() + const { getNodesReadOnly } = useNodesReadOnly() + const { handleRefreshWorkflowDraft } = useWorkflowUpdate() + const params = useParams() + + const getPostParams = useCallback(() => { + const { + getNodes, + edges, + transform, + } = store.getState() + const [x, y, zoom] = transform + const { + appId, + conversationVariables, + environmentVariables, + syncWorkflowDraftHash, + } = workflowStore.getState() + + if (appId) { + const nodes = getNodes() + const hasStartNode = nodes.find(node => node.data.type === BlockEnum.Start) + + if (!hasStartNode) + return + + const features = featuresStore!.getState().features + const producedNodes = produce(nodes, (draft) => { + draft.forEach((node) => { + Object.keys(node.data).forEach((key) => { + if (key.startsWith('_')) + delete node.data[key] + }) + }) + }) + const producedEdges = produce(edges, (draft) => { + draft.forEach((edge) => { + Object.keys(edge.data).forEach((key) => { + if (key.startsWith('_')) + delete edge.data[key] + }) + }) + }) + return { + url: `/apps/${appId}/workflows/draft`, + params: { + graph: { + nodes: producedNodes, + edges: producedEdges, + viewport: { + x, + y, + zoom, + }, + }, + features: { + opening_statement: features.opening?.enabled ? (features.opening?.opening_statement || '') : '', + suggested_questions: features.opening?.enabled ? (features.opening?.suggested_questions || []) : [], + suggested_questions_after_answer: features.suggested, + text_to_speech: features.text2speech, + speech_to_text: features.speech2text, + retriever_resource: features.citation, + sensitive_word_avoidance: features.moderation, + file_upload: features.file, + }, + environment_variables: environmentVariables, + conversation_variables: conversationVariables, + hash: syncWorkflowDraftHash, + }, + } + } + }, [store, featuresStore, workflowStore]) + + const syncWorkflowDraftWhenPageClose = useCallback(() => { + if (getNodesReadOnly()) + return + const postParams = getPostParams() + + if (postParams) { + navigator.sendBeacon( + `${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`, + JSON.stringify(postParams.params), + ) + } + }, [getPostParams, params.appId, getNodesReadOnly]) + + const doSyncWorkflowDraft = useCallback(async ( + notRefreshWhenSyncError?: boolean, + callback?: { + onSuccess?: () => void + onError?: () => void + onSettled?: () => void + }, + ) => { + if (getNodesReadOnly()) + return + const postParams = getPostParams() + + if (postParams) { + const { + setSyncWorkflowDraftHash, + setDraftUpdatedAt, + } = workflowStore.getState() + try { + const res = await syncWorkflowDraft(postParams) + setSyncWorkflowDraftHash(res.hash) + setDraftUpdatedAt(res.updated_at) + callback?.onSuccess && callback.onSuccess() + } + catch (error: any) { + if (error && error.json && !error.bodyUsed) { + error.json().then((err: any) => { + if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) + handleRefreshWorkflowDraft() + }) + } + callback?.onError && callback.onError() + } + finally { + callback?.onSettled && callback.onSettled() + } + } + }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft]) + + return { + doSyncWorkflowDraft, + syncWorkflowDraftWhenPageClose, + } +} diff --git a/web/app/components/workflow-app/hooks/use-workflow-init.ts b/web/app/components/workflow-app/hooks/use-workflow-init.ts new file mode 100644 index 0000000000..e1c4c25a4e --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-workflow-init.ts @@ -0,0 +1,123 @@ +import { + useCallback, + useEffect, + useState, +} from 'react' +import { + useStore, + useWorkflowStore, +} from '@/app/components/workflow/store' +import { useWorkflowTemplate } from './use-workflow-template' +import { useStore as useAppStore } from '@/app/components/app/store' +import { + fetchNodesDefaultConfigs, + fetchPublishedWorkflow, + fetchWorkflowDraft, + syncWorkflowDraft, +} from '@/service/workflow' +import type { FetchWorkflowDraftResponse } from '@/types/workflow' +import { useWorkflowConfig } from '@/service/use-workflow' + +export const useWorkflowInit = () => { + const workflowStore = useWorkflowStore() + const { + nodes: nodesTemplate, + edges: edgesTemplate, + } = useWorkflowTemplate() + const appDetail = useAppStore(state => state.appDetail)! + const setSyncWorkflowDraftHash = useStore(s => s.setSyncWorkflowDraftHash) + const [data, setData] = useState() + const [isLoading, setIsLoading] = useState(true) + useEffect(() => { + workflowStore.setState({ appId: appDetail.id }) + }, [appDetail.id, workflowStore]) + + const handleUpdateWorkflowConfig = useCallback((config: Record) => { + const { setWorkflowConfig } = workflowStore.getState() + + setWorkflowConfig(config) + }, [workflowStore]) + useWorkflowConfig(appDetail.id, handleUpdateWorkflowConfig) + + const handleGetInitialWorkflowData = useCallback(async () => { + try { + const res = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) + setData(res) + workflowStore.setState({ + envSecrets: (res.environment_variables || []).filter(env => env.value_type === 'secret').reduce((acc, env) => { + acc[env.id] = env.value + return acc + }, {} as Record), + environmentVariables: res.environment_variables?.map(env => env.value_type === 'secret' ? { ...env, value: '[__HIDDEN__]' } : env) || [], + conversationVariables: res.conversation_variables || [], + }) + setSyncWorkflowDraftHash(res.hash) + setIsLoading(false) + } + catch (error: any) { + if (error && error.json && !error.bodyUsed && appDetail) { + error.json().then((err: any) => { + if (err.code === 'draft_workflow_not_exist') { + workflowStore.setState({ notInitialWorkflow: true }) + syncWorkflowDraft({ + url: `/apps/${appDetail.id}/workflows/draft`, + params: { + graph: { + nodes: nodesTemplate, + edges: edgesTemplate, + }, + features: { + retriever_resource: { enabled: true }, + }, + environment_variables: [], + conversation_variables: [], + }, + }).then((res) => { + workflowStore.getState().setDraftUpdatedAt(res.updated_at) + handleGetInitialWorkflowData() + }) + } + }) + } + } + }, [appDetail, nodesTemplate, edgesTemplate, workflowStore, setSyncWorkflowDraftHash]) + + useEffect(() => { + handleGetInitialWorkflowData() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + const handleFetchPreloadData = useCallback(async () => { + try { + const nodesDefaultConfigsData = await fetchNodesDefaultConfigs(`/apps/${appDetail?.id}/workflows/default-workflow-block-configs`) + const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`) + workflowStore.setState({ + nodesDefaultConfigs: nodesDefaultConfigsData.reduce((acc, block) => { + if (!acc[block.type]) + acc[block.type] = { ...block.config } + return acc + }, {} as Record), + }) + workflowStore.getState().setPublishedAt(publishedWorkflow?.created_at) + } + catch (e) { + console.error(e) + } + }, [workflowStore, appDetail]) + + useEffect(() => { + handleFetchPreloadData() + }, [handleFetchPreloadData]) + + useEffect(() => { + if (data) { + workflowStore.getState().setDraftUpdatedAt(data.updated_at) + workflowStore.getState().setToolPublished(data.tool_published) + } + }, [data, workflowStore]) + + return { + data, + isLoading, + } +} diff --git a/web/app/components/workflow-app/hooks/use-workflow-run.ts b/web/app/components/workflow-app/hooks/use-workflow-run.ts new file mode 100644 index 0000000000..1e484d0760 --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-workflow-run.ts @@ -0,0 +1,357 @@ +import { useCallback } from 'react' +import { + useReactFlow, + useStoreApi, +} from 'reactflow' +import produce from 'immer' +import { v4 as uuidV4 } from 'uuid' +import { usePathname } from 'next/navigation' +import { useWorkflowStore } from '@/app/components/workflow/store' +import { WorkflowRunningStatus } from '@/app/components/workflow/types' +import { useWorkflowUpdate } from '@/app/components/workflow/hooks/use-workflow-interactions' +import { useWorkflowRunEvent } from '@/app/components/workflow/hooks/use-workflow-run-event/use-workflow-run-event' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { IOtherOptions } from '@/service/base' +import { ssePost } from '@/service/base' +import { stopWorkflowRun } from '@/service/workflow' +import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' +import type { VersionHistory } from '@/types/workflow' +import { noop } from 'lodash-es' +import { useNodesSyncDraft } from './use-nodes-sync-draft' + +export const useWorkflowRun = () => { + const store = useStoreApi() + const workflowStore = useWorkflowStore() + const reactflow = useReactFlow() + const featuresStore = useFeaturesStore() + const { doSyncWorkflowDraft } = useNodesSyncDraft() + const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() + const pathname = usePathname() + + const { + handleWorkflowStarted, + handleWorkflowFinished, + handleWorkflowFailed, + handleWorkflowNodeStarted, + handleWorkflowNodeFinished, + handleWorkflowNodeIterationStarted, + handleWorkflowNodeIterationNext, + handleWorkflowNodeIterationFinished, + handleWorkflowNodeLoopStarted, + handleWorkflowNodeLoopNext, + handleWorkflowNodeLoopFinished, + handleWorkflowNodeRetry, + handleWorkflowAgentLog, + handleWorkflowTextChunk, + handleWorkflowTextReplace, + } = useWorkflowRunEvent() + + const handleBackupDraft = useCallback(() => { + const { + getNodes, + edges, + } = store.getState() + const { getViewport } = reactflow + const { + backupDraft, + setBackupDraft, + environmentVariables, + } = workflowStore.getState() + const { features } = featuresStore!.getState() + + if (!backupDraft) { + setBackupDraft({ + nodes: getNodes(), + edges, + viewport: getViewport(), + features, + environmentVariables, + }) + doSyncWorkflowDraft() + } + }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft]) + + const handleLoadBackupDraft = useCallback(() => { + const { + backupDraft, + setBackupDraft, + setEnvironmentVariables, + } = workflowStore.getState() + + if (backupDraft) { + const { + nodes, + edges, + viewport, + features, + environmentVariables, + } = backupDraft + handleUpdateWorkflowCanvas({ + nodes, + edges, + viewport, + }) + setEnvironmentVariables(environmentVariables) + featuresStore!.setState({ features }) + setBackupDraft(undefined) + } + }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore]) + + const handleRun = useCallback(async ( + params: any, + callback?: IOtherOptions, + ) => { + const { + getNodes, + setNodes, + } = store.getState() + const newNodes = produce(getNodes(), (draft) => { + draft.forEach((node) => { + node.data.selected = false + node.data._runningStatus = undefined + }) + }) + setNodes(newNodes) + await doSyncWorkflowDraft() + + const { + onWorkflowStarted, + onWorkflowFinished, + onNodeStarted, + onNodeFinished, + onIterationStart, + onIterationNext, + onIterationFinish, + onLoopStart, + onLoopNext, + onLoopFinish, + onNodeRetry, + onAgentLog, + onError, + ...restCallback + } = callback || {} + workflowStore.setState({ historyWorkflowData: undefined }) + const appDetail = useAppStore.getState().appDetail + const workflowContainer = document.getElementById('workflow-container') + + const { + clientWidth, + clientHeight, + } = workflowContainer! + + let url = '' + if (appDetail?.mode === 'advanced-chat') + url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run` + + if (appDetail?.mode === 'workflow') + url = `/apps/${appDetail.id}/workflows/draft/run` + + const { + setWorkflowRunningData, + } = workflowStore.getState() + setWorkflowRunningData({ + result: { + status: WorkflowRunningStatus.Running, + }, + tracing: [], + resultText: '', + }) + + let ttsUrl = '' + let ttsIsPublic = false + if (params.token) { + ttsUrl = '/text-to-audio' + ttsIsPublic = true + } + else if (params.appId) { + if (pathname.search('explore/installed') > -1) + ttsUrl = `/installed-apps/${params.appId}/text-to-audio` + else + ttsUrl = `/apps/${params.appId}/text-to-audio` + } + const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', noop) + + ssePost( + url, + { + body: params, + }, + { + onWorkflowStarted: (params) => { + handleWorkflowStarted(params) + + if (onWorkflowStarted) + onWorkflowStarted(params) + }, + onWorkflowFinished: (params) => { + handleWorkflowFinished(params) + + if (onWorkflowFinished) + onWorkflowFinished(params) + }, + onError: (params) => { + handleWorkflowFailed() + + if (onError) + onError(params) + }, + onNodeStarted: (params) => { + handleWorkflowNodeStarted( + params, + { + clientWidth, + clientHeight, + }, + ) + + if (onNodeStarted) + onNodeStarted(params) + }, + onNodeFinished: (params) => { + handleWorkflowNodeFinished(params) + + if (onNodeFinished) + onNodeFinished(params) + }, + onIterationStart: (params) => { + handleWorkflowNodeIterationStarted( + params, + { + clientWidth, + clientHeight, + }, + ) + + if (onIterationStart) + onIterationStart(params) + }, + onIterationNext: (params) => { + handleWorkflowNodeIterationNext(params) + + if (onIterationNext) + onIterationNext(params) + }, + onIterationFinish: (params) => { + handleWorkflowNodeIterationFinished(params) + + if (onIterationFinish) + onIterationFinish(params) + }, + onLoopStart: (params) => { + handleWorkflowNodeLoopStarted( + params, + { + clientWidth, + clientHeight, + }, + ) + + if (onLoopStart) + onLoopStart(params) + }, + onLoopNext: (params) => { + handleWorkflowNodeLoopNext(params) + + if (onLoopNext) + onLoopNext(params) + }, + onLoopFinish: (params) => { + handleWorkflowNodeLoopFinished(params) + + if (onLoopFinish) + onLoopFinish(params) + }, + onNodeRetry: (params) => { + handleWorkflowNodeRetry(params) + + if (onNodeRetry) + onNodeRetry(params) + }, + onAgentLog: (params) => { + handleWorkflowAgentLog(params) + + if (onAgentLog) + onAgentLog(params) + }, + onTextChunk: (params) => { + handleWorkflowTextChunk(params) + }, + onTextReplace: (params) => { + handleWorkflowTextReplace(params) + }, + onTTSChunk: (messageId: string, audio: string) => { + if (!audio || audio === '') + return + player.playAudioWithAudio(audio, true) + AudioPlayerManager.getInstance().resetMsgId(messageId) + }, + onTTSEnd: (messageId: string, audio: string) => { + player.playAudioWithAudio(audio, false) + }, + ...restCallback, + }, + ) + }, [ + store, + workflowStore, + doSyncWorkflowDraft, + handleWorkflowStarted, + handleWorkflowFinished, + handleWorkflowFailed, + handleWorkflowNodeStarted, + handleWorkflowNodeFinished, + handleWorkflowNodeIterationStarted, + handleWorkflowNodeIterationNext, + handleWorkflowNodeIterationFinished, + handleWorkflowNodeLoopStarted, + handleWorkflowNodeLoopNext, + handleWorkflowNodeLoopFinished, + handleWorkflowNodeRetry, + handleWorkflowTextChunk, + handleWorkflowTextReplace, + handleWorkflowAgentLog, + pathname], + ) + + const handleStopRun = useCallback((taskId: string) => { + const appId = useAppStore.getState().appDetail?.id + + stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`) + }, []) + + const handleRestoreFromPublishedWorkflow = useCallback((publishedWorkflow: VersionHistory) => { + const nodes = publishedWorkflow.graph.nodes.map(node => ({ ...node, selected: false, data: { ...node.data, selected: false } })) + const edges = publishedWorkflow.graph.edges + const viewport = publishedWorkflow.graph.viewport! + handleUpdateWorkflowCanvas({ + nodes, + edges, + viewport, + }) + const mappedFeatures = { + opening: { + enabled: !!publishedWorkflow.features.opening_statement || !!publishedWorkflow.features.suggested_questions.length, + opening_statement: publishedWorkflow.features.opening_statement, + suggested_questions: publishedWorkflow.features.suggested_questions, + }, + suggested: publishedWorkflow.features.suggested_questions_after_answer, + text2speech: publishedWorkflow.features.text_to_speech, + speech2text: publishedWorkflow.features.speech_to_text, + citation: publishedWorkflow.features.retriever_resource, + moderation: publishedWorkflow.features.sensitive_word_avoidance, + file: publishedWorkflow.features.file_upload, + } + + featuresStore?.setState({ features: mappedFeatures }) + workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || []) + }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore]) + + return { + handleBackupDraft, + handleLoadBackupDraft, + handleRun, + handleStopRun, + handleRestoreFromPublishedWorkflow, + } +} diff --git a/web/app/components/workflow-app/hooks/use-workflow-start-run.tsx b/web/app/components/workflow-app/hooks/use-workflow-start-run.tsx new file mode 100644 index 0000000000..3f5ea1c1df --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-workflow-start-run.tsx @@ -0,0 +1,96 @@ +import { useCallback } from 'react' +import { useStoreApi } from 'reactflow' +import { useWorkflowStore } from '@/app/components/workflow/store' +import { + BlockEnum, + WorkflowRunningStatus, +} from '@/app/components/workflow/types' +import { useWorkflowInteractions } from '@/app/components/workflow/hooks' +import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { + useIsChatMode, + useNodesSyncDraft, + useWorkflowRun, +} from '.' + +export const useWorkflowStartRun = () => { + const store = useStoreApi() + const workflowStore = useWorkflowStore() + const featuresStore = useFeaturesStore() + const isChatMode = useIsChatMode() + const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() + const { handleRun } = useWorkflowRun() + const { doSyncWorkflowDraft } = useNodesSyncDraft() + + const handleWorkflowStartRunInWorkflow = useCallback(async () => { + const { + workflowRunningData, + } = workflowStore.getState() + + if (workflowRunningData?.result.status === WorkflowRunningStatus.Running) + return + + const { getNodes } = store.getState() + const nodes = getNodes() + const startNode = nodes.find(node => node.data.type === BlockEnum.Start) + const startVariables = startNode?.data.variables || [] + const fileSettings = featuresStore!.getState().features.file + const { + showDebugAndPreviewPanel, + setShowDebugAndPreviewPanel, + setShowInputsPanel, + setShowEnvPanel, + } = workflowStore.getState() + + setShowEnvPanel(false) + + if (showDebugAndPreviewPanel) { + handleCancelDebugAndPreviewPanel() + return + } + + if (!startVariables.length && !fileSettings?.image?.enabled) { + await doSyncWorkflowDraft() + handleRun({ inputs: {}, files: [] }) + setShowDebugAndPreviewPanel(true) + setShowInputsPanel(false) + } + else { + setShowDebugAndPreviewPanel(true) + setShowInputsPanel(true) + } + }, [store, workflowStore, featuresStore, handleCancelDebugAndPreviewPanel, handleRun, doSyncWorkflowDraft]) + + const handleWorkflowStartRunInChatflow = useCallback(async () => { + const { + showDebugAndPreviewPanel, + setShowDebugAndPreviewPanel, + setHistoryWorkflowData, + setShowEnvPanel, + setShowChatVariablePanel, + } = workflowStore.getState() + + setShowEnvPanel(false) + setShowChatVariablePanel(false) + + if (showDebugAndPreviewPanel) + handleCancelDebugAndPreviewPanel() + else + setShowDebugAndPreviewPanel(true) + + setHistoryWorkflowData(undefined) + }, [workflowStore, handleCancelDebugAndPreviewPanel]) + + const handleStartWorkflowRun = useCallback(() => { + if (!isChatMode) + handleWorkflowStartRunInWorkflow() + else + handleWorkflowStartRunInChatflow() + }, [isChatMode, handleWorkflowStartRunInWorkflow, handleWorkflowStartRunInChatflow]) + + return { + handleStartWorkflowRun, + handleWorkflowStartRunInWorkflow, + handleWorkflowStartRunInChatflow, + } +} diff --git a/web/app/components/workflow/hooks/use-workflow-template.ts b/web/app/components/workflow-app/hooks/use-workflow-template.ts similarity index 87% rename from web/app/components/workflow/hooks/use-workflow-template.ts rename to web/app/components/workflow-app/hooks/use-workflow-template.ts index c2dc956b63..9f47b981dc 100644 --- a/web/app/components/workflow/hooks/use-workflow-template.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-template.ts @@ -1,10 +1,10 @@ -import { generateNewNode } from '../utils' +import { generateNewNode } from '@/app/components/workflow/utils' import { NODE_WIDTH_X_OFFSET, START_INITIAL_POSITION, -} from '../constants' -import { useIsChatMode } from './use-workflow' -import { useNodesInitialData } from './use-nodes-data' +} from '@/app/components/workflow/constants' +import { useNodesInitialData } from '@/app/components/workflow/hooks' +import { useIsChatMode } from './use-is-chat-mode' export const useWorkflowTemplate = () => { const isChatMode = useIsChatMode() diff --git a/web/app/components/workflow-app/index.tsx b/web/app/components/workflow-app/index.tsx new file mode 100644 index 0000000000..761a7f29c4 --- /dev/null +++ b/web/app/components/workflow-app/index.tsx @@ -0,0 +1,108 @@ +import { + useMemo, +} from 'react' +import useSWR from 'swr' +import { + SupportUploadFileTypes, +} from '@/app/components/workflow/types' +import { + useWorkflowInit, +} from './hooks' +import { + initialEdges, + initialNodes, +} from '@/app/components/workflow/utils' +import Loading from '@/app/components/base/loading' +import { FeaturesProvider } from '@/app/components/base/features' +import type { Features as FeaturesData } from '@/app/components/base/features/types' +import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' +import { fetchFileUploadConfig } from '@/service/common' +import WorkflowWithDefaultContext from '@/app/components/workflow' +import { + WorkflowContextProvider, +} from '@/app/components/workflow/context' +import { createWorkflowSlice } from './store/workflow/workflow-slice' +import WorkflowAppMain from './components/workflow-main' + +const WorkflowAppWithAdditionalContext = () => { + const { + data, + isLoading, + } = useWorkflowInit() + const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) + + const nodesData = useMemo(() => { + if (data) + return initialNodes(data.graph.nodes, data.graph.edges) + + return [] + }, [data]) + const edgesData = useMemo(() => { + if (data) + return initialEdges(data.graph.edges, data.graph.nodes) + + return [] + }, [data]) + + if (!data || isLoading) { + return ( +
+ +
+ ) + } + + const features = data.features || {} + const initialFeatures: FeaturesData = { + file: { + image: { + enabled: !!features.file_upload?.image?.enabled, + number_limits: features.file_upload?.image?.number_limits || 3, + transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled), + allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], + allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), + allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3, + fileUploadConfig: fileUploadConfigResponse, + }, + opening: { + enabled: !!features.opening_statement, + opening_statement: features.opening_statement, + suggested_questions: features.suggested_questions, + }, + suggested: features.suggested_questions_after_answer || { enabled: false }, + speech2text: features.speech_to_text || { enabled: false }, + text2speech: features.text_to_speech || { enabled: false }, + citation: features.retriever_resource || { enabled: false }, + moderation: features.sensitive_word_avoidance || { enabled: false }, + } + + return ( + + + + + + ) +} + +const WorkflowAppWrapper = () => { + return ( + + + + ) +} + +export default WorkflowAppWrapper diff --git a/web/app/components/workflow-app/store/workflow/workflow-slice.ts b/web/app/components/workflow-app/store/workflow/workflow-slice.ts new file mode 100644 index 0000000000..77626e52b1 --- /dev/null +++ b/web/app/components/workflow-app/store/workflow/workflow-slice.ts @@ -0,0 +1,18 @@ +import type { StateCreator } from 'zustand' + +export type WorkflowSliceShape = { + appId: string + notInitialWorkflow: boolean + setNotInitialWorkflow: (notInitialWorkflow: boolean) => void + nodesDefaultConfigs: Record + setNodesDefaultConfigs: (nodesDefaultConfigs: Record) => void +} + +export type CreateWorkflowSlice = StateCreator +export const createWorkflowSlice: StateCreator = set => ({ + appId: '', + notInitialWorkflow: false, + setNotInitialWorkflow: notInitialWorkflow => set(() => ({ notInitialWorkflow })), + nodesDefaultConfigs: {}, + setNodesDefaultConfigs: nodesDefaultConfigs => set(() => ({ nodesDefaultConfigs })), +}) diff --git a/web/app/components/workflow/context.tsx b/web/app/components/workflow/context.tsx index bb34ce6319..cae14fc2b2 100644 --- a/web/app/components/workflow/context.tsx +++ b/web/app/components/workflow/context.tsx @@ -2,19 +2,24 @@ import { createContext, useRef, } from 'react' -import { createWorkflowStore } from './store' +import { + createWorkflowStore, +} from './store' +import type { StateCreator } from 'zustand' +import type { WorkflowSliceShape } from '@/app/components/workflow-app/store/workflow/workflow-slice' type WorkflowStore = ReturnType export const WorkflowContext = createContext(null) -type WorkflowProviderProps = { +export type WorkflowProviderProps = { children: React.ReactNode + injectWorkflowStoreSliceFn?: StateCreator } -export const WorkflowContextProvider = ({ children }: WorkflowProviderProps) => { +export const WorkflowContextProvider = ({ children, injectWorkflowStoreSliceFn }: WorkflowProviderProps) => { const storeRef = useRef(undefined) if (!storeRef.current) - storeRef.current = createWorkflowStore() + storeRef.current = createWorkflowStore({ injectWorkflowStoreSliceFn }) return ( diff --git a/web/app/components/workflow/header/editing-title.tsx b/web/app/components/workflow/header/editing-title.tsx index b99564a5f9..2444cf8c29 100644 --- a/web/app/components/workflow/header/editing-title.tsx +++ b/web/app/components/workflow/header/editing-title.tsx @@ -1,13 +1,13 @@ import { memo } from 'react' import { useTranslation } from 'react-i18next' -import { useWorkflow } from '../hooks' +import { useFormatTimeFromNow } from '../hooks' import { useStore } from '@/app/components/workflow/store' import useTimestamp from '@/hooks/use-timestamp' const EditingTitle = () => { const { t } = useTranslation() const { formatTime } = useTimestamp() - const { formatTimeFromNow } = useWorkflow() + const { formatTimeFromNow } = useFormatTimeFromNow() const draftUpdatedAt = useStore(state => state.draftUpdatedAt) const publishedAt = useStore(state => state.publishedAt) const isSyncingWorkflowDraft = useStore(s => s.isSyncingWorkflowDraft) diff --git a/web/app/components/workflow/header/header-in-normal.tsx b/web/app/components/workflow/header/header-in-normal.tsx new file mode 100644 index 0000000000..ec016b1b65 --- /dev/null +++ b/web/app/components/workflow/header/header-in-normal.tsx @@ -0,0 +1,69 @@ +import { + useCallback, +} from 'react' +import { useNodes } from 'reactflow' +import { + useStore, + useWorkflowStore, +} from '../store' +import type { StartNodeType } from '../nodes/start/types' +import { + useNodesInteractions, + useNodesReadOnly, + useWorkflowRun, +} from '../hooks' +import Divider from '../../base/divider' +import RunAndHistory from './run-and-history' +import EditingTitle from './editing-title' +import EnvButton from './env-button' +import VersionHistoryButton from './version-history-button' + +export type HeaderInNormalProps = { + components?: { + left?: React.ReactNode + middle?: React.ReactNode + } +} +const HeaderInNormal = ({ + components, +}: HeaderInNormalProps) => { + const workflowStore = useWorkflowStore() + const { nodesReadOnly } = useNodesReadOnly() + const { handleNodeSelect } = useNodesInteractions() + const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) + const setShowEnvPanel = useStore(s => s.setShowEnvPanel) + const setShowDebugAndPreviewPanel = useStore(s => s.setShowDebugAndPreviewPanel) + const nodes = useNodes() + const selectedNode = nodes.find(node => node.data.selected) + const { handleBackupDraft } = useWorkflowRun() + + const onStartRestoring = useCallback(() => { + workflowStore.setState({ isRestoring: true }) + handleBackupDraft() + // clear right panel + if (selectedNode) + handleNodeSelect(selectedNode.id, true) + setShowWorkflowVersionHistoryPanel(true) + setShowEnvPanel(false) + setShowDebugAndPreviewPanel(false) + }, [handleBackupDraft, workflowStore, handleNodeSelect, selectedNode, + setShowWorkflowVersionHistoryPanel, setShowEnvPanel, setShowDebugAndPreviewPanel]) + + return ( + <> +
+ +
+
+ {components?.left} + + + + {components?.middle} + +
+ + ) +} + +export default HeaderInNormal diff --git a/web/app/components/workflow/header/header-in-restoring.tsx b/web/app/components/workflow/header/header-in-restoring.tsx new file mode 100644 index 0000000000..4d1954587d --- /dev/null +++ b/web/app/components/workflow/header/header-in-restoring.tsx @@ -0,0 +1,93 @@ +import { + useCallback, +} from 'react' +import { RiHistoryLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { + useStore, + useWorkflowStore, +} from '../store' +import { + WorkflowVersion, +} from '../types' +import { + useNodesSyncDraft, + useWorkflowRun, +} from '../hooks' +import Toast from '../../base/toast' +import RestoringTitle from './restoring-title' +import Button from '@/app/components/base/button' + +export type HeaderInRestoringProps = { + onRestoreSettled?: () => void +} +const HeaderInRestoring = ({ + onRestoreSettled, +}: HeaderInRestoringProps) => { + const { t } = useTranslation() + const workflowStore = useWorkflowStore() + const currentVersion = useStore(s => s.currentVersion) + const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) + + const { + handleLoadBackupDraft, + } = useWorkflowRun() + const { handleSyncWorkflowDraft } = useNodesSyncDraft() + + const handleCancelRestore = useCallback(() => { + handleLoadBackupDraft() + workflowStore.setState({ isRestoring: false }) + setShowWorkflowVersionHistoryPanel(false) + }, [workflowStore, handleLoadBackupDraft, setShowWorkflowVersionHistoryPanel]) + + const handleRestore = useCallback(() => { + setShowWorkflowVersionHistoryPanel(false) + workflowStore.setState({ isRestoring: false }) + workflowStore.setState({ backupDraft: undefined }) + handleSyncWorkflowDraft(true, false, { + onSuccess: () => { + Toast.notify({ + type: 'success', + message: t('workflow.versionHistory.action.restoreSuccess'), + }) + }, + onError: () => { + Toast.notify({ + type: 'error', + message: t('workflow.versionHistory.action.restoreFailure'), + }) + }, + onSettled: () => { + onRestoreSettled?.() + }, + }) + }, [handleSyncWorkflowDraft, workflowStore, setShowWorkflowVersionHistoryPanel, onRestoreSettled, t]) + + return ( + <> +
+ +
+
+ + +
+ + ) +} + +export default HeaderInRestoring diff --git a/web/app/components/workflow/header/header-in-view-history.tsx b/web/app/components/workflow/header/header-in-view-history.tsx new file mode 100644 index 0000000000..81858ccc89 --- /dev/null +++ b/web/app/components/workflow/header/header-in-view-history.tsx @@ -0,0 +1,50 @@ +import { + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + useWorkflowStore, +} from '../store' +import { + useWorkflowRun, +} from '../hooks' +import Divider from '../../base/divider' +import RunningTitle from './running-title' +import ViewHistory from './view-history' +import Button from '@/app/components/base/button' +import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' + +const HeaderInHistory = () => { + const { t } = useTranslation() + const workflowStore = useWorkflowStore() + + const { + handleLoadBackupDraft, + } = useWorkflowRun() + + const handleGoBackToEdit = useCallback(() => { + handleLoadBackupDraft() + workflowStore.setState({ historyWorkflowData: undefined }) + }, [workflowStore, handleLoadBackupDraft]) + + return ( + <> +
+ +
+
+ + + +
+ + ) +} + +export default HeaderInHistory diff --git a/web/app/components/workflow/header/index.tsx b/web/app/components/workflow/header/index.tsx index 7e99f5dd6b..e5391afb09 100644 --- a/web/app/components/workflow/header/index.tsx +++ b/web/app/components/workflow/header/index.tsx @@ -1,292 +1,51 @@ -import type { FC } from 'react' import { - memo, - useCallback, - useMemo, -} from 'react' -import { RiApps2AddLine, RiHistoryLine } from '@remixicon/react' -import { useNodes } from 'reactflow' -import { useTranslation } from 'react-i18next' -import { useContext, useContextSelector } from 'use-context-selector' -import { - useStore, - useWorkflowStore, -} from '../store' -import { - BlockEnum, - InputVarType, - WorkflowVersion, -} from '../types' -import type { StartNodeType } from '../nodes/start/types' -import { - useChecklistBeforePublish, - useIsChatMode, - useNodesInteractions, - useNodesReadOnly, - useNodesSyncDraft, useWorkflowMode, - useWorkflowRun, } from '../hooks' -import AppPublisher from '../../app/app-publisher' -import Toast, { ToastContext } from '../../base/toast' -import Divider from '../../base/divider' -import RunAndHistory from './run-and-history' -import EditingTitle from './editing-title' -import RunningTitle from './running-title' -import RestoringTitle from './restoring-title' -import ViewHistory from './view-history' -import ChatVariableButton from './chat-variable-button' -import EnvButton from './env-button' -import VersionHistoryButton from './version-history-button' -import Button from '@/app/components/base/button' -import { useStore as useAppStore } from '@/app/components/app/store' -import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' -import { useFeatures } from '@/app/components/base/features/hooks' -import { usePublishWorkflow, useResetWorkflowVersionHistory } from '@/service/use-workflow' -import type { PublishWorkflowParams } from '@/types/workflow' -import { fetchAppDetail, fetchAppSSO } from '@/service/apps' -import AppContext from '@/context/app-context' +import type { HeaderInNormalProps } from './header-in-normal' +import HeaderInNormal from './header-in-normal' +import HeaderInHistory from './header-in-view-history' +import type { HeaderInRestoringProps } from './header-in-restoring' +import HeaderInRestoring from './header-in-restoring' -const Header: FC = () => { - const { t } = useTranslation() - const workflowStore = useWorkflowStore() - const appDetail = useAppStore(s => s.appDetail) - const setAppDetail = useAppStore(s => s.setAppDetail) - const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) - const appID = appDetail?.id - const isChatMode = useIsChatMode() - const { nodesReadOnly, getNodesReadOnly } = useNodesReadOnly() - const { handleNodeSelect } = useNodesInteractions() - const publishedAt = useStore(s => s.publishedAt) - const draftUpdatedAt = useStore(s => s.draftUpdatedAt) - const toolPublished = useStore(s => s.toolPublished) - const currentVersion = useStore(s => s.currentVersion) - const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) - const setShowEnvPanel = useStore(s => s.setShowEnvPanel) - const setShowDebugAndPreviewPanel = useStore(s => s.setShowDebugAndPreviewPanel) - const nodes = useNodes() - const startNode = nodes.find(node => node.data.type === BlockEnum.Start) - const selectedNode = nodes.find(node => node.data.selected) - const startVariables = startNode?.data.variables - const fileSettings = useFeatures(s => s.features.file) - const variables = useMemo(() => { - const data = startVariables || [] - if (fileSettings?.image?.enabled) { - return [ - ...data, - { - type: InputVarType.files, - variable: '__image', - required: false, - label: 'files', - }, - ] - } - - return data - }, [fileSettings?.image?.enabled, startVariables]) - - const { - handleLoadBackupDraft, - handleBackupDraft, - } = useWorkflowRun() - const { handleCheckBeforePublish } = useChecklistBeforePublish() - const { handleSyncWorkflowDraft } = useNodesSyncDraft() - const { notify } = useContext(ToastContext) +export type HeaderProps = { + normal?: HeaderInNormalProps + restoring?: HeaderInRestoringProps +} +const Header = ({ + normal: normalProps, + restoring: restoringProps, +}: HeaderProps) => { const { normal, restoring, viewHistory, } = useWorkflowMode() - const handleShowFeatures = useCallback(() => { - const { - showFeaturesPanel, - isRestoring, - setShowFeaturesPanel, - } = workflowStore.getState() - if (getNodesReadOnly() && !isRestoring) - return - setShowFeaturesPanel(!showFeaturesPanel) - }, [workflowStore, getNodesReadOnly]) - - const handleCancelRestore = useCallback(() => { - handleLoadBackupDraft() - workflowStore.setState({ isRestoring: false }) - setShowWorkflowVersionHistoryPanel(false) - }, [workflowStore, handleLoadBackupDraft, setShowWorkflowVersionHistoryPanel]) - - const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id) - - const handleRestore = useCallback(() => { - setShowWorkflowVersionHistoryPanel(false) - workflowStore.setState({ isRestoring: false }) - workflowStore.setState({ backupDraft: undefined }) - handleSyncWorkflowDraft(true, false, { - onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('workflow.versionHistory.action.restoreSuccess'), - }) - }, - onError: () => { - Toast.notify({ - type: 'error', - message: t('workflow.versionHistory.action.restoreFailure'), - }) - }, - onSettled: () => { - resetWorkflowVersionHistory() - }, - }) - }, [handleSyncWorkflowDraft, workflowStore, setShowWorkflowVersionHistoryPanel, resetWorkflowVersionHistory, t]) - - const updateAppDetail = useCallback(async () => { - try { - const res = await fetchAppDetail({ url: '/apps', id: appID! }) - if (systemFeatures.enable_web_sso_switch_component) { - const ssoRes = await fetchAppSSO({ appId: appID! }) - setAppDetail({ ...res, enable_sso: ssoRes.enabled }) - } - else { - setAppDetail({ ...res }) - } - } - catch (error) { - console.error(error) - } - }, [appID, setAppDetail, systemFeatures.enable_web_sso_switch_component]) - - const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!) - - const onPublish = useCallback(async (params?: PublishWorkflowParams) => { - if (await handleCheckBeforePublish()) { - const res = await publishWorkflow({ - title: params?.title || '', - releaseNotes: params?.releaseNotes || '', - }) - - if (res) { - notify({ type: 'success', message: t('common.api.actionSuccess') }) - updateAppDetail() - workflowStore.getState().setPublishedAt(res.created_at) - resetWorkflowVersionHistory() - } - } - else { - throw new Error('Checklist failed') - } - }, [handleCheckBeforePublish, notify, t, workflowStore, publishWorkflow, resetWorkflowVersionHistory, updateAppDetail]) - - const onStartRestoring = useCallback(() => { - workflowStore.setState({ isRestoring: true }) - handleBackupDraft() - // clear right panel - if (selectedNode) - handleNodeSelect(selectedNode.id, true) - setShowWorkflowVersionHistoryPanel(true) - setShowEnvPanel(false) - setShowDebugAndPreviewPanel(false) - }, [handleBackupDraft, workflowStore, handleNodeSelect, selectedNode, - setShowWorkflowVersionHistoryPanel, setShowEnvPanel, setShowDebugAndPreviewPanel]) - - const onPublisherToggle = useCallback((state: boolean) => { - if (state) - handleSyncWorkflowDraft(true) - }, [handleSyncWorkflowDraft]) - - const handleGoBackToEdit = useCallback(() => { - handleLoadBackupDraft() - workflowStore.setState({ historyWorkflowData: undefined }) - }, [workflowStore, handleLoadBackupDraft]) - - const handleToolConfigureUpdate = useCallback(() => { - workflowStore.setState({ toolPublished: true }) - }, [workflowStore]) - return (
-
- { - normal && - } - { - viewHistory && - } - { - restoring && - } -
{ normal && ( -
- {/* */} - {isChatMode && } - - - - - - -
+ ) } { viewHistory && ( -
- - - -
+ ) } { restoring && ( -
- - -
+ ) }
) } -export default memo(Header) +export default Header diff --git a/web/app/components/workflow/header/restoring-title.tsx b/web/app/components/workflow/header/restoring-title.tsx index 310ab5c35a..26cdd79d13 100644 --- a/web/app/components/workflow/header/restoring-title.tsx +++ b/web/app/components/workflow/header/restoring-title.tsx @@ -1,13 +1,13 @@ import { memo, useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { useWorkflow } from '../hooks' +import { useFormatTimeFromNow } from '../hooks' import { useStore } from '../store' import { WorkflowVersion } from '../types' import useTimestamp from '@/hooks/use-timestamp' const RestoringTitle = () => { const { t } = useTranslation() - const { formatTimeFromNow } = useWorkflow() + const { formatTimeFromNow } = useFormatTimeFromNow() const { formatTime } = useTimestamp() const currentVersion = useStore(state => state.currentVersion) const isDraft = currentVersion?.version === WorkflowVersion.Draft diff --git a/web/app/components/workflow/header/view-history.tsx b/web/app/components/workflow/header/view-history.tsx index 1298c0e42d..21b4462867 100644 --- a/web/app/components/workflow/header/view-history.tsx +++ b/web/app/components/workflow/header/view-history.tsx @@ -11,9 +11,9 @@ import { RiErrorWarningLine, } from '@remixicon/react' import { + useFormatTimeFromNow, useIsChatMode, useNodesInteractions, - useWorkflow, useWorkflowInteractions, useWorkflowRun, } from '../hooks' @@ -50,7 +50,7 @@ const ViewHistory = ({ const { t } = useTranslation() const isChatMode = useIsChatMode() const [open, setOpen] = useState(false) - const { formatTimeFromNow } = useWorkflow() + const { formatTimeFromNow } = useFormatTimeFromNow() const { handleNodesCancelSelected, } = useNodesInteractions() diff --git a/web/app/components/workflow/hooks-store/index.ts b/web/app/components/workflow/hooks-store/index.ts new file mode 100644 index 0000000000..40b4132dfd --- /dev/null +++ b/web/app/components/workflow/hooks-store/index.ts @@ -0,0 +1,2 @@ +export * from './provider' +export * from './store' diff --git a/web/app/components/workflow/hooks-store/provider.tsx b/web/app/components/workflow/hooks-store/provider.tsx new file mode 100644 index 0000000000..c1090ae3f8 --- /dev/null +++ b/web/app/components/workflow/hooks-store/provider.tsx @@ -0,0 +1,36 @@ +import { + createContext, + useEffect, + useRef, +} from 'react' +import { useStore } from 'reactflow' +import { + createHooksStore, +} from './store' +import type { Shape } from './store' + +type HooksStore = ReturnType +export const HooksStoreContext = createContext(null) +type HooksStoreContextProviderProps = Partial & { + children: React.ReactNode +} +export const HooksStoreContextProvider = ({ children, ...restProps }: HooksStoreContextProviderProps) => { + const storeRef = useRef(undefined) + const d3Selection = useStore(s => s.d3Selection) + const d3Zoom = useStore(s => s.d3Zoom) + + useEffect(() => { + if (storeRef.current && d3Selection && d3Zoom) + storeRef.current.getState().refreshAll(restProps) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [d3Selection, d3Zoom]) + + if (!storeRef.current) + storeRef.current = createHooksStore(restProps) + + return ( + + {children} + + ) +} diff --git a/web/app/components/workflow/hooks-store/store.ts b/web/app/components/workflow/hooks-store/store.ts new file mode 100644 index 0000000000..2e40cbfbc9 --- /dev/null +++ b/web/app/components/workflow/hooks-store/store.ts @@ -0,0 +1,72 @@ +import { useContext } from 'react' +import { + noop, +} from 'lodash-es' +import { + useStore as useZustandStore, +} from 'zustand' +import { createStore } from 'zustand/vanilla' +import { HooksStoreContext } from './provider' + +type CommonHooksFnMap = { + doSyncWorkflowDraft: ( + notRefreshWhenSyncError?: boolean, + callback?: { + onSuccess?: () => void + onError?: () => void + onSettled?: () => void + } + ) => Promise + syncWorkflowDraftWhenPageClose: () => void + handleBackupDraft: () => void + handleLoadBackupDraft: () => void + handleRestoreFromPublishedWorkflow: (...args: any[]) => void + handleRun: (...args: any[]) => void + handleStopRun: (...args: any[]) => void + handleStartWorkflowRun: () => void + handleWorkflowStartRunInWorkflow: () => void + handleWorkflowStartRunInChatflow: () => void +} + +export type Shape = { + refreshAll: (props: Partial) => void +} & CommonHooksFnMap + +export const createHooksStore = ({ + doSyncWorkflowDraft = async () => noop(), + syncWorkflowDraftWhenPageClose = noop, + handleBackupDraft = noop, + handleLoadBackupDraft = noop, + handleRestoreFromPublishedWorkflow = noop, + handleRun = noop, + handleStopRun = noop, + handleStartWorkflowRun = noop, + handleWorkflowStartRunInWorkflow = noop, + handleWorkflowStartRunInChatflow = noop, +}: Partial) => { + return createStore(set => ({ + refreshAll: props => set(state => ({ ...state, ...props })), + doSyncWorkflowDraft, + syncWorkflowDraftWhenPageClose, + handleBackupDraft, + handleLoadBackupDraft, + handleRestoreFromPublishedWorkflow, + handleRun, + handleStopRun, + handleStartWorkflowRun, + handleWorkflowStartRunInWorkflow, + handleWorkflowStartRunInChatflow, + })) +} + +export function useHooksStore(selector: (state: Shape) => T): T { + const store = useContext(HooksStoreContext) + if (!store) + throw new Error('Missing HooksStoreContext.Provider in the tree') + + return useZustandStore(store, selector) +} + +export const useHooksStoreApi = () => { + return useContext(HooksStoreContext)! +} diff --git a/web/app/components/workflow/hooks/index.ts b/web/app/components/workflow/hooks/index.ts index 463e9b3271..20a34c69e3 100644 --- a/web/app/components/workflow/hooks/index.ts +++ b/web/app/components/workflow/hooks/index.ts @@ -5,7 +5,6 @@ export * from './use-nodes-data' export * from './use-nodes-sync-draft' export * from './use-workflow' export * from './use-workflow-run' -export * from './use-workflow-template' export * from './use-checklist' export * from './use-selection-interactions' export * from './use-panel-interactions' @@ -16,3 +15,4 @@ export * from './use-workflow-variables' export * from './use-shortcuts' export * from './use-workflow-interactions' export * from './use-workflow-mode' +export * from './use-format-time-from-now' diff --git a/web/app/components/workflow/hooks/use-edges-interactions-without-sync.ts b/web/app/components/workflow/hooks/use-edges-interactions-without-sync.ts new file mode 100644 index 0000000000..c4c709cd25 --- /dev/null +++ b/web/app/components/workflow/hooks/use-edges-interactions-without-sync.ts @@ -0,0 +1,27 @@ +import { useCallback } from 'react' +import produce from 'immer' +import { useStoreApi } from 'reactflow' + +export const useEdgesInteractionsWithoutSync = () => { + const store = useStoreApi() + + const handleEdgeCancelRunningStatus = useCallback(() => { + const { + edges, + setEdges, + } = store.getState() + + const newEdges = produce(edges, (draft) => { + draft.forEach((edge) => { + edge.data._sourceRunningStatus = undefined + edge.data._targetRunningStatus = undefined + edge.data._waitingRun = false + }) + }) + setEdges(newEdges) + }, [store]) + + return { + handleEdgeCancelRunningStatus, + } +} diff --git a/web/app/components/workflow/hooks/use-edges-interactions.ts b/web/app/components/workflow/hooks/use-edges-interactions.ts index 688f0b26ce..306af1e96c 100644 --- a/web/app/components/workflow/hooks/use-edges-interactions.ts +++ b/web/app/components/workflow/hooks/use-edges-interactions.ts @@ -151,28 +151,11 @@ export const useEdgesInteractions = () => { setEdges(newEdges) }, [store, getNodesReadOnly]) - const handleEdgeCancelRunningStatus = useCallback(() => { - const { - edges, - setEdges, - } = store.getState() - - const newEdges = produce(edges, (draft) => { - draft.forEach((edge) => { - edge.data._sourceRunningStatus = undefined - edge.data._targetRunningStatus = undefined - edge.data._waitingRun = false - }) - }) - setEdges(newEdges) - }, [store]) - return { handleEdgeEnter, handleEdgeLeave, handleEdgeDeleteByDeleteBranch, handleEdgeDelete, handleEdgesChange, - handleEdgeCancelRunningStatus, } } diff --git a/web/app/components/workflow/hooks/use-format-time-from-now.ts b/web/app/components/workflow/hooks/use-format-time-from-now.ts new file mode 100644 index 0000000000..b2b521557f --- /dev/null +++ b/web/app/components/workflow/hooks/use-format-time-from-now.ts @@ -0,0 +1,12 @@ +import dayjs from 'dayjs' +import { useCallback } from 'react' +import { useI18N } from '@/context/i18n' + +export const useFormatTimeFromNow = () => { + const { locale } = useI18N() + const formatTimeFromNow = useCallback((time: number) => { + return dayjs(time).locale(locale === 'zh-Hans' ? 'zh-cn' : locale).fromNow() + }, [locale]) + + return { formatTimeFromNow } +} diff --git a/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts b/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts new file mode 100644 index 0000000000..7fbf0ce868 --- /dev/null +++ b/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts @@ -0,0 +1,27 @@ +import { useCallback } from 'react' +import produce from 'immer' +import { useStoreApi } from 'reactflow' + +export const useNodesInteractionsWithoutSync = () => { + const store = useStoreApi() + + const handleNodeCancelRunningStatus = useCallback(() => { + const { + getNodes, + setNodes, + } = store.getState() + + const nodes = getNodes() + const newNodes = produce(nodes, (draft) => { + draft.forEach((node) => { + node.data._runningStatus = undefined + node.data._waitingRun = false + }) + }) + setNodes(newNodes) + }, [store]) + + return { + handleNodeCancelRunningStatus, + } +} diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 90231cfcc8..94b10c9929 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -1177,22 +1177,6 @@ export const useNodesInteractions = () => { saveStateToHistory(WorkflowHistoryEvent.NodeChange) }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory]) - const handleNodeCancelRunningStatus = useCallback(() => { - const { - getNodes, - setNodes, - } = store.getState() - - const nodes = getNodes() - const newNodes = produce(nodes, (draft) => { - draft.forEach((node) => { - node.data._runningStatus = undefined - node.data._waitingRun = false - }) - }) - setNodes(newNodes) - }, [store]) - const handleNodesCancelSelected = useCallback(() => { const { getNodes, @@ -1554,7 +1538,6 @@ export const useNodesInteractions = () => { handleNodeDelete, handleNodeChange, handleNodeAdd, - handleNodeCancelRunningStatus, handleNodesCancelSelected, handleNodeContextMenu, handleNodesCopy, diff --git a/web/app/components/workflow/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow/hooks/use-nodes-sync-draft.ts index 5cd8f36bff..e6cc3a97e3 100644 --- a/web/app/components/workflow/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow/hooks/use-nodes-sync-draft.ts @@ -1,147 +1,17 @@ import { useCallback } from 'react' -import produce from 'immer' -import { useStoreApi } from 'reactflow' -import { useParams } from 'next/navigation' import { useStore, - useWorkflowStore, } from '../store' -import { BlockEnum } from '../types' -import { useWorkflowUpdate } from '../hooks' import { useNodesReadOnly, } from './use-workflow' -import { syncWorkflowDraft } from '@/service/workflow' -import { useFeaturesStore } from '@/app/components/base/features/hooks' -import { API_PREFIX } from '@/config' +import { useHooksStore } from '@/app/components/workflow/hooks-store' export const useNodesSyncDraft = () => { - const store = useStoreApi() - const workflowStore = useWorkflowStore() - const featuresStore = useFeaturesStore() const { getNodesReadOnly } = useNodesReadOnly() - const { handleRefreshWorkflowDraft } = useWorkflowUpdate() const debouncedSyncWorkflowDraft = useStore(s => s.debouncedSyncWorkflowDraft) - const params = useParams() - - const getPostParams = useCallback(() => { - const { - getNodes, - edges, - transform, - } = store.getState() - const [x, y, zoom] = transform - const { - appId, - conversationVariables, - environmentVariables, - syncWorkflowDraftHash, - } = workflowStore.getState() - - if (appId) { - const nodes = getNodes() - const hasStartNode = nodes.find(node => node.data.type === BlockEnum.Start) - - if (!hasStartNode) - return - - const features = featuresStore!.getState().features - const producedNodes = produce(nodes, (draft) => { - draft.forEach((node) => { - Object.keys(node.data).forEach((key) => { - if (key.startsWith('_')) - delete node.data[key] - }) - }) - }) - const producedEdges = produce(edges, (draft) => { - draft.forEach((edge) => { - Object.keys(edge.data).forEach((key) => { - if (key.startsWith('_')) - delete edge.data[key] - }) - }) - }) - return { - url: `/apps/${appId}/workflows/draft`, - params: { - graph: { - nodes: producedNodes, - edges: producedEdges, - viewport: { - x, - y, - zoom, - }, - }, - features: { - opening_statement: features.opening?.enabled ? (features.opening?.opening_statement || '') : '', - suggested_questions: features.opening?.enabled ? (features.opening?.suggested_questions || []) : [], - suggested_questions_after_answer: features.suggested, - text_to_speech: features.text2speech, - speech_to_text: features.speech2text, - retriever_resource: features.citation, - sensitive_word_avoidance: features.moderation, - file_upload: features.file, - }, - environment_variables: environmentVariables, - conversation_variables: conversationVariables, - hash: syncWorkflowDraftHash, - }, - } - } - }, [store, featuresStore, workflowStore]) - - const syncWorkflowDraftWhenPageClose = useCallback(() => { - if (getNodesReadOnly()) - return - const postParams = getPostParams() - - if (postParams) { - navigator.sendBeacon( - `${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`, - JSON.stringify(postParams.params), - ) - } - }, [getPostParams, params.appId, getNodesReadOnly]) - - const doSyncWorkflowDraft = useCallback(async ( - notRefreshWhenSyncError?: boolean, - callback?: { - onSuccess?: () => void - onError?: () => void - onSettled?: () => void - }, - ) => { - if (getNodesReadOnly()) - return - const postParams = getPostParams() - - if (postParams) { - const { - setSyncWorkflowDraftHash, - setDraftUpdatedAt, - } = workflowStore.getState() - try { - const res = await syncWorkflowDraft(postParams) - setSyncWorkflowDraftHash(res.hash) - setDraftUpdatedAt(res.updated_at) - callback?.onSuccess && callback.onSuccess() - } - catch (error: any) { - if (error && error.json && !error.bodyUsed) { - error.json().then((err: any) => { - if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) - handleRefreshWorkflowDraft() - }) - } - callback?.onError && callback.onError() - } - finally { - callback?.onSettled && callback.onSettled() - } - } - }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft]) + const doSyncWorkflowDraft = useHooksStore(s => s.doSyncWorkflowDraft) + const syncWorkflowDraftWhenPageClose = useHooksStore(s => s.syncWorkflowDraftWhenPageClose) const handleSyncWorkflowDraft = useCallback(( sync?: boolean, diff --git a/web/app/components/workflow/hooks/use-workflow-interactions.ts b/web/app/components/workflow/hooks/use-workflow-interactions.ts index 202867e22f..740868c594 100644 --- a/web/app/components/workflow/hooks/use-workflow-interactions.ts +++ b/web/app/components/workflow/hooks/use-workflow-interactions.ts @@ -25,8 +25,8 @@ import { useSelectionInteractions, useWorkflowReadOnly, } from '../hooks' -import { useEdgesInteractions } from './use-edges-interactions' -import { useNodesInteractions } from './use-nodes-interactions' +import { useEdgesInteractionsWithoutSync } from './use-edges-interactions-without-sync' +import { useNodesInteractionsWithoutSync } from './use-nodes-interactions-without-sync' import { useNodesSyncDraft } from './use-nodes-sync-draft' import { WorkflowHistoryEvent, useWorkflowHistory } from './use-workflow-history' import { useEventEmitterContextContext } from '@/context/event-emitter' @@ -37,8 +37,8 @@ import { useStore as useAppStore } from '@/app/components/app/store' export const useWorkflowInteractions = () => { const workflowStore = useWorkflowStore() - const { handleNodeCancelRunningStatus } = useNodesInteractions() - const { handleEdgeCancelRunningStatus } = useEdgesInteractions() + const { handleNodeCancelRunningStatus } = useNodesInteractionsWithoutSync() + const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync() const handleCancelDebugAndPreviewPanel = useCallback(() => { workflowStore.setState({ diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 99d9a45702..05a60ebb4b 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -1,350 +1,11 @@ -import { useCallback } from 'react' -import { - useReactFlow, - useStoreApi, -} from 'reactflow' -import produce from 'immer' -import { v4 as uuidV4 } from 'uuid' -import { usePathname } from 'next/navigation' -import { useWorkflowStore } from '../store' -import { useNodesSyncDraft } from '../hooks' -import { WorkflowRunningStatus } from '../types' -import { useWorkflowUpdate } from './use-workflow-interactions' -import { useWorkflowRunEvent } from './use-workflow-run-event/use-workflow-run-event' -import { useStore as useAppStore } from '@/app/components/app/store' -import type { IOtherOptions } from '@/service/base' -import { ssePost } from '@/service/base' -import { stopWorkflowRun } from '@/service/workflow' -import { useFeaturesStore } from '@/app/components/base/features/hooks' -import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' -import type { VersionHistory } from '@/types/workflow' -import { noop } from 'lodash-es' +import { useHooksStore } from '@/app/components/workflow/hooks-store' export const useWorkflowRun = () => { - const store = useStoreApi() - const workflowStore = useWorkflowStore() - const reactflow = useReactFlow() - const featuresStore = useFeaturesStore() - const { doSyncWorkflowDraft } = useNodesSyncDraft() - const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() - const pathname = usePathname() - const { - handleWorkflowStarted, - handleWorkflowFinished, - handleWorkflowFailed, - handleWorkflowNodeStarted, - handleWorkflowNodeFinished, - handleWorkflowNodeIterationStarted, - handleWorkflowNodeIterationNext, - handleWorkflowNodeIterationFinished, - handleWorkflowNodeLoopStarted, - handleWorkflowNodeLoopNext, - handleWorkflowNodeLoopFinished, - handleWorkflowNodeRetry, - handleWorkflowAgentLog, - handleWorkflowTextChunk, - handleWorkflowTextReplace, - } = useWorkflowRunEvent() - - const handleBackupDraft = useCallback(() => { - const { - getNodes, - edges, - } = store.getState() - const { getViewport } = reactflow - const { - backupDraft, - setBackupDraft, - environmentVariables, - } = workflowStore.getState() - const { features } = featuresStore!.getState() - - if (!backupDraft) { - setBackupDraft({ - nodes: getNodes(), - edges, - viewport: getViewport(), - features, - environmentVariables, - }) - doSyncWorkflowDraft() - } - }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft]) - - const handleLoadBackupDraft = useCallback(() => { - const { - backupDraft, - setBackupDraft, - setEnvironmentVariables, - } = workflowStore.getState() - - if (backupDraft) { - const { - nodes, - edges, - viewport, - features, - environmentVariables, - } = backupDraft - handleUpdateWorkflowCanvas({ - nodes, - edges, - viewport, - }) - setEnvironmentVariables(environmentVariables) - featuresStore!.setState({ features }) - setBackupDraft(undefined) - } - }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore]) - - const handleRun = useCallback(async ( - params: any, - callback?: IOtherOptions, - ) => { - const { - getNodes, - setNodes, - } = store.getState() - const newNodes = produce(getNodes(), (draft) => { - draft.forEach((node) => { - node.data.selected = false - node.data._runningStatus = undefined - }) - }) - setNodes(newNodes) - await doSyncWorkflowDraft() - - const { - onWorkflowStarted, - onWorkflowFinished, - onNodeStarted, - onNodeFinished, - onIterationStart, - onIterationNext, - onIterationFinish, - onLoopStart, - onLoopNext, - onLoopFinish, - onNodeRetry, - onAgentLog, - onError, - ...restCallback - } = callback || {} - workflowStore.setState({ historyWorkflowData: undefined }) - const appDetail = useAppStore.getState().appDetail - const workflowContainer = document.getElementById('workflow-container') - - const { - clientWidth, - clientHeight, - } = workflowContainer! - - let url = '' - if (appDetail?.mode === 'advanced-chat') - url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run` - - if (appDetail?.mode === 'workflow') - url = `/apps/${appDetail.id}/workflows/draft/run` - - const { - setWorkflowRunningData, - } = workflowStore.getState() - setWorkflowRunningData({ - result: { - status: WorkflowRunningStatus.Running, - }, - tracing: [], - resultText: '', - }) - - let ttsUrl = '' - let ttsIsPublic = false - if (params.token) { - ttsUrl = '/text-to-audio' - ttsIsPublic = true - } - else if (params.appId) { - if (pathname.search('explore/installed') > -1) - ttsUrl = `/installed-apps/${params.appId}/text-to-audio` - else - ttsUrl = `/apps/${params.appId}/text-to-audio` - } - const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', noop) - - ssePost( - url, - { - body: params, - }, - { - onWorkflowStarted: (params) => { - handleWorkflowStarted(params) - - if (onWorkflowStarted) - onWorkflowStarted(params) - }, - onWorkflowFinished: (params) => { - handleWorkflowFinished(params) - - if (onWorkflowFinished) - onWorkflowFinished(params) - }, - onError: (params) => { - handleWorkflowFailed() - - if (onError) - onError(params) - }, - onNodeStarted: (params) => { - handleWorkflowNodeStarted( - params, - { - clientWidth, - clientHeight, - }, - ) - - if (onNodeStarted) - onNodeStarted(params) - }, - onNodeFinished: (params) => { - handleWorkflowNodeFinished(params) - - if (onNodeFinished) - onNodeFinished(params) - }, - onIterationStart: (params) => { - handleWorkflowNodeIterationStarted( - params, - { - clientWidth, - clientHeight, - }, - ) - - if (onIterationStart) - onIterationStart(params) - }, - onIterationNext: (params) => { - handleWorkflowNodeIterationNext(params) - - if (onIterationNext) - onIterationNext(params) - }, - onIterationFinish: (params) => { - handleWorkflowNodeIterationFinished(params) - - if (onIterationFinish) - onIterationFinish(params) - }, - onLoopStart: (params) => { - handleWorkflowNodeLoopStarted( - params, - { - clientWidth, - clientHeight, - }, - ) - - if (onLoopStart) - onLoopStart(params) - }, - onLoopNext: (params) => { - handleWorkflowNodeLoopNext(params) - - if (onLoopNext) - onLoopNext(params) - }, - onLoopFinish: (params) => { - handleWorkflowNodeLoopFinished(params) - - if (onLoopFinish) - onLoopFinish(params) - }, - onNodeRetry: (params) => { - handleWorkflowNodeRetry(params) - - if (onNodeRetry) - onNodeRetry(params) - }, - onAgentLog: (params) => { - handleWorkflowAgentLog(params) - - if (onAgentLog) - onAgentLog(params) - }, - onTextChunk: (params) => { - handleWorkflowTextChunk(params) - }, - onTextReplace: (params) => { - handleWorkflowTextReplace(params) - }, - onTTSChunk: (messageId: string, audio: string) => { - if (!audio || audio === '') - return - player.playAudioWithAudio(audio, true) - AudioPlayerManager.getInstance().resetMsgId(messageId) - }, - onTTSEnd: (messageId: string, audio: string) => { - player.playAudioWithAudio(audio, false) - }, - ...restCallback, - }, - ) - }, [ - store, - workflowStore, - doSyncWorkflowDraft, - handleWorkflowStarted, - handleWorkflowFinished, - handleWorkflowFailed, - handleWorkflowNodeStarted, - handleWorkflowNodeFinished, - handleWorkflowNodeIterationStarted, - handleWorkflowNodeIterationNext, - handleWorkflowNodeIterationFinished, - handleWorkflowNodeLoopStarted, - handleWorkflowNodeLoopNext, - handleWorkflowNodeLoopFinished, - handleWorkflowNodeRetry, - handleWorkflowTextChunk, - handleWorkflowTextReplace, - handleWorkflowAgentLog, - pathname], - ) - - const handleStopRun = useCallback((taskId: string) => { - const appId = useAppStore.getState().appDetail?.id - - stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`) - }, []) - - const handleRestoreFromPublishedWorkflow = useCallback((publishedWorkflow: VersionHistory) => { - const nodes = publishedWorkflow.graph.nodes.map(node => ({ ...node, selected: false, data: { ...node.data, selected: false } })) - const edges = publishedWorkflow.graph.edges - const viewport = publishedWorkflow.graph.viewport! - handleUpdateWorkflowCanvas({ - nodes, - edges, - viewport, - }) - const mappedFeatures = { - opening: { - enabled: !!publishedWorkflow.features.opening_statement || !!publishedWorkflow.features.suggested_questions.length, - opening_statement: publishedWorkflow.features.opening_statement, - suggested_questions: publishedWorkflow.features.suggested_questions, - }, - suggested: publishedWorkflow.features.suggested_questions_after_answer, - text2speech: publishedWorkflow.features.text_to_speech, - speech2text: publishedWorkflow.features.speech_to_text, - citation: publishedWorkflow.features.retriever_resource, - moderation: publishedWorkflow.features.sensitive_word_avoidance, - file: publishedWorkflow.features.file_upload, - } - - featuresStore?.setState({ features: mappedFeatures }) - workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || []) - }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore]) + const handleBackupDraft = useHooksStore(s => s.handleBackupDraft) + const handleLoadBackupDraft = useHooksStore(s => s.handleLoadBackupDraft) + const handleRestoreFromPublishedWorkflow = useHooksStore(s => s.handleRestoreFromPublishedWorkflow) + const handleRun = useHooksStore(s => s.handleRun) + const handleStopRun = useHooksStore(s => s.handleStopRun) return { handleBackupDraft, diff --git a/web/app/components/workflow/hooks/use-workflow-start-run.tsx b/web/app/components/workflow/hooks/use-workflow-start-run.tsx index b2b1c69975..0f4e68fe95 100644 --- a/web/app/components/workflow/hooks/use-workflow-start-run.tsx +++ b/web/app/components/workflow/hooks/use-workflow-start-run.tsx @@ -1,92 +1,9 @@ -import { useCallback } from 'react' -import { useStoreApi } from 'reactflow' -import { useWorkflowStore } from '../store' -import { - BlockEnum, - WorkflowRunningStatus, -} from '../types' -import { - useIsChatMode, - useNodesSyncDraft, - useWorkflowInteractions, - useWorkflowRun, -} from './index' -import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { useHooksStore } from '@/app/components/workflow/hooks-store' export const useWorkflowStartRun = () => { - const store = useStoreApi() - const workflowStore = useWorkflowStore() - const featuresStore = useFeaturesStore() - const isChatMode = useIsChatMode() - const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() - const { handleRun } = useWorkflowRun() - const { doSyncWorkflowDraft } = useNodesSyncDraft() - - const handleWorkflowStartRunInWorkflow = useCallback(async () => { - const { - workflowRunningData, - } = workflowStore.getState() - - if (workflowRunningData?.result.status === WorkflowRunningStatus.Running) - return - - const { getNodes } = store.getState() - const nodes = getNodes() - const startNode = nodes.find(node => node.data.type === BlockEnum.Start) - const startVariables = startNode?.data.variables || [] - const fileSettings = featuresStore!.getState().features.file - const { - showDebugAndPreviewPanel, - setShowDebugAndPreviewPanel, - setShowInputsPanel, - setShowEnvPanel, - } = workflowStore.getState() - - setShowEnvPanel(false) - - if (showDebugAndPreviewPanel) { - handleCancelDebugAndPreviewPanel() - return - } - - if (!startVariables.length && !fileSettings?.image?.enabled) { - await doSyncWorkflowDraft() - handleRun({ inputs: {}, files: [] }) - setShowDebugAndPreviewPanel(true) - setShowInputsPanel(false) - } - else { - setShowDebugAndPreviewPanel(true) - setShowInputsPanel(true) - } - }, [store, workflowStore, featuresStore, handleCancelDebugAndPreviewPanel, handleRun, doSyncWorkflowDraft]) - - const handleWorkflowStartRunInChatflow = useCallback(async () => { - const { - showDebugAndPreviewPanel, - setShowDebugAndPreviewPanel, - setHistoryWorkflowData, - setShowEnvPanel, - setShowChatVariablePanel, - } = workflowStore.getState() - - setShowEnvPanel(false) - setShowChatVariablePanel(false) - - if (showDebugAndPreviewPanel) - handleCancelDebugAndPreviewPanel() - else - setShowDebugAndPreviewPanel(true) - - setHistoryWorkflowData(undefined) - }, [workflowStore, handleCancelDebugAndPreviewPanel]) - - const handleStartWorkflowRun = useCallback(() => { - if (!isChatMode) - handleWorkflowStartRunInWorkflow() - else - handleWorkflowStartRunInChatflow() - }, [isChatMode, handleWorkflowStartRunInWorkflow, handleWorkflowStartRunInChatflow]) + const handleStartWorkflowRun = useHooksStore(s => s.handleStartWorkflowRun) + const handleWorkflowStartRunInWorkflow = useHooksStore(s => s.handleWorkflowStartRunInWorkflow) + const handleWorkflowStartRunInChatflow = useHooksStore(s => s.handleWorkflowStartRunInChatflow) return { handleStartWorkflowRun, diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 7a15afa2e4..99dce4dc15 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -1,13 +1,9 @@ import { useCallback, - useEffect, useMemo, - useState, } from 'react' -import dayjs from 'dayjs' import { uniqBy } from 'lodash-es' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { getIncomers, getOutgoers, @@ -40,25 +36,15 @@ import { import { CUSTOM_NOTE_NODE } from '../note-node/constants' import { findUsedVarNodes, getNodeOutputVars, updateNodeVars } from '../nodes/_base/components/variable/utils' import { useNodesExtraData } from './use-nodes-data' -import { useWorkflowTemplate } from './use-workflow-template' import { useStore as useAppStore } from '@/app/components/app/store' -import { - fetchNodesDefaultConfigs, - fetchPublishedWorkflow, - fetchWorkflowDraft, - syncWorkflowDraft, -} from '@/service/workflow' -import type { FetchWorkflowDraftResponse } from '@/types/workflow' import { fetchAllBuiltInTools, fetchAllCustomTools, fetchAllWorkflowTools, } from '@/service/tools' -import I18n from '@/context/i18n' import { CollectionType } from '@/app/components/tools/types' import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants' import { CUSTOM_LOOP_START_NODE } from '@/app/components/workflow/nodes/loop-start/constants' -import { useWorkflowConfig } from '@/service/use-workflow' import { basePath } from '@/utils/var' import { canFindTool } from '@/utils' @@ -70,12 +56,9 @@ export const useIsChatMode = () => { export const useWorkflow = () => { const { t } = useTranslation() - const { locale } = useContext(I18n) const store = useStoreApi() const workflowStore = useWorkflowStore() - const appId = useStore(s => s.appId) const nodesExtraData = useNodesExtraData() - const { data: workflowConfig } = useWorkflowConfig(appId) const setPanelWidth = useCallback((width: number) => { localStorage.setItem('workflow-node-panel-width', `${width}`) workflowStore.setState({ panelWidth: width }) @@ -120,7 +103,7 @@ export const useWorkflow = () => { list.push(...incomers) - return uniqBy(list, 'id').filter((item) => { + return uniqBy(list, 'id').filter((item: Node) => { return SUPPORT_OUTPUT_VARS_NODE.includes(item.data.type) }) }, [store]) @@ -167,7 +150,7 @@ export const useWorkflow = () => { const length = list.length if (length) { - return uniqBy(list, 'id').reverse().filter((item) => { + return uniqBy(list, 'id').reverse().filter((item: Node) => { return SUPPORT_OUTPUT_VARS_NODE.includes(item.data.type) }) } @@ -344,6 +327,7 @@ export const useWorkflow = () => { parallelList, hasAbnormalEdges, } = getParallelInfo(nodes, edges, parentNodeId) + const { workflowConfig } = workflowStore.getState() if (hasAbnormalEdges) return false @@ -359,7 +343,7 @@ export const useWorkflow = () => { } return true - }, [t, workflowStore, workflowConfig?.parallel_depth_limit]) + }, [t, workflowStore]) const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => { const { @@ -407,10 +391,6 @@ export const useWorkflow = () => { return !hasCycle(targetNode) }, [store, nodesExtraData, checkParallelLimit]) - const formatTimeFromNow = useCallback((time: number) => { - return dayjs(time).locale(locale === 'zh-Hans' ? 'zh-cn' : locale).fromNow() - }, [locale]) - const getNode = useCallback((nodeId?: string) => { const { getNodes } = store.getState() const nodes = getNodes() @@ -432,7 +412,6 @@ export const useWorkflow = () => { checkNestedParallelLimit, isValidConnection, isFromStartNode, - formatTimeFromNow, getNode, getBeforeNodeById, getIterationNodeChildren, @@ -478,107 +457,6 @@ export const useFetchToolsData = () => { } } -export const useWorkflowInit = () => { - const workflowStore = useWorkflowStore() - const { - nodes: nodesTemplate, - edges: edgesTemplate, - } = useWorkflowTemplate() - const { handleFetchAllTools } = useFetchToolsData() - const appDetail = useAppStore(state => state.appDetail)! - const setSyncWorkflowDraftHash = useStore(s => s.setSyncWorkflowDraftHash) - const [data, setData] = useState() - const [isLoading, setIsLoading] = useState(true) - useEffect(() => { - workflowStore.setState({ appId: appDetail.id }) - }, [appDetail.id, workflowStore]) - - const handleGetInitialWorkflowData = useCallback(async () => { - try { - const res = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) - setData(res) - workflowStore.setState({ - envSecrets: (res.environment_variables || []).filter(env => env.value_type === 'secret').reduce((acc, env) => { - acc[env.id] = env.value - return acc - }, {} as Record), - environmentVariables: res.environment_variables?.map(env => env.value_type === 'secret' ? { ...env, value: '[__HIDDEN__]' } : env) || [], - conversationVariables: res.conversation_variables || [], - }) - setSyncWorkflowDraftHash(res.hash) - setIsLoading(false) - } - catch (error: any) { - if (error && error.json && !error.bodyUsed && appDetail) { - error.json().then((err: any) => { - if (err.code === 'draft_workflow_not_exist') { - workflowStore.setState({ notInitialWorkflow: true }) - syncWorkflowDraft({ - url: `/apps/${appDetail.id}/workflows/draft`, - params: { - graph: { - nodes: nodesTemplate, - edges: edgesTemplate, - }, - features: { - retriever_resource: { enabled: true }, - }, - environment_variables: [], - conversation_variables: [], - }, - }).then((res) => { - workflowStore.getState().setDraftUpdatedAt(res.updated_at) - handleGetInitialWorkflowData() - }) - } - }) - } - } - }, [appDetail, nodesTemplate, edgesTemplate, workflowStore, setSyncWorkflowDraftHash]) - - useEffect(() => { - handleGetInitialWorkflowData() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) - - const handleFetchPreloadData = useCallback(async () => { - try { - const nodesDefaultConfigsData = await fetchNodesDefaultConfigs(`/apps/${appDetail?.id}/workflows/default-workflow-block-configs`) - const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`) - workflowStore.setState({ - nodesDefaultConfigs: nodesDefaultConfigsData.reduce((acc, block) => { - if (!acc[block.type]) - acc[block.type] = { ...block.config } - return acc - }, {} as Record), - }) - workflowStore.getState().setPublishedAt(publishedWorkflow?.created_at) - } - catch (e) { - console.error(e) - } - }, [workflowStore, appDetail]) - - useEffect(() => { - handleFetchPreloadData() - handleFetchAllTools('builtin') - handleFetchAllTools('custom') - handleFetchAllTools('workflow') - }, [handleFetchPreloadData, handleFetchAllTools]) - - useEffect(() => { - if (data) { - workflowStore.getState().setDraftUpdatedAt(data.updated_at) - workflowStore.getState().setToolPublished(data.tool_published) - } - }, [data, workflowStore]) - - return { - data, - isLoading, - } -} - export const useWorkflowReadOnly = () => { const workflowStore = useWorkflowStore() const workflowRunningData = useStore(s => s.workflowRunningData) diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index 4c48afb56c..9a3e13822a 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -5,11 +5,8 @@ import { memo, useCallback, useEffect, - useMemo, useRef, - useState, } from 'react' -import useSWR from 'swr' import { setAutoFreeze } from 'immer' import { useEventListener, @@ -31,17 +28,14 @@ import 'reactflow/dist/style.css' import './style.css' import type { Edge, - EnvironmentVariable, Node, } from './types' import { ControlMode, - SupportUploadFileTypes, } from './types' -import { WorkflowContextProvider } from './context' import { - useDSL, useEdgesInteractions, + useFetchToolsData, useNodesInteractions, useNodesReadOnly, useNodesSyncDraft, @@ -49,11 +43,9 @@ import { useSelectionInteractions, useShortcuts, useWorkflow, - useWorkflowInit, useWorkflowReadOnly, useWorkflowUpdate, } from './hooks' -import Header from './header' import CustomNode from './nodes' import CustomNoteNode from './note-node' import { CUSTOM_NOTE_NODE } from './note-node/constants' @@ -66,42 +58,28 @@ import { CUSTOM_SIMPLE_NODE } from './simple-node/constants' import Operator from './operator' import CustomEdge from './custom-edge' import CustomConnectionLine from './custom-connection-line' -import Panel from './panel' -import Features from './features' import HelpLine from './help-line' import CandidateNode from './candidate-node' import PanelContextmenu from './panel-contextmenu' import NodeContextmenu from './node-contextmenu' import SyncingDataModal from './syncing-data-modal' -import UpdateDSLModal from './update-dsl-modal' -import DSLExportConfirmModal from './dsl-export-confirm-modal' import LimitTips from './limit-tips' -import PluginDependency from './plugin-dependency' import { useStore, useWorkflowStore, } from './store' -import { - initialEdges, - initialNodes, -} from './utils' import { CUSTOM_EDGE, CUSTOM_NODE, - DSL_EXPORT_CHECK, ITERATION_CHILDREN_Z_INDEX, WORKFLOW_DATA_UPDATE, } from './constants' import { WorkflowHistoryProvider } from './workflow-history-store' -import Loading from '@/app/components/base/loading' -import { FeaturesProvider } from '@/app/components/base/features' -import type { Features as FeaturesData } from '@/app/components/base/features/types' -import { useFeaturesStore } from '@/app/components/base/features/hooks' import { useEventEmitterContextContext } from '@/context/event-emitter' import Confirm from '@/app/components/base/confirm' -import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' -import { fetchFileUploadConfig } from '@/service/common' import DatasetsDetailProvider from './datasets-detail-store/provider' +import { HooksStoreContextProvider } from './hooks-store' +import type { Shape as HooksStoreShape } from './hooks-store' const nodeTypes = { [CUSTOM_NODE]: CustomNode, @@ -114,32 +92,32 @@ const edgeTypes = { [CUSTOM_EDGE]: CustomEdge, } -type WorkflowProps = { +export type WorkflowProps = { nodes: Node[] edges: Edge[] viewport?: Viewport + children?: React.ReactNode + onWorkflowDataUpdate?: (v: any) => void } -const Workflow: FC = memo(({ +export const Workflow: FC = memo(({ nodes: originalNodes, edges: originalEdges, viewport, + children, + onWorkflowDataUpdate, }) => { const workflowContainerRef = useRef(null) const workflowStore = useWorkflowStore() const reactflow = useReactFlow() - const featuresStore = useFeaturesStore() const [nodes, setNodes] = useNodesState(originalNodes) const [edges, setEdges] = useEdgesState(originalEdges) - const showFeaturesPanel = useStore(state => state.showFeaturesPanel) const controlMode = useStore(s => s.controlMode) const nodeAnimation = useStore(s => s.nodeAnimation) const showConfirm = useStore(s => s.showConfirm) - const showImportDSLModal = useStore(s => s.showImportDSLModal) const { setShowConfirm, setControlPromptEditorRerenderKey, - setShowImportDSLModal, setSyncWorkflowDraftHash, } = workflowStore.getState() const { @@ -148,9 +126,6 @@ const Workflow: FC = memo(({ } = useNodesSyncDraft() const { workflowReadOnly } = useWorkflowReadOnly() const { nodesReadOnly } = useNodesReadOnly() - - const [secretEnvList, setSecretEnvList] = useState([]) - const { eventEmitter } = useEventEmitterContextContext() eventEmitter?.useSubscription((v: any) => { @@ -161,19 +136,13 @@ const Workflow: FC = memo(({ if (v.payload.viewport) reactflow.setViewport(v.payload.viewport) - if (v.payload.features && featuresStore) { - const { setFeatures } = featuresStore.getState() - - setFeatures(v.payload.features) - } - if (v.payload.hash) setSyncWorkflowDraftHash(v.payload.hash) + onWorkflowDataUpdate?.(v.payload) + setTimeout(() => setControlPromptEditorRerenderKey(Date.now())) } - if (v.type === DSL_EXPORT_CHECK) - setSecretEnvList(v.payload.data as EnvironmentVariable[]) }) useEffect(() => { @@ -231,6 +200,12 @@ const Workflow: FC = memo(({ }) } }) + const { handleFetchAllTools } = useFetchToolsData() + useEffect(() => { + handleFetchAllTools('builtin') + handleFetchAllTools('custom') + handleFetchAllTools('workflow') + }, [handleFetchAllTools]) const { handleNodeDragStart, @@ -258,15 +233,10 @@ const Workflow: FC = memo(({ } = useSelectionInteractions() const { handlePaneContextMenu, - handlePaneContextmenuCancel, } = usePanelInteractions() const { isValidConnection, } = useWorkflow() - const { - exportCheck, - handleExportDSL, - } = useDSL() useOnViewportChange({ onEnd: () => { @@ -297,12 +267,7 @@ const Workflow: FC = memo(({ > -
- - { - showFeaturesPanel && - } @@ -317,26 +282,8 @@ const Workflow: FC = memo(({ /> ) } - { - showImportDSLModal && ( - setShowImportDSLModal(false)} - onBackup={exportCheck} - onImport={handlePaneContextmenuCancel} - /> - ) - } - { - secretEnvList.length > 0 && ( - setSecretEnvList([])} - /> - ) - } - + {children} = memo(({
) }) -Workflow.displayName = 'Workflow' -const WorkflowWrap = memo(() => { - const { - data, - isLoading, - } = useWorkflowInit() - const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) +type WorkflowWithInnerContextProps = WorkflowProps & { + hooksStore?: Partial +} +export const WorkflowWithInnerContext = memo(({ + hooksStore, + ...restProps +}: WorkflowWithInnerContextProps) => { + return ( + + + + ) +}) - const nodesData = useMemo(() => { - if (data) - return initialNodes(data.graph.nodes, data.graph.edges) - - return [] - }, [data]) - const edgesData = useMemo(() => { - if (data) - return initialEdges(data.graph.edges, data.graph.nodes) - - return [] - }, [data]) - - if (!data || isLoading) { - return ( -
- -
- ) - } - - const features = data.features || {} - const initialFeatures: FeaturesData = { - file: { - image: { - enabled: !!features.file_upload?.image?.enabled, - number_limits: features.file_upload?.image?.number_limits || 3, - transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], - }, - enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled), - allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], - allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), - allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], - number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3, - fileUploadConfig: fileUploadConfigResponse, - }, - opening: { - enabled: !!features.opening_statement, - opening_statement: features.opening_statement, - suggested_questions: features.suggested_questions, - }, - suggested: features.suggested_questions_after_answer || { enabled: false }, - speech2text: features.speech_to_text || { enabled: false }, - text2speech: features.text_to_speech || { enabled: false }, - citation: features.retriever_resource || { enabled: false }, - moderation: features.sensitive_word_avoidance || { enabled: false }, +type WorkflowWithDefaultContextProps = + Pick + & { + children: React.ReactNode } +const WorkflowWithDefaultContext = ({ + nodes, + edges, + children, +}: WorkflowWithDefaultContextProps) => { return ( - - - - - + nodes={nodes} + edges={edges} > + + {children} + ) -}) -WorkflowWrap.displayName = 'WorkflowWrap' - -const WorkflowContainer = () => { - return ( - - - - ) } -export default memo(WorkflowContainer) +export default memo(WorkflowWithDefaultContext) diff --git a/web/app/components/workflow/panel/index.tsx b/web/app/components/workflow/panel/index.tsx index 40920ab256..8e510f4e77 100644 --- a/web/app/components/workflow/panel/index.tsx +++ b/web/app/components/workflow/panel/index.tsx @@ -1,43 +1,25 @@ import type { FC } from 'react' import { memo } from 'react' import { useNodes } from 'reactflow' -import { useShallow } from 'zustand/react/shallow' import type { CommonNodeType } from '../types' import { Panel as NodePanel } from '../nodes' import { useStore } from '../store' -import { - useIsChatMode, -} from '../hooks' -import DebugAndPreview from './debug-and-preview' -import Record from './record' -import WorkflowPreview from './workflow-preview' -import ChatRecord from './chat-record' -import ChatVariablePanel from './chat-variable-panel' import EnvPanel from './env-panel' -import GlobalVariablePanel from './global-variable-panel' -import VersionHistoryPanel from './version-history-panel' import cn from '@/utils/classnames' -import { useStore as useAppStore } from '@/app/components/app/store' -import MessageLogModal from '@/app/components/base/message-log-modal' -const Panel: FC = () => { +export type PanelProps = { + components?: { + left?: React.ReactNode + right?: React.ReactNode + } +} +const Panel: FC = ({ + components, +}) => { const nodes = useNodes() - const isChatMode = useIsChatMode() const selectedNode = nodes.find(node => node.data.selected) - const historyWorkflowData = useStore(s => s.historyWorkflowData) - const showDebugAndPreviewPanel = useStore(s => s.showDebugAndPreviewPanel) const showEnvPanel = useStore(s => s.showEnvPanel) - const showChatVariablePanel = useStore(s => s.showChatVariablePanel) - const showGlobalVariablePanel = useStore(s => s.showGlobalVariablePanel) - const showWorkflowVersionHistoryPanel = useStore(s => s.showWorkflowVersionHistoryPanel) const isRestoring = useStore(s => s.isRestoring) - const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ - currentLogItem: state.currentLogItem, - setCurrentLogItem: state.setCurrentLogItem, - showMessageLogModal: state.showMessageLogModal, - setShowMessageLogModal: state.setShowMessageLogModal, - currentLogModalActiveTab: state.currentLogModalActiveTab, - }))) return (
{ key={`${isRestoring}`} > { - showMessageLogModal && ( - { - setCurrentLogItem() - setShowMessageLogModal(false) - }} - defaultTab={currentLogModalActiveTab} - /> - ) + components?.left } { !!selectedNode && ( @@ -65,45 +36,13 @@ const Panel: FC = () => { ) } { - historyWorkflowData && !isChatMode && ( - - ) - } - { - historyWorkflowData && isChatMode && ( - - ) - } - { - showDebugAndPreviewPanel && isChatMode && ( - - ) - } - { - showDebugAndPreviewPanel && !isChatMode && ( - - ) + components?.right } { showEnvPanel && ( ) } - { - showChatVariablePanel && ( - - ) - } - { - showGlobalVariablePanel && ( - - ) - } - { - showWorkflowVersionHistoryPanel && ( - - ) - }
) } diff --git a/web/app/components/workflow/store/workflow/index.ts b/web/app/components/workflow/store/workflow/index.ts index 769b986606..0e2f5eb0f7 100644 --- a/web/app/components/workflow/store/workflow/index.ts +++ b/web/app/components/workflow/store/workflow/index.ts @@ -1,4 +1,7 @@ import { useContext } from 'react' +import type { + StateCreator, +} from 'zustand' import { useStore as useZustandStore, } from 'zustand' @@ -26,6 +29,7 @@ import { createWorkflowDraftSlice } from './workflow-draft-slice' import type { WorkflowSliceShape } from './workflow-slice' import { createWorkflowSlice } from './workflow-slice' import { WorkflowContext } from '@/app/components/workflow/context' +import type { WorkflowSliceShape as WorkflowAppSliceShape } from '@/app/components/workflow-app/store/workflow/workflow-slice' export type Shape = ChatVariableSliceShape & @@ -38,9 +42,16 @@ export type Shape = ToolSliceShape & VersionSliceShape & WorkflowDraftSliceShape & - WorkflowSliceShape + WorkflowSliceShape & + WorkflowAppSliceShape + +type CreateWorkflowStoreParams = { + injectWorkflowStoreSliceFn?: StateCreator +} + +export const createWorkflowStore = (params: CreateWorkflowStoreParams) => { + const { injectWorkflowStoreSliceFn } = params || {} -export const createWorkflowStore = () => { return createStore((...args) => ({ ...createChatVariableSlice(...args), ...createEnvVariableSlice(...args), @@ -53,6 +64,7 @@ export const createWorkflowStore = () => { ...createVersionSlice(...args), ...createWorkflowDraftSlice(...args), ...createWorkflowSlice(...args), + ...(injectWorkflowStoreSliceFn?.(...args) || {} as WorkflowAppSliceShape), })) } diff --git a/web/app/components/workflow/store/workflow/node-slice.ts b/web/app/components/workflow/store/workflow/node-slice.ts index d937dc2099..2068ee0ba1 100644 --- a/web/app/components/workflow/store/workflow/node-slice.ts +++ b/web/app/components/workflow/store/workflow/node-slice.ts @@ -12,8 +12,6 @@ import type { export type NodeSliceShape = { showSingleRunPanel: boolean setShowSingleRunPanel: (showSingleRunPanel: boolean) => void - nodesDefaultConfigs: Record - setNodesDefaultConfigs: (nodesDefaultConfigs: Record) => void nodeAnimation: boolean setNodeAnimation: (nodeAnimation: boolean) => void candidateNode?: Node @@ -55,8 +53,6 @@ export type NodeSliceShape = { export const createNodeSlice: StateCreator = set => ({ showSingleRunPanel: false, setShowSingleRunPanel: showSingleRunPanel => set(() => ({ showSingleRunPanel })), - nodesDefaultConfigs: {}, - setNodesDefaultConfigs: nodesDefaultConfigs => set(() => ({ nodesDefaultConfigs })), nodeAnimation: false, setNodeAnimation: nodeAnimation => set(() => ({ nodeAnimation })), candidateNode: undefined, diff --git a/web/app/components/workflow/store/workflow/workflow-slice.ts b/web/app/components/workflow/store/workflow/workflow-slice.ts index 19248161d2..6bb69cdfcd 100644 --- a/web/app/components/workflow/store/workflow/workflow-slice.ts +++ b/web/app/components/workflow/store/workflow/workflow-slice.ts @@ -10,11 +10,8 @@ type PreviewRunningData = WorkflowRunningData & { } export type WorkflowSliceShape = { - appId: string workflowRunningData?: PreviewRunningData setWorkflowRunningData: (workflowData: PreviewRunningData) => void - notInitialWorkflow: boolean - setNotInitialWorkflow: (notInitialWorkflow: boolean) => void clipboardElements: Node[] setClipboardElements: (clipboardElements: Node[]) => void selection: null | { x1: number; y1: number; x2: number; y2: number } @@ -33,14 +30,13 @@ export type WorkflowSliceShape = { setShowImportDSLModal: (showImportDSLModal: boolean) => void showTips: string setShowTips: (showTips: string) => void + workflowConfig?: Record + setWorkflowConfig: (workflowConfig: Record) => void } export const createWorkflowSlice: StateCreator = set => ({ - appId: '', workflowRunningData: undefined, setWorkflowRunningData: workflowRunningData => set(() => ({ workflowRunningData })), - notInitialWorkflow: false, - setNotInitialWorkflow: notInitialWorkflow => set(() => ({ notInitialWorkflow })), clipboardElements: [], setClipboardElements: clipboardElements => set(() => ({ clipboardElements })), selection: null, @@ -62,4 +58,6 @@ export const createWorkflowSlice: StateCreator = set => ({ setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), showTips: '', setShowTips: showTips => set(() => ({ showTips })), + workflowConfig: undefined, + setWorkflowConfig: workflowConfig => set(() => ({ workflowConfig })), }) diff --git a/web/service/use-workflow.ts b/web/service/use-workflow.ts index ee4132d22f..4321552cc7 100644 --- a/web/service/use-workflow.ts +++ b/web/service/use-workflow.ts @@ -21,10 +21,14 @@ export const useAppWorkflow = (appID: string) => { }) } -export const useWorkflowConfig = (appId: string) => { +export const useWorkflowConfig = (appId: string, onSuccess: (v: WorkflowConfigResponse) => void) => { return useQuery({ queryKey: [NAME_SPACE, 'config', appId], - queryFn: () => get(`/apps/${appId}/workflows/draft/config`), + queryFn: async () => { + const data = await get(`/apps/${appId}/workflows/draft/config`) + onSuccess(data) + return data + }, }) } From 1e7418095f6ba4668987841a527d8b74a891a1aa Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:54:22 +0800 Subject: [PATCH 47/68] feat/TanStack-Form (#18346) --- .../config-var/config-select/index.spec.tsx | 82 ++++++++++++++++ .../config-var/config-select/index.tsx | 3 +- .../checkbox/assets/indeterminate-icon.tsx | 11 +++ .../components/base/checkbox/assets/mixed.svg | 5 - .../components/base/checkbox/index.module.css | 10 -- .../components/base/checkbox/index.spec.tsx | 67 +++++++++++++ web/app/components/base/checkbox/index.tsx | 49 +++++----- .../base/form/components/field/checkbox.tsx | 43 ++++++++ .../form/components/field/number-input.tsx | 49 ++++++++++ .../base/form/components/field/options.tsx | 34 +++++++ .../base/form/components/field/select.tsx | 51 ++++++++++ .../base/form/components/field/text.tsx | 48 +++++++++ .../form/components/form/submit-button.tsx | 25 +++++ .../base/form/components/label.spec.tsx | 53 ++++++++++ .../components/base/form/components/label.tsx | 48 +++++++++ .../form-scenarios/demo/contact-fields.tsx | 35 +++++++ .../base/form/form-scenarios/demo/index.tsx | 68 +++++++++++++ .../form-scenarios/demo/shared-options.tsx | 14 +++ .../base/form/form-scenarios/demo/types.ts | 34 +++++++ web/app/components/base/form/index.tsx | 25 +++++ .../base/input-number/index.spec.tsx | 97 +++++++++++++++++++ .../components/base/input-number/index.tsx | 42 ++++---- web/app/components/base/input/index.tsx | 2 +- web/app/components/base/param-item/index.tsx | 2 +- web/app/components/base/tooltip/index.tsx | 4 +- .../datasets/create/step-two/inputs.tsx | 4 +- .../documents/detail/completed/index.tsx | 55 +++++------ .../detail/completed/segment-card/index.tsx | 6 +- .../detail/completed/segment-detail.tsx | 9 +- .../detail/completed/segment-list.tsx | 2 +- .../components/datasets/documents/list.tsx | 6 +- .../edit-metadata-batch/input-combined.tsx | 2 +- .../nodes/_base/components/agent-strategy.tsx | 2 +- web/app/dev-preview/page.tsx | 20 ++-- web/jest.config.ts | 13 +-- web/jest.setup.ts | 5 + web/package.json | 1 + web/pnpm-lock.yaml | 60 ++++++++++++ 38 files changed, 959 insertions(+), 127 deletions(-) create mode 100644 web/app/components/app/configuration/config-var/config-select/index.spec.tsx create mode 100644 web/app/components/base/checkbox/assets/indeterminate-icon.tsx delete mode 100644 web/app/components/base/checkbox/assets/mixed.svg delete mode 100644 web/app/components/base/checkbox/index.module.css create mode 100644 web/app/components/base/checkbox/index.spec.tsx create mode 100644 web/app/components/base/form/components/field/checkbox.tsx create mode 100644 web/app/components/base/form/components/field/number-input.tsx create mode 100644 web/app/components/base/form/components/field/options.tsx create mode 100644 web/app/components/base/form/components/field/select.tsx create mode 100644 web/app/components/base/form/components/field/text.tsx create mode 100644 web/app/components/base/form/components/form/submit-button.tsx create mode 100644 web/app/components/base/form/components/label.spec.tsx create mode 100644 web/app/components/base/form/components/label.tsx create mode 100644 web/app/components/base/form/form-scenarios/demo/contact-fields.tsx create mode 100644 web/app/components/base/form/form-scenarios/demo/index.tsx create mode 100644 web/app/components/base/form/form-scenarios/demo/shared-options.tsx create mode 100644 web/app/components/base/form/form-scenarios/demo/types.ts create mode 100644 web/app/components/base/form/index.tsx create mode 100644 web/app/components/base/input-number/index.spec.tsx diff --git a/web/app/components/app/configuration/config-var/config-select/index.spec.tsx b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx new file mode 100644 index 0000000000..18df318de3 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx @@ -0,0 +1,82 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import ConfigSelect from './index' + +jest.mock('react-sortablejs', () => ({ + ReactSortable: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('ConfigSelect Component', () => { + const defaultProps = { + options: ['Option 1', 'Option 2'], + onChange: jest.fn(), + } + + afterEach(() => { + jest.clearAllMocks() + }) + + it('renders all options', () => { + render() + + defaultProps.options.forEach((option) => { + expect(screen.getByDisplayValue(option)).toBeInTheDocument() + }) + }) + + it('renders add button', () => { + render() + + expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument() + }) + + it('handles option deletion', () => { + render() + const optionContainer = screen.getByDisplayValue('Option 1').closest('div') + const deleteButton = optionContainer?.querySelector('div[role="button"]') + + if (!deleteButton) return + fireEvent.click(deleteButton) + expect(defaultProps.onChange).toHaveBeenCalledWith(['Option 2']) + }) + + it('handles adding new option', () => { + render() + const addButton = screen.getByText('appDebug.variableConfig.addOption') + + fireEvent.click(addButton) + + expect(defaultProps.onChange).toHaveBeenCalledWith([...defaultProps.options, '']) + }) + + it('applies focus styles on input focus', () => { + render() + const firstInput = screen.getByDisplayValue('Option 1') + + fireEvent.focus(firstInput) + + expect(firstInput.closest('div')).toHaveClass('border-components-input-border-active') + }) + + it('applies delete hover styles', () => { + render() + const optionContainer = screen.getByDisplayValue('Option 1').closest('div') + const deleteButton = optionContainer?.querySelector('div[role="button"]') + + if (!deleteButton) return + fireEvent.mouseEnter(deleteButton) + expect(optionContainer).toHaveClass('border-components-input-border-destructive') + }) + + it('renders empty state correctly', () => { + render() + + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config-var/config-select/index.tsx b/web/app/components/app/configuration/config-var/config-select/index.tsx index d2dc1662c1..40ddaef78f 100644 --- a/web/app/components/app/configuration/config-var/config-select/index.tsx +++ b/web/app/components/app/configuration/config-var/config-select/index.tsx @@ -51,7 +51,7 @@ const ConfigSelect: FC = ({ { const value = e.target.value @@ -67,6 +67,7 @@ const ConfigSelect: FC = ({ onBlur={() => setFocusID(null)} />
{ onChange(options.filter((_, i) => index !== i)) diff --git a/web/app/components/base/checkbox/assets/indeterminate-icon.tsx b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx new file mode 100644 index 0000000000..56df8db6a4 --- /dev/null +++ b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx @@ -0,0 +1,11 @@ +const IndeterminateIcon = () => { + return ( +
+ + + +
+ ) +} + +export default IndeterminateIcon diff --git a/web/app/components/base/checkbox/assets/mixed.svg b/web/app/components/base/checkbox/assets/mixed.svg deleted file mode 100644 index e16b8fc975..0000000000 --- a/web/app/components/base/checkbox/assets/mixed.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/web/app/components/base/checkbox/index.module.css b/web/app/components/base/checkbox/index.module.css deleted file mode 100644 index d675607b46..0000000000 --- a/web/app/components/base/checkbox/index.module.css +++ /dev/null @@ -1,10 +0,0 @@ -.mixed { - background: var(--color-components-checkbox-bg) url(./assets/mixed.svg) center center no-repeat; - background-size: 12px 12px; - border: none; -} - -.checked.disabled { - background-color: #d0d5dd; - border-color: #d0d5dd; -} \ No newline at end of file diff --git a/web/app/components/base/checkbox/index.spec.tsx b/web/app/components/base/checkbox/index.spec.tsx new file mode 100644 index 0000000000..7ef901aef5 --- /dev/null +++ b/web/app/components/base/checkbox/index.spec.tsx @@ -0,0 +1,67 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import Checkbox from './index' + +describe('Checkbox Component', () => { + const mockProps = { + id: 'test', + } + + it('renders unchecked checkbox by default', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toBeInTheDocument() + expect(checkbox).not.toHaveClass('bg-components-checkbox-bg') + }) + + it('renders checked checkbox when checked prop is true', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg') + expect(screen.getByTestId('check-icon-test')).toBeInTheDocument() + }) + + it('renders indeterminate state correctly', () => { + render() + expect(screen.getByTestId('indeterminate-icon')).toBeInTheDocument() + }) + + it('handles click events when not disabled', () => { + const onCheck = jest.fn() + render() + const checkbox = screen.getByTestId('checkbox-test') + + fireEvent.click(checkbox) + expect(onCheck).toHaveBeenCalledTimes(1) + }) + + it('does not handle click events when disabled', () => { + const onCheck = jest.fn() + render() + const checkbox = screen.getByTestId('checkbox-test') + + fireEvent.click(checkbox) + expect(onCheck).not.toHaveBeenCalled() + expect(checkbox).toHaveClass('cursor-not-allowed') + }) + + it('applies custom className when provided', () => { + const customClass = 'custom-class' + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass(customClass) + }) + + it('applies correct styles for disabled checked state', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled-checked') + expect(checkbox).toHaveClass('cursor-not-allowed') + }) + + it('applies correct styles for disabled unchecked state', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled') + expect(checkbox).toHaveClass('cursor-not-allowed') + }) +}) diff --git a/web/app/components/base/checkbox/index.tsx b/web/app/components/base/checkbox/index.tsx index b0b0ebca7c..3e47967c62 100644 --- a/web/app/components/base/checkbox/index.tsx +++ b/web/app/components/base/checkbox/index.tsx @@ -1,48 +1,49 @@ import { RiCheckLine } from '@remixicon/react' -import s from './index.module.css' import cn from '@/utils/classnames' +import IndeterminateIcon from './assets/indeterminate-icon' type CheckboxProps = { + id?: string checked?: boolean onCheck?: () => void className?: string disabled?: boolean - mixed?: boolean + indeterminate?: boolean } -const Checkbox = ({ checked, onCheck, className, disabled, mixed }: CheckboxProps) => { - if (!checked) { - return ( -
{ - if (disabled) - return - onCheck?.() - }} - >
- ) - } +const Checkbox = ({ + id, + checked, + onCheck, + className, + disabled, + indeterminate, +}: CheckboxProps) => { + const checkClassName = (checked || indeterminate) + ? 'bg-components-checkbox-bg text-components-checkbox-icon hover:bg-components-checkbox-bg-hover' + : 'border border-components-checkbox-border bg-components-checkbox-bg-unchecked hover:bg-components-checkbox-bg-unchecked-hover hover:border-components-checkbox-border-hover' + const disabledClassName = (checked || indeterminate) + ? 'cursor-not-allowed bg-components-checkbox-bg-disabled-checked text-components-checkbox-icon-disabled hover:bg-components-checkbox-bg-disabled-checked' + : 'cursor-not-allowed border-components-checkbox-border-disabled bg-components-checkbox-bg-disabled hover:border-components-checkbox-border-disabled hover:bg-components-checkbox-bg-disabled' + return (
{ if (disabled) return - onCheck?.() }} + data-testid={`checkbox-${id}`} > - + {!checked && indeterminate && } + {checked && }
) } diff --git a/web/app/components/base/form/components/field/checkbox.tsx b/web/app/components/base/form/components/field/checkbox.tsx new file mode 100644 index 0000000000..855dbd80fe --- /dev/null +++ b/web/app/components/base/form/components/field/checkbox.tsx @@ -0,0 +1,43 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import Checkbox from '../../../checkbox' + +type CheckboxFieldProps = { + label: string; + labelClassName?: string; +} + +const CheckboxField = ({ + label, + labelClassName, +}: CheckboxFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ { + field.handleChange(!field.state.value) + }} + /> +
+ +
+ ) +} + +export default CheckboxField diff --git a/web/app/components/base/form/components/field/number-input.tsx b/web/app/components/base/form/components/field/number-input.tsx new file mode 100644 index 0000000000..fce3143fe1 --- /dev/null +++ b/web/app/components/base/form/components/field/number-input.tsx @@ -0,0 +1,49 @@ +import React from 'react' +import { useFieldContext } from '../..' +import Label from '../label' +import cn from '@/utils/classnames' +import type { InputNumberProps } from '../../../input-number' +import { InputNumber } from '../../../input-number' + +type TextFieldProps = { + label: string + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} & Omit + +const NumberInputField = ({ + label, + isRequired, + showOptional, + tooltip, + className, + labelClassName, + ...inputProps +}: TextFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default NumberInputField diff --git a/web/app/components/base/form/components/field/options.tsx b/web/app/components/base/form/components/field/options.tsx new file mode 100644 index 0000000000..9ff71e50af --- /dev/null +++ b/web/app/components/base/form/components/field/options.tsx @@ -0,0 +1,34 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import Label from '../label' +import ConfigSelect from '@/app/components/app/configuration/config-var/config-select' + +type OptionsFieldProps = { + label: string; + className?: string; + labelClassName?: string; +} + +const OptionsField = ({ + label, + className, + labelClassName, +}: OptionsFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default OptionsField diff --git a/web/app/components/base/form/components/field/select.tsx b/web/app/components/base/form/components/field/select.tsx new file mode 100644 index 0000000000..95af3c0116 --- /dev/null +++ b/web/app/components/base/form/components/field/select.tsx @@ -0,0 +1,51 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import PureSelect from '../../../select/pure' +import Label from '../label' + +type SelectOption = { + value: string + label: string +} + +type SelectFieldProps = { + label: string + options: SelectOption[] + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} + +const SelectField = ({ + label, + options, + isRequired, + showOptional, + tooltip, + className, + labelClassName, +}: SelectFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default SelectField diff --git a/web/app/components/base/form/components/field/text.tsx b/web/app/components/base/form/components/field/text.tsx new file mode 100644 index 0000000000..b2090291a0 --- /dev/null +++ b/web/app/components/base/form/components/field/text.tsx @@ -0,0 +1,48 @@ +import React from 'react' +import { useFieldContext } from '../..' +import Input, { type InputProps } from '../../../input' +import Label from '../label' +import cn from '@/utils/classnames' + +type TextFieldProps = { + label: string + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} & Omit + +const TextField = ({ + label, + isRequired, + showOptional, + tooltip, + className, + labelClassName, + ...inputProps +}: TextFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default TextField diff --git a/web/app/components/base/form/components/form/submit-button.tsx b/web/app/components/base/form/components/form/submit-button.tsx new file mode 100644 index 0000000000..494d19b843 --- /dev/null +++ b/web/app/components/base/form/components/form/submit-button.tsx @@ -0,0 +1,25 @@ +import { useStore } from '@tanstack/react-form' +import { useFormContext } from '../..' +import Button, { type ButtonProps } from '../../../button' + +type SubmitButtonProps = Omit + +const SubmitButton = ({ ...buttonProps }: SubmitButtonProps) => { + const form = useFormContext() + + const [isSubmitting, canSubmit] = useStore(form.store, state => [ + state.isSubmitting, + state.canSubmit, + ]) + + return ( +
+ ) +} + +export default Label diff --git a/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx b/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx new file mode 100644 index 0000000000..9ba664fc10 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx @@ -0,0 +1,35 @@ +import { withForm } from '../..' +import { demoFormOpts } from './shared-options' +import { ContactMethods } from './types' + +const ContactFields = withForm({ + ...demoFormOpts, + render: ({ form }) => { + return ( +
+

Contacts

+
+ } + /> + } + /> + ( + + )} + /> +
+
+ ) + }, +}) + +export default ContactFields diff --git a/web/app/components/base/form/form-scenarios/demo/index.tsx b/web/app/components/base/form/form-scenarios/demo/index.tsx new file mode 100644 index 0000000000..f08edee41e --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/index.tsx @@ -0,0 +1,68 @@ +import { useStore } from '@tanstack/react-form' +import { useAppForm } from '../..' +import ContactFields from './contact-fields' +import { demoFormOpts } from './shared-options' +import { UserSchema } from './types' + +const DemoForm = () => { + const form = useAppForm({ + ...demoFormOpts, + validators: { + onSubmit: ({ value }) => { + // Validate the entire form + const result = UserSchema.safeParse(value) + if (!result.success) { + const issues = result.error.issues + console.log('Validation errors:', issues) + return issues[0].message + } + return undefined + }, + }, + onSubmit: ({ value }) => { + console.log('Form submitted:', value) + }, + }) + +const name = useStore(form.store, state => state.values.name) + + return ( +
{ + e.preventDefault() + e.stopPropagation() + form.handleSubmit() + }} + > + ( + + )} + /> + ( + + )} + /> + ( + + )} + /> + { + !!name && ( + + ) + } + + Submit + + + ) +} + +export default DemoForm diff --git a/web/app/components/base/form/form-scenarios/demo/shared-options.tsx b/web/app/components/base/form/form-scenarios/demo/shared-options.tsx new file mode 100644 index 0000000000..8b216c8b90 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/shared-options.tsx @@ -0,0 +1,14 @@ +import { formOptions } from '@tanstack/react-form' + +export const demoFormOpts = formOptions({ + defaultValues: { + name: '', + surname: '', + isAcceptingTerms: false, + contact: { + email: '', + phone: '', + preferredContactMethod: 'email', + }, + }, +}) diff --git a/web/app/components/base/form/form-scenarios/demo/types.ts b/web/app/components/base/form/form-scenarios/demo/types.ts new file mode 100644 index 0000000000..c4e626ef63 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/types.ts @@ -0,0 +1,34 @@ +import { z } from 'zod' + +const ContactMethod = z.union([ + z.literal('email'), + z.literal('phone'), + z.literal('whatsapp'), + z.literal('sms'), +]) + +export const ContactMethods = ContactMethod.options.map(({ value }) => ({ + value, + label: value.charAt(0).toUpperCase() + value.slice(1), +})) + +export const UserSchema = z.object({ + name: z + .string() + .regex(/^[A-Z]/, 'Name must start with a capital letter') + .min(3, 'Name must be at least 3 characters long'), + surname: z + .string() + .min(3, 'Surname must be at least 3 characters long') + .regex(/^[A-Z]/, 'Surname must start with a capital letter'), + isAcceptingTerms: z.boolean().refine(val => val, { + message: 'You must accept the terms and conditions', + }), + contact: z.object({ + email: z.string().email('Invalid email address'), + phone: z.string().optional(), + preferredContactMethod: ContactMethod, + }), +}) + +export type User = z.infer diff --git a/web/app/components/base/form/index.tsx b/web/app/components/base/form/index.tsx new file mode 100644 index 0000000000..aeb482ad02 --- /dev/null +++ b/web/app/components/base/form/index.tsx @@ -0,0 +1,25 @@ +import { createFormHook, createFormHookContexts } from '@tanstack/react-form' +import TextField from './components/field/text' +import NumberInputField from './components/field/number-input' +import CheckboxField from './components/field/checkbox' +import SelectField from './components/field/select' +import OptionsField from './components/field/options' +import SubmitButton from './components/form/submit-button' + +export const { fieldContext, useFieldContext, formContext, useFormContext } + = createFormHookContexts() + +export const { useAppForm, withForm } = createFormHook({ + fieldComponents: { + TextField, + NumberInputField, + CheckboxField, + SelectField, + OptionsField, + }, + formComponents: { + SubmitButton, + }, + fieldContext, + formContext, +}) diff --git a/web/app/components/base/input-number/index.spec.tsx b/web/app/components/base/input-number/index.spec.tsx new file mode 100644 index 0000000000..8dfd1184b0 --- /dev/null +++ b/web/app/components/base/input-number/index.spec.tsx @@ -0,0 +1,97 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { InputNumber } from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('InputNumber Component', () => { + const defaultProps = { + onChange: jest.fn(), + } + + afterEach(() => { + jest.clearAllMocks() + }) + + it('renders input with default values', () => { + render() + const input = screen.getByRole('textbox') + expect(input).toBeInTheDocument() + }) + + it('handles increment button click', () => { + render() + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + + fireEvent.click(incrementBtn) + expect(defaultProps.onChange).toHaveBeenCalledWith(6) + }) + + it('handles decrement button click', () => { + render() + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + fireEvent.click(decrementBtn) + expect(defaultProps.onChange).toHaveBeenCalledWith(4) + }) + + it('respects max value constraint', () => { + render() + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + + fireEvent.click(incrementBtn) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('respects min value constraint', () => { + render() + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + fireEvent.click(decrementBtn) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('handles direct input changes', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: '42' } }) + expect(defaultProps.onChange).toHaveBeenCalledWith(42) + }) + + it('handles empty input', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: '' } }) + expect(defaultProps.onChange).toHaveBeenCalledWith(undefined) + }) + + it('handles invalid input', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: 'abc' } }) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('displays unit when provided', () => { + const unit = 'px' + render() + expect(screen.getByText(unit)).toBeInTheDocument() + }) + + it('disables controls when disabled prop is true', () => { + render() + const input = screen.getByRole('textbox') + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + expect(input).toBeDisabled() + expect(incrementBtn).toBeDisabled() + expect(decrementBtn).toBeDisabled() + }) +}) diff --git a/web/app/components/base/input-number/index.tsx b/web/app/components/base/input-number/index.tsx index 5b88fc67f8..98efc94462 100644 --- a/web/app/components/base/input-number/index.tsx +++ b/web/app/components/base/input-number/index.tsx @@ -8,7 +8,7 @@ export type InputNumberProps = { value?: number onChange: (value?: number) => void amount?: number - size?: 'sm' | 'md' + size?: 'regular' | 'large' max?: number min?: number defaultValue?: number @@ -19,14 +19,12 @@ export type InputNumberProps = { } & Omit export const InputNumber: FC = (props) => { - const { unit, className, onChange, amount = 1, value, size = 'md', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props + const { unit, className, onChange, amount = 1, value, size = 'regular', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props const isValidValue = (v: number) => { - if (max && v > max) + if (typeof max === 'number' && v > max) return false - if (min && v < min) - return false - return true + return !(typeof min === 'number' && v < min) } const inc = () => { @@ -76,29 +74,39 @@ export const InputNumber: FC = (props) => { onChange(parsed) }} unit={unit} + size={size} />
- +
diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index 5f059c3b7f..30fd90aff8 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -30,7 +30,7 @@ export type InputProps = { wrapperClassName?: string styleCss?: CSSProperties unit?: string -} & React.InputHTMLAttributes & VariantProps +} & Omit, 'size'> & VariantProps const Input = ({ size, diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 4cae402e3b..03eb5a7c42 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -54,7 +54,7 @@ const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, max={max} step={step} amount={step} - size='sm' + size='regular' value={value} onChange={(value) => { onChange(id, value) diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index e9b7ab047a..e6c4de31f1 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -10,6 +10,7 @@ export type TooltipProps = { position?: Placement triggerMethod?: 'hover' | 'click' triggerClassName?: string + triggerTestId?: string disabled?: boolean popupContent?: React.ReactNode children?: React.ReactNode @@ -24,6 +25,7 @@ const Tooltip: FC = ({ position = 'top', triggerMethod = 'hover', triggerClassName, + triggerTestId, disabled = false, popupContent, children, @@ -91,7 +93,7 @@ const Tooltip: FC = ({ onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} asChild={asChild} > - {children ||
} + {children ||
} = (props) => {
}> = (props) => {
}> = ({ const resetList = useCallback(() => { setSelectedSegmentIds([]) invalidSegmentList() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) + }, [invalidSegmentList]) const resetChildList = useCallback(() => { invalidChildSegmentList() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) + }, [invalidChildSegmentList]) const onClickCard = (detail: SegmentDetailModel, isEditMode = false) => { setCurrSegment({ segInfo: detail, showModal: true, isEditMode }) @@ -253,7 +251,7 @@ const Completed: FC = ({ const invalidChunkListEnabled = useInvalid(useChunkListEnabledKey) const invalidChunkListDisabled = useInvalid(useChunkListDisabledKey) - const refreshChunkListWithStatusChanged = () => { + const refreshChunkListWithStatusChanged = useCallback(() => { switch (selectedStatus) { case 'all': invalidChunkListDisabled() @@ -262,7 +260,7 @@ const Completed: FC = ({ default: invalidSegmentList() } - } + }, [selectedStatus, invalidChunkListDisabled, invalidChunkListEnabled, invalidSegmentList]) const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => { const operationApi = enable ? enableSegment : disableSegment @@ -280,8 +278,7 @@ const Completed: FC = ({ notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, selectedSegmentIds, segments]) + }, [datasetId, documentId, selectedSegmentIds, segments, disableSegment, enableSegment, t, notify, refreshChunkListWithStatusChanged]) const { mutateAsync: deleteSegment } = useDeleteSegment() @@ -296,12 +293,11 @@ const Completed: FC = ({ notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, selectedSegmentIds]) + }, [datasetId, documentId, selectedSegmentIds, deleteSegment, resetList, t, notify]) const { mutateAsync: updateSegment } = useUpdateSegment() - const refreshChunkListDataWithDetailChanged = () => { + const refreshChunkListDataWithDetailChanged = useCallback(() => { switch (selectedStatus) { case 'all': invalidChunkListDisabled() @@ -316,7 +312,7 @@ const Completed: FC = ({ invalidChunkListEnabled() break } - } + }, [selectedStatus, invalidChunkListDisabled, invalidChunkListEnabled, invalidChunkListAll]) const handleUpdateSegment = useCallback(async ( segmentId: string, @@ -375,17 +371,18 @@ const Completed: FC = ({ eventEmitter?.emit('update-segment-done') }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segments, datasetId, documentId]) + }, [segments, datasetId, documentId, updateSegment, docForm, notify, eventEmitter, onCloseSegmentDetail, refreshChunkListDataWithDetailChanged, t]) useEffect(() => { resetList() + // eslint-disable-next-line react-hooks/exhaustive-deps }, [pathname]) useEffect(() => { if (importStatus === ProcessStatus.COMPLETED) resetList() - }, [importStatus, resetList]) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [importStatus]) const onCancelBatchOperation = useCallback(() => { setSelectedSegmentIds([]) @@ -430,8 +427,7 @@ const Completed: FC = ({ const count = segmentListData?.total || 0 return `${total} ${t('datasetDocuments.segment.searchResults', { count })}` } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segmentListData?.total, mode, parentMode, searchValue, selectedStatus]) + }, [segmentListData, mode, parentMode, searchValue, selectedStatus, t]) const toggleFullScreen = useCallback(() => { setFullScreen(!fullScreen) @@ -449,8 +445,7 @@ const Completed: FC = ({ resetList() currentPage !== totalPages && setCurrentPage(totalPages) } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segmentListData, limit, currentPage]) + }, [segmentListData, limit, currentPage, resetList]) const { mutateAsync: deleteChildSegment } = useDeleteChildSegment() @@ -470,8 +465,7 @@ const Completed: FC = ({ }, }, ) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, parentMode]) + }, [datasetId, documentId, parentMode, deleteChildSegment, resetList, resetChildList, t, notify]) const handleAddNewChildChunk = useCallback((parentChunkId: string) => { setShowNewChildSegmentModal(true) @@ -490,8 +484,7 @@ const Completed: FC = ({ else { resetChildList() } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [parentMode, currChunkId, segments]) + }, [parentMode, currChunkId, segments, refreshChunkListDataWithDetailChanged, resetChildList]) const viewNewlyAddedChildChunk = useCallback(() => { const totalPages = childChunkListData?.total_pages || 0 @@ -505,8 +498,7 @@ const Completed: FC = ({ resetChildList() currentPage !== totalPages && setCurrentPage(totalPages) } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [childChunkListData, limit, currentPage]) + }, [childChunkListData, limit, currentPage, resetChildList]) const onClickSlice = useCallback((detail: ChildChunkDetail) => { setCurrChildChunk({ childChunkInfo: detail, showModal: true }) @@ -560,8 +552,7 @@ const Completed: FC = ({ eventEmitter?.emit('update-child-segment-done') }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segments, childSegments, datasetId, documentId, parentMode]) + }, [segments, datasetId, documentId, parentMode, updateChildSegment, notify, eventEmitter, onCloseChildSegmentDetail, refreshChunkListDataWithDetailChanged, resetChildList, t]) const onClearFilter = useCallback(() => { setInputValue('') @@ -570,6 +561,12 @@ const Completed: FC = ({ setCurrentPage(1) }, []) + const selectDefaultValue = useMemo(() => { + if (selectedStatus === 'all') + return 'all' + return selectedStatus ? 1 : 0 + }, [selectedStatus]) + return ( = ({ @@ -591,7 +588,7 @@ const Completed: FC = ({ = ({ const wordCountText = useMemo(() => { const total = formatNumber(word_count) return `${total} ${t('datasetDocuments.segment.characters', { count: word_count })}` - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [word_count]) + }, [word_count, t]) const labelPrefix = useMemo(() => { return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isParentChildMode]) + }, [isParentChildMode, t]) if (loading) return diff --git a/web/app/components/datasets/documents/detail/completed/segment-detail.tsx b/web/app/components/datasets/documents/detail/completed/segment-detail.tsx index cea3402499..d3575c18ed 100644 --- a/web/app/components/datasets/documents/detail/completed/segment-detail.tsx +++ b/web/app/components/datasets/documents/detail/completed/segment-detail.tsx @@ -86,8 +86,7 @@ const SegmentDetail: FC = ({ const titleText = useMemo(() => { return isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isEditMode]) + }, [isEditMode, t]) const isQAModel = useMemo(() => { return docForm === ChunkingMode.qa @@ -98,13 +97,11 @@ const SegmentDetail: FC = ({ const total = formatNumber(isEditMode ? contentLength : segInfo!.word_count as number) const count = isEditMode ? contentLength : segInfo!.word_count as number return `${total} ${t('datasetDocuments.segment.characters', { count })}` - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isEditMode, question.length, answer.length, segInfo?.word_count, isQAModel]) + }, [isEditMode, question.length, answer.length, isQAModel, segInfo, t]) const labelPrefix = useMemo(() => { return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isParentChildMode]) + }, [isParentChildMode, t]) return (
diff --git a/web/app/components/datasets/documents/detail/completed/segment-list.tsx b/web/app/components/datasets/documents/detail/completed/segment-list.tsx index b2351c1b97..f6076e5813 100644 --- a/web/app/components/datasets/documents/detail/completed/segment-list.tsx +++ b/web/app/components/datasets/documents/detail/completed/segment-list.tsx @@ -42,7 +42,7 @@ const SegmentList = ( embeddingAvailable, onClearFilter, }: ISegmentListProps & { - ref: React.RefObject; + ref: React.LegacyRef }, ) => { const mode = useDocumentContext(s => s.mode) diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 8ed878fe56..cb349ee01c 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -202,7 +202,7 @@ export const OperationAction: FC<{ const isListScene = scene === 'list' const onOperate = async (operationName: OperationName) => { - let opApi = deleteDocument + let opApi switch (operationName) { case 'archive': opApi = archiveDocument @@ -490,7 +490,7 @@ const DocumentList: FC = ({ const handleAction = (actionName: DocumentActionType) => { return async () => { - let opApi = deleteDocument + let opApi switch (actionName) { case DocumentActionType.archive: opApi = archiveDocument @@ -527,7 +527,7 @@ const DocumentList: FC = ({ )} diff --git a/web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx b/web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx index 25e19506d0..fd7bb89bd3 100644 --- a/web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx +++ b/web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx @@ -40,7 +40,7 @@ const InputCombined: FC = ({ className={cn(className, 'rounded-l-md')} value={value} onChange={onChange} - size='sm' + size='regular' controlWrapClassName='overflow-hidden' controlClassName='pt-0 pb-0' readOnly={readOnly} diff --git a/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx b/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx index be57cbca0f..d67b7af1a4 100644 --- a/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx +++ b/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx @@ -133,7 +133,7 @@ export const AgentStrategy = memo((props: AgentStrategyProps) => { // TODO: maybe empty, handle this onChange={onChange as any} defaultValue={defaultValue} - size='sm' + size='regular' min={def.min} max={def.max} className='w-12' diff --git a/web/app/dev-preview/page.tsx b/web/app/dev-preview/page.tsx index 24631aa28e..69464d612a 100644 --- a/web/app/dev-preview/page.tsx +++ b/web/app/dev-preview/page.tsx @@ -1,19 +1,11 @@ 'use client' -import { ToolTipContent } from '../components/base/tooltip/content' -import { SwitchPluginVersion } from '../components/workflow/nodes/_base/components/switch-plugin-version' -import { useTranslation } from 'react-i18next' +import DemoForm from '../components/base/form/form-scenarios/demo' export default function Page() { - const { t } = useTranslation() - return
- - {t('workflow.nodes.agent.strategyNotFoundDescAndSwitchVersion')} - } - /> -
+ return ( +
+ +
+ ) } diff --git a/web/jest.config.ts b/web/jest.config.ts index e29734fdef..ebeb2f7d7e 100644 --- a/web/jest.config.ts +++ b/web/jest.config.ts @@ -43,12 +43,13 @@ const config: Config = { coverageProvider: 'v8', // A list of reporter names that Jest uses when writing coverage reports - // coverageReporters: [ - // "json", - // "text", - // "lcov", - // "clover" - // ], + coverageReporters: [ + 'json', + 'text', + 'text-summary', + 'lcov', + 'clover', + ], // An object that configures minimum threshold enforcement for coverage results // coverageThreshold: undefined, diff --git a/web/jest.setup.ts b/web/jest.setup.ts index c44951a680..ef9ede0492 100644 --- a/web/jest.setup.ts +++ b/web/jest.setup.ts @@ -1 +1,6 @@ import '@testing-library/jest-dom' +import { cleanup } from '@testing-library/react' + +afterEach(() => { + cleanup() +}) diff --git a/web/package.json b/web/package.json index 5edc388068..a1af12cff4 100644 --- a/web/package.json +++ b/web/package.json @@ -54,6 +54,7 @@ "@sentry/utils": "^8.54.0", "@svgdotjs/svg.js": "^3.2.4", "@tailwindcss/typography": "^0.5.15", + "@tanstack/react-form": "^1.3.3", "@tanstack/react-query": "^5.60.5", "@tanstack/react-query-devtools": "^5.60.5", "ahooks": "^3.8.4", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index c86fe8baf0..d1c65b6a4a 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -94,6 +94,9 @@ importers: '@tailwindcss/typography': specifier: ^0.5.15 version: 0.5.16(tailwindcss@3.4.17(ts-node@10.9.2(@types/node@18.15.0)(typescript@4.9.5))) + '@tanstack/react-form': + specifier: ^1.3.3 + version: 1.3.3(react-dom@19.0.0(react@19.0.0))(react@19.0.0) '@tanstack/react-query': specifier: ^5.60.5 version: 5.72.2(react@19.0.0) @@ -2781,12 +2784,27 @@ packages: peerDependencies: tailwindcss: '>=3.0.0 || insiders || >=4.0.0-alpha.20 || >=4.0.0-beta.1' + '@tanstack/form-core@1.3.2': + resolution: {integrity: sha512-hqRLw9EJ8bLJ5zvorGgTI4INcKh1hAtjPRTslwdB529soP8LpguzqWhn7yVV5/c2GcMSlqmpy5NZarkF5Mf54A==} + '@tanstack/query-core@5.72.2': resolution: {integrity: sha512-fxl9/0yk3mD/FwTmVEf1/H6N5B975H0luT+icKyX566w6uJG0x6o+Yl+I38wJRCaogiMkstByt+seXfDbWDAcA==} '@tanstack/query-devtools@5.72.2': resolution: {integrity: sha512-mMKnGb+iOhVBcj6jaerCFRpg8pACStdG8hmUBHPtToeZzs4ctjBUL1FajqpVn2WaMxnq8Wya+P3Q5tPFNM9jQw==} + '@tanstack/react-form@1.3.3': + resolution: {integrity: sha512-rjZU6ufaQYbZU9I0uIXUJ1CPQ9M/LFyfpbsgA4oqpX/lLoiCFYsV7tZYVlWMMHkpSr1hhmAywp/8rmCFt14lnw==} + peerDependencies: + '@tanstack/react-start': ^1.112.0 + react: ^17.0.0 || ^18.0.0 || ^19.0.0 + vinxi: ^0.5.0 + peerDependenciesMeta: + '@tanstack/react-start': + optional: true + vinxi: + optional: true + '@tanstack/react-query-devtools@5.72.2': resolution: {integrity: sha512-n53qr9JdHCJTCUba6OvMhwiV2CcsckngOswKEE7nM5pQBa/fW9c43qw8omw1RPT2s+aC7MuwS8fHsWT8g+j6IQ==} peerDependencies: @@ -2798,12 +2816,21 @@ packages: peerDependencies: react: ^18 || ^19 + '@tanstack/react-store@0.7.0': + resolution: {integrity: sha512-S/Rq17HaGOk+tQHV/yrePMnG1xbsKZIl/VsNWnNXt4XW+tTY8dTlvpJH2ZQ3GRALsusG5K6Q3unAGJ2pd9W/Ng==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + '@tanstack/react-virtual@3.13.6': resolution: {integrity: sha512-WT7nWs8ximoQ0CDx/ngoFP7HbQF9Q2wQe4nh2NB+u2486eX3nZRE40P9g6ccCVq7ZfTSH5gFOuCoVH5DLNS/aA==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + '@tanstack/store@0.7.0': + resolution: {integrity: sha512-CNIhdoUsmD2NolYuaIs8VfWM467RK6oIBAW4nPEKZhg1smZ+/CwtCdpURgp7nxSqOaV9oKkzdWD80+bC66F/Jg==} + '@tanstack/virtual-core@3.13.6': resolution: {integrity: sha512-cnQUeWnhNP8tJ4WsGcYiX24Gjkc9ALstLbHcBj1t3E7EimN6n6kHH+DPV4PpDnuw00NApQp+ViojMj1GRdwYQg==} @@ -4348,6 +4375,9 @@ packages: decimal.js@10.5.0: resolution: {integrity: sha512-8vDa8Qxvr/+d94hSh5P3IJwI5t8/c0KsMp+g8bNw9cY2icONa5aPfvKeieW1WlG0WQYwwhJ7mjui2xtiePQSXw==} + decode-formdata@0.9.0: + resolution: {integrity: sha512-q5uwOjR3Um5YD+ZWPOF/1sGHVW9A5rCrRwITQChRXlmPkxDFBqCm4jNTIVdGHNH9OnR+V9MoZVgRhsFb+ARbUw==} + decode-named-character-reference@1.1.0: resolution: {integrity: sha512-Wy+JTSbFThEOXQIR2L6mxJvEs+veIzpmqD7ynWxMXGpnk3smkHQOp6forLdHsKpAMW9iJpaBBIxz285t1n1C3w==} @@ -4423,6 +4453,9 @@ packages: resolution: {integrity: sha512-TLz+x/vEXm/Y7P7wn1EJFNLxYpUD4TgMosxY6fAVJUnJMbupHBOncxyWUG9OpTaH9EBD7uFI5LfEgmMOc54DsA==} engines: {node: '>=8'} + devalue@5.1.1: + resolution: {integrity: sha512-maua5KUiapvEwiEAe+XnlZ3Rh0GD+qI1J/nb9vrJc3muPXvcF/8gXYTWF76+5DAqHyDUtOIImEuo0YKE9mshVw==} + devlop@1.1.0: resolution: {integrity: sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==} @@ -11352,10 +11385,24 @@ snapshots: postcss-selector-parser: 6.0.10 tailwindcss: 3.4.17(ts-node@10.9.2(@types/node@18.15.0)(typescript@4.9.5)) + '@tanstack/form-core@1.3.2': + dependencies: + '@tanstack/store': 0.7.0 + '@tanstack/query-core@5.72.2': {} '@tanstack/query-devtools@5.72.2': {} + '@tanstack/react-form@1.3.3(react-dom@19.0.0(react@19.0.0))(react@19.0.0)': + dependencies: + '@tanstack/form-core': 1.3.2 + '@tanstack/react-store': 0.7.0(react-dom@19.0.0(react@19.0.0))(react@19.0.0) + decode-formdata: 0.9.0 + devalue: 5.1.1 + react: 19.0.0 + transitivePeerDependencies: + - react-dom + '@tanstack/react-query-devtools@5.72.2(@tanstack/react-query@5.72.2(react@19.0.0))(react@19.0.0)': dependencies: '@tanstack/query-devtools': 5.72.2 @@ -11367,12 +11414,21 @@ snapshots: '@tanstack/query-core': 5.72.2 react: 19.0.0 + '@tanstack/react-store@0.7.0(react-dom@19.0.0(react@19.0.0))(react@19.0.0)': + dependencies: + '@tanstack/store': 0.7.0 + react: 19.0.0 + react-dom: 19.0.0(react@19.0.0) + use-sync-external-store: 1.5.0(react@19.0.0) + '@tanstack/react-virtual@3.13.6(react-dom@19.0.0(react@19.0.0))(react@19.0.0)': dependencies: '@tanstack/virtual-core': 3.13.6 react: 19.0.0 react-dom: 19.0.0(react@19.0.0) + '@tanstack/store@0.7.0': {} + '@tanstack/virtual-core@3.13.6': {} '@testing-library/dom@10.4.0': @@ -13139,6 +13195,8 @@ snapshots: decimal.js@10.5.0: {} + decode-formdata@0.9.0: {} + decode-named-character-reference@1.1.0: dependencies: character-entities: 2.0.2 @@ -13199,6 +13257,8 @@ snapshots: detect-newline@3.1.0: {} + devalue@5.1.1: {} + devlop@1.1.0: dependencies: dequal: 2.0.3 From 3914cf07e7bce8c86a7a0db60e256d9a4c78f9e7 Mon Sep 17 00:00:00 2001 From: GuanMu Date: Fri, 18 Apr 2025 16:00:12 +0800 Subject: [PATCH 48/68] fix: Adjust span height and alignment in WorkplaceSelector component (#18361) --- .../header/account-dropdown/workplace-selector/index.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/header/account-dropdown/workplace-selector/index.tsx b/web/app/components/header/account-dropdown/workplace-selector/index.tsx index a9a886376a..da3f8bae6d 100644 --- a/web/app/components/header/account-dropdown/workplace-selector/index.tsx +++ b/web/app/components/header/account-dropdown/workplace-selector/index.tsx @@ -42,7 +42,7 @@ const WorkplaceSelector = () => { `, )}>
- {currentWorkspace?.name[0]?.toLocaleUpperCase()} + {currentWorkspace?.name[0]?.toLocaleUpperCase()}
{currentWorkspace?.name}
@@ -73,7 +73,7 @@ const WorkplaceSelector = () => { workspaces.map(workspace => (
handleSwitchWorkspace(workspace.id)}>
- {workspace?.name[0]?.toLocaleUpperCase()} + {workspace?.name[0]?.toLocaleUpperCase()}
{workspace.name}
From d2e3744ca3377fbd399b465e8a048c7aeadbe90b Mon Sep 17 00:00:00 2001 From: Rain Wang Date: Fri, 18 Apr 2025 16:05:48 +0800 Subject: [PATCH 49/68] Switching from CONSOLE_API_URL to FILES_URL in word_extractor.py (#18249) --- api/core/rag/extractor/word_extractor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 70c618a631..edaa8c92fa 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -126,9 +126,7 @@ class WordExtractor(BaseExtractor): db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = ( - f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)" - ) + image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" return image_map From da9269ca97c3648dce668ce7cf8e418bfd66ce91 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Fri, 18 Apr 2025 16:33:53 +0800 Subject: [PATCH 50/68] feat: structured output (#17877) --- api/controllers/console/app/generator.py | 30 ++ api/core/llm_generator/llm_generator.py | 35 +++ api/core/llm_generator/prompts.py | 107 +++++++ .../model_runtime/entities/model_entities.py | 16 +- api/core/workflow/nodes/agent/agent_node.py | 16 +- api/core/workflow/nodes/agent/entities.py | 15 + api/core/workflow/nodes/llm/entities.py | 2 + api/core/workflow/nodes/llm/node.py | 261 +++++++++++++++++- .../utils/structured_output/entities.py | 24 ++ .../utils/structured_output/prompt.py | 17 ++ api/pyproject.toml | 6 +- api/uv.lock | 14 +- 12 files changed, 530 insertions(+), 13 deletions(-) create mode 100644 api/core/workflow/utils/structured_output/entities.py create mode 100644 api/core/workflow/utils/structured_output/prompt.py diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 8518d34a8e..4046417076 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource): return code_result +class RuleStructuredOutputGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + account = current_user + try: + structured_output = LLMGenerator.generate_structured_output( + tenant_id=account.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + return structured_output + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") +api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 75687f9ae3..d5d2ca60fa 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -10,6 +10,7 @@ from core.llm_generator.prompts import ( GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, + SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager @@ -340,3 +341,37 @@ class LLMGenerator: answer = cast(str, response.message.content) return answer.strip() + + @classmethod + def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + + prompt_messages = [ + SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), + UserPromptMessage(content=instruction), + ] + model_parameters = model_config.get("model_parameters", {}) + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ), + ) + + generated_json_schema = cast(str, response.message.content) + return {"output": generated_json_schema, "error": ""} + + except InvokeError as e: + error = str(e) + return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} + except Exception as e: + logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}") + return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index cf20e60c82..82d22d7f89 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}} You just need to generate the output """ # noqa: E501 + +SYSTEM_STRUCTURED_OUTPUT_GENERATE = """ +Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements. + +## Instructions: + +1. Analyze the user's description of their data needs +2. Identify each property that should be included in the schema +3. Determine the appropriate data type for each property +4. Decide which properties should be required +5. Generate a complete JSON Schema with proper syntax +6. Include appropriate constraints when specified (min/max values, patterns, formats) +7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting. +8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly. + +## Examples: + +### Example 1: +**User Input:** I need name and age +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" } + }, + "required": ["name", "age"] +} + +### Example 2: +**User Input:** I want to store information about books including title, author, publication year and optional page count +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "title": { "type": "string" }, + "author": { "type": "string" }, + "publicationYear": { "type": "integer" }, + "pageCount": { "type": "integer" } + }, + "required": ["title", "author", "publicationYear"] +} + +### Example 3: +**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18) +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email" + }, + "password": { + "type": "string", + "minLength": 8 + }, + "age": { + "type": "integer", + "minimum": 18 + } + }, + "required": ["email", "password", "age"] +} + +### Example 4: +**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist. +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "properties": { + "songs": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "id": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "aritst": { + "type": "string" + } + }, + "required": [ + "name", + "id", + "duration", + "aritst" + ] + } + } + } + }, + "required": [ + "songs" + ] +} + +Now, generate a JSON Schema based on my description +""" # noqa: E501 diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 3225f03fbd..373ef2bbe2 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal from enum import Enum, StrEnum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from core.model_runtime.entities.common_entities import I18nObject @@ -85,6 +85,7 @@ class ModelFeature(Enum): DOCUMENT = "document" VIDEO = "video" AUDIO = "audio" + STRUCTURED_OUTPUT = "structured-output" class DefaultParameterName(StrEnum): @@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel): parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None + @model_validator(mode="after") + def validate_model(self): + supported_schema_keys = ["json_schema"] + schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) + if not schema_key: + return self + if self.features is None: + self.features = [ModelFeature.STRUCTURED_OUTPUT] + else: + if ModelFeature.STRUCTURED_OUTPUT not in self.features: + self.features.append(ModelFeature.STRUCTURED_OUTPUT) + return self + class ModelUsage(BaseModel): pass diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7c8960fe49..da40cbcdea 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -16,7 +16,7 @@ from core.variables.segments import StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated +from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -251,7 +251,12 @@ class AgentNode(ToolNode): prompt_message.model_dump(mode="json") for prompt_message in prompt_messages ] value["history_prompt_messages"] = history_prompt_messages - value["entity"] = model_schema.model_dump(mode="json") if model_schema else None + if model_schema: + # remove structured output feature to support old version agent plugin + model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) + value["entity"] = model_schema.model_dump(mode="json") + else: + value["entity"] = None result[parameter_name] = value return result @@ -348,3 +353,10 @@ class AgentNode(ToolNode): ) model_schema = model_type_instance.get_model_schema(model_name, model_credentials) return model_instance, model_schema + + def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: + if model_schema.features: + for feature in model_schema.features: + if feature.value not in AgentOldVersionModelFeatures: + model_schema.features.remove(feature) + return model_schema diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 87cc7e9824..77e94375bf 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData): class ParamsAutoGenerated(Enum): CLOSE = 0 OPEN = 1 + + +class AgentOldVersionModelFeatures(Enum): + """ + Enum class for old SDK version llm feature. + """ + + TOOL_CALL = "tool-call" + MULTI_TOOL_CALL = "multi-tool-call" + AGENT_THOUGHT = "agent-thought" + VISION = "vision" + STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index bf54fdb80c..486b4b01af 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + structured_output: dict | None = None + structured_output_enabled: bool = False @field_validator("prompt_config", mode="before") @classmethod diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fe0ed3e564..8db7394e54 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast +import json_repair + from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus @@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID @@ -57,6 +65,12 @@ from core.workflow.nodes.event import ( RunRetrieverResourceEvent, RunStreamChunkEvent, ) +from core.workflow.utils.structured_output.entities import ( + ResponseFormat, + SpecialModelType, + SupportStructuredOutputStatus, +) +from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation @@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]): _node_type = NodeType.LLM def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: + """Process structured output if enabled""" + if not self.node_data.structured_output_enabled or not self.node_data.structured_output: + return None + return self._parse_structured_output(text) + node_inputs: Optional[dict[str, Any]] = None process_data = None result_text = "" @@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(event, RunRetrieverResourceEvent): context = event.context yield event - if context: node_inputs["#context#"] = context @@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]): self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} - + structured_output = process_structured_output(result_text) + if structured_output: + outputs["structured_output"] = structured_output yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ModelNotExistError(f"Model {model_name} not exist.") - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: + completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) + elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: + # Set appropriate response format based on model capabilities + self._set_response_format(completion_params, model_schema.parameter_rules) return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, model=model_name, @@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) stop = model_config.stop return filtered_prompt_messages, stop + def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]: + structured_output: dict[str, Any] | list[Any] = {} + try: + parsed = json.loads(result_text) + if not isinstance(parsed, (dict | list)): + raise LLMNodeError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + except json.JSONDecodeError as e: + # if the result_text is not a valid json, try to repair it + parsed = json_repair.loads(result_text) + if not isinstance(parsed, (dict | list)): + raise LLMNodeError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + return structured_output + @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: provider_model_bundle = model_instance.provider_model_bundle @@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages + def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: + """ + Handle structured output for models with native JSON schema support. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + :return: Updated model parameters with JSON schema configuration + """ + # Process schema according to model requirements + schema = self._fetch_structured_output_schema() + schema_json = self._prepare_schema_for_model(schema) + + # Set JSON schema in parameters + model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) + + # Set appropriate response format if required by the model + for rule in rules: + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value + + return model_parameters + + def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]: + """ + Handle structured output for models without native JSON schema support. + This function modifies the prompt messages to include schema-based output requirements. + + Args: + prompt_messages: Original sequence of prompt messages + + Returns: + list[PromptMessage]: Updated prompt messages with structured output requirements + """ + # Convert schema to string format + schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False) + + # Find existing system prompt with schema placeholder + system_prompt = next( + (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)), + None, + ) + structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str) + # Prepare system prompt content + system_prompt_content = ( + structured_output_prompt + "\n\n" + system_prompt.content + if system_prompt and isinstance(system_prompt.content, str) + else structured_output_prompt + ) + system_prompt = SystemPromptMessage(content=system_prompt_content) + + # Extract content from the last user message + + filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)] + updated_prompt = [system_prompt] + filtered_prompts + + return updated_prompt + + def _set_response_format(self, model_parameters: dict, rules: list) -> None: + """ + Set the appropriate response format parameter based on model rules. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + """ + for rule in rules: + if rule.name == "response_format": + if ResponseFormat.JSON.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON.value + elif ResponseFormat.JSON_OBJECT.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + + def _prepare_schema_for_model(self, schema: dict) -> dict: + """ + Prepare JSON schema based on model requirements. + + Different models have different requirements for JSON schema formatting. + This function handles these differences. + + :param schema: The original JSON schema + :return: Processed schema compatible with the current model + """ + + # Deep copy to avoid modifying the original schema + processed_schema = schema.copy() + + # Convert boolean types to string types (common requirement) + convert_boolean_to_string(processed_schema) + + # Apply model-specific transformations + if SpecialModelType.GEMINI in self.node_data.model.name: + remove_additional_properties(processed_schema) + return processed_schema + elif SpecialModelType.OLLAMA in self.node_data.model.provider: + return processed_schema + else: + # Default format with name field + return {"schema": processed_schema, "name": "llm_response"} + + def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: + """ + Fetch model schema + """ + model_name = self.node_data.model.name + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name + ) + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_credentials = model_instance.credentials + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_schema + + def _fetch_structured_output_schema(self) -> dict[str, Any]: + """ + Fetch the structured output schema from the node data. + + Returns: + dict[str, Any]: The structured output schema + """ + if not self.node_data.structured_output: + raise LLMNodeError("Please provide a valid structured output schema") + structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + if not structured_output_schema: + raise LLMNodeError("Please provide a valid structured output schema") + + try: + schema = json.loads(structured_output_schema) + if not isinstance(schema, dict): + raise LLMNodeError("structured_output_schema must be a JSON object") + return schema + except json.JSONDecodeError: + raise LLMNodeError("structured_output_schema is not valid JSON format") + + def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: + """ + Check if the current model supports structured output. + + Returns: + SupportStructuredOutput: The support status of structured output + """ + # Early return if structured output is disabled + if ( + not isinstance(self.node_data, LLMNodeData) + or not self.node_data.structured_output_enabled + or not self.node_data.structured_output + ): + return SupportStructuredOutputStatus.DISABLED + # Get model schema and check if it exists + model_schema = self._fetch_model_schema(self.node_data.model.provider) + if not model_schema: + return SupportStructuredOutputStatus.DISABLED + + # Check if model supports structured output feature + return ( + SupportStructuredOutputStatus.SUPPORTED + if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) + else SupportStructuredOutputStatus.UNSUPPORTED + ) + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: @@ -1064,3 +1269,49 @@ def _handle_completion_template( ) prompt_messages.append(prompt_message) return prompt_messages + + +def remove_additional_properties(schema: dict) -> None: + """ + Remove additionalProperties fields from JSON schema. + Used for models like Gemini that don't support this property. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Remove additionalProperties at current level + schema.pop("additionalProperties", None) + + # Process nested structures recursively + for value in schema.values(): + if isinstance(value, dict): + remove_additional_properties(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + remove_additional_properties(item) + + +def convert_boolean_to_string(schema: dict) -> None: + """ + Convert boolean type specifications to string in JSON schema. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Check for boolean type at current level + if schema.get("type") == "boolean": + schema["type"] = "string" + + # Process nested dictionaries and lists recursively + for value in schema.values(): + if isinstance(value, dict): + convert_boolean_to_string(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + convert_boolean_to_string(item) diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py new file mode 100644 index 0000000000..7954acbaee --- /dev/null +++ b/api/core/workflow/utils/structured_output/entities.py @@ -0,0 +1,24 @@ +from enum import StrEnum + + +class ResponseFormat(StrEnum): + """Constants for model response formats""" + + JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode. + JSON = "JSON" # model's json mode. some model like claude support this mode. + JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias. + + +class SpecialModelType(StrEnum): + """Constants for identifying model types""" + + GEMINI = "gemini" + OLLAMA = "ollama" + + +class SupportStructuredOutputStatus(StrEnum): + """Constants for structured output support status""" + + SUPPORTED = "supported" + UNSUPPORTED = "unsupported" + DISABLED = "disabled" diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py new file mode 100644 index 0000000000..06d9b2056e --- /dev/null +++ b/api/core/workflow/utils/structured_output/prompt.py @@ -0,0 +1,17 @@ +STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format. +constraints: + - You must output in JSON format. + - Do not output boolean value, use string type instead. + - Do not output integer or float value, use number type instead. +eg: + Here is the JSON schema: + {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} + + Here is the user's question: + My name is John Doe and I am 30 years old. + + output: + {"name": "John Doe", "age": 30} +Here is the JSON schema: +{{schema}} +""" # noqa: E501 diff --git a/api/pyproject.toml b/api/pyproject.toml index 85679a6359..08f9c1e229 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "gunicorn~=23.0.0", "httpx[socks]~=0.27.0", "jieba==0.42.1", + "json-repair>=0.41.1", "langfuse~=2.51.3", "langsmith~=0.1.77", "mailchimp-transactional~=1.0.50", @@ -163,10 +164,7 @@ storage = [ ############################################################ # [ Tools ] dependency group ############################################################ -tools = [ - "cloudscraper~=1.2.71", - "nltk~=3.9.1", -] +tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] ############################################################ # [ VDB ] dependency group diff --git a/api/uv.lock b/api/uv.lock index 4ff9c34446..4384e1abb5 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'", @@ -1178,6 +1177,7 @@ dependencies = [ { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "jieba" }, + { name = "json-repair" }, { name = "langfuse" }, { name = "langsmith" }, { name = "mailchimp-transactional" }, @@ -1346,6 +1346,7 @@ requires-dist = [ { name = "gunicorn", specifier = "~=23.0.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.27.0" }, { name = "jieba", specifier = "==0.42.1" }, + { name = "json-repair", specifier = ">=0.41.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "mailchimp-transactional", specifier = "~=1.0.50" }, @@ -2524,6 +2525,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, ] +[[package]] +name = "json-repair" +version = "0.41.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 }, +] + [[package]] name = "jsonpath-python" version = "1.0.6" @@ -4074,6 +4084,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 }, { url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 }, { url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 }, + { url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 }, + { url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 }, ] [[package]] From 775dc47abec20a46873f05b8ee56cfde5bd6fa0b Mon Sep 17 00:00:00 2001 From: Joel Date: Fri, 18 Apr 2025 16:53:43 +0800 Subject: [PATCH 51/68] feat: llm support struct output (#17994) Co-authored-by: twwu Co-authored-by: zxhlyh --- .../solid/general/arrow-down-round-fill.svg | 5 + .../solid/general/ArrowDownRoundFill.json | 36 ++ .../solid/general/ArrowDownRoundFill.tsx | 20 + .../icons/src/vender/solid/general/index.ts | 1 + .../plugins/history-block/node.tsx | 2 +- .../workflow-variable-block/component.tsx | 43 +- .../plugins/workflow-variable-block/index.tsx | 8 +- .../plugins/workflow-variable-block/node.tsx | 22 +- ...kflow-variable-block-replacement-block.tsx | 5 +- .../components/base/prompt-editor/types.ts | 8 + .../base/segmented-control/index.tsx | 68 +++ web/app/components/base/textarea/index.tsx | 5 +- .../model-provider-page/declarations.ts | 1 + .../model-provider-page/model-modal/Form.tsx | 1 + .../model-parameter-modal/parameter-item.tsx | 2 + .../multiple-tool-selector/index.tsx | 16 +- .../workflow/hooks/use-workflow-variables.ts | 36 ++ .../components/collapse/field-collapse.tsx | 9 + .../nodes/_base/components/collapse/index.tsx | 61 ++- .../error-handle/error-handle-on-panel.tsx | 25 +- .../error-handle-type-selector.tsx | 2 + .../nodes/_base/components/output-vars.tsx | 58 ++- .../nodes/_base/components/prompt/editor.tsx | 4 + .../readonly-input-with-select-var.tsx | 8 + .../object-child-tree-panel/picker/field.tsx | 77 +++ .../object-child-tree-panel/picker/index.tsx | 82 ++++ .../object-child-tree-panel/show/field.tsx | 74 +++ .../object-child-tree-panel/show/index.tsx | 39 ++ .../tree-indent-line.tsx | 24 + .../nodes/_base/components/variable/utils.ts | 9 +- .../variable/var-full-path-panel.tsx | 59 +++ .../variable/var-reference-picker.tsx | 45 +- .../variable/var-reference-vars.tsx | 92 ++-- .../metadata/metadata-filter/index.tsx | 6 +- .../json-schema-config-modal/code-editor.tsx | 140 ++++++ .../error-message.tsx | 27 ++ .../json-schema-config-modal/index.tsx | 34 ++ .../json-importer.tsx | 136 ++++++ .../json-schema-config.tsx | 301 ++++++++++++ .../json-schema-generator/assets/index.tsx | 7 + .../assets/schema-generator-dark.tsx | 15 + .../assets/schema-generator-light.tsx | 15 + .../generated-result.tsx | 121 +++++ .../json-schema-generator/index.tsx | 183 ++++++++ .../json-schema-generator/prompt-editor.tsx | 108 +++++ .../schema-editor.tsx | 23 + .../visual-editor/add-field.tsx | 33 ++ .../visual-editor/card.tsx | 46 ++ .../visual-editor/context.tsx | 50 ++ .../visual-editor/edit-card/actions.tsx | 56 +++ .../edit-card/advanced-actions.tsx | 59 +++ .../edit-card/advanced-options.tsx | 77 +++ .../edit-card/auto-width-input.tsx | 81 ++++ .../visual-editor/edit-card/index.tsx | 277 +++++++++++ .../edit-card/required-switch.tsx | 25 + .../visual-editor/edit-card/type-selector.tsx | 69 +++ .../visual-editor/hooks.ts | 441 ++++++++++++++++++ .../visual-editor/index.tsx | 28 ++ .../visual-editor/schema-node.tsx | 194 ++++++++ .../visual-editor/store.ts | 34 ++ .../nodes/llm/components/structure-output.tsx | 75 +++ .../components/workflow/nodes/llm/panel.tsx | 54 ++- .../workflow/nodes/llm/use-config.ts | 34 +- .../components/workflow/nodes/llm/utils.ts | 333 ++++++++++++- .../components/workflow/nodes/tool/panel.tsx | 40 +- .../workflow/nodes/tool/use-config.ts | 31 +- web/config/index.ts | 1 + web/hooks/use-mitt.ts | 18 +- web/i18n/en-US/app.ts | 11 + web/i18n/en-US/common.ts | 1 + web/i18n/en-US/workflow.ts | 28 ++ web/i18n/language.ts | 2 +- web/i18n/zh-Hans/app.ts | 11 + web/i18n/zh-Hans/common.ts | 1 + web/i18n/zh-Hans/workflow.ts | 28 ++ web/models/common.ts | 11 + web/package.json | 1 + web/pnpm-lock.yaml | 8 + web/service/use-common.ts | 18 +- web/tailwind-common-config.ts | 1 + web/themes/manual-dark.css | 124 ++--- web/themes/manual-light.css | 102 ++-- 82 files changed, 4190 insertions(+), 276 deletions(-) create mode 100644 web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg create mode 100644 web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json create mode 100644 web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx create mode 100644 web/app/components/base/segmented-control/index.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/picker/field.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/picker/index.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/show/field.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/show/index.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/tree-indent-line.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/variable/var-full-path-panel.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/code-editor.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/schema-editor.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/add-field.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/card.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/actions.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/advanced-actions.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/advanced-options.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/auto-width-input.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/index.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/required-switch.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/type-selector.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/index.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/schema-node.tsx create mode 100644 web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/store.ts create mode 100644 web/app/components/workflow/nodes/llm/components/structure-output.tsx diff --git a/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg b/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg new file mode 100644 index 0000000000..9566fcc0c3 --- /dev/null +++ b/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json new file mode 100644 index 0000000000..4e7da3c801 --- /dev/null +++ b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json @@ -0,0 +1,36 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "arrow-down-round-fill" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "id": "Vector", + "d": "M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z", + "fill": "currentColor" + }, + "children": [] + } + ] + } + ] + }, + "name": "ArrowDownRoundFill" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx new file mode 100644 index 0000000000..c766a72b94 --- /dev/null +++ b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx @@ -0,0 +1,20 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './ArrowDownRoundFill.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconData } from '@/app/components/base/icons/IconBase' + +const Icon = ( + { + ref, + ...props + }: React.SVGProps & { + ref?: React.RefObject>; + }, +) => + +Icon.displayName = 'ArrowDownRoundFill' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/solid/general/index.ts b/web/app/components/base/icons/src/vender/solid/general/index.ts index 52647905ab..4c4dd9a437 100644 --- a/web/app/components/base/icons/src/vender/solid/general/index.ts +++ b/web/app/components/base/icons/src/vender/solid/general/index.ts @@ -1,4 +1,5 @@ export { default as AnswerTriangle } from './AnswerTriangle' +export { default as ArrowDownRoundFill } from './ArrowDownRoundFill' export { default as CheckCircle } from './CheckCircle' export { default as CheckDone01 } from './CheckDone01' export { default as Download02 } from './Download02' diff --git a/web/app/components/base/prompt-editor/plugins/history-block/node.tsx b/web/app/components/base/prompt-editor/plugins/history-block/node.tsx index 1a2600d568..1cb33fcc49 100644 --- a/web/app/components/base/prompt-editor/plugins/history-block/node.tsx +++ b/web/app/components/base/prompt-editor/plugins/history-block/node.tsx @@ -14,7 +14,7 @@ export class HistoryBlockNode extends DecoratorNode { } static clone(node: HistoryBlockNode): HistoryBlockNode { - return new HistoryBlockNode(node.__roleName, node.__onEditRole) + return new HistoryBlockNode(node.__roleName, node.__onEditRole, node.__key) } constructor(roleName: RoleName, onEditRole: () => void, key?: NodeKey) { diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx index 2cf4c95b87..2f6c3374a7 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx @@ -11,6 +11,7 @@ import { mergeRegister } from '@lexical/utils' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { RiErrorWarningFill, + RiMoreLine, } from '@remixicon/react' import { useSelectOrDelete } from '../../hooks' import type { WorkflowNodesMap } from './node' @@ -27,26 +28,35 @@ import { Line3 } from '@/app/components/base/icons/src/public/common' import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' import Tooltip from '@/app/components/base/tooltip' import { isExceptionVariable } from '@/app/components/workflow/utils' +import VarFullPathPanel from '@/app/components/workflow/nodes/_base/components/variable/var-full-path-panel' +import { Type } from '@/app/components/workflow/nodes/llm/types' +import type { ValueSelector } from '@/app/components/workflow/types' type WorkflowVariableBlockComponentProps = { nodeKey: string variables: string[] workflowNodesMap: WorkflowNodesMap + getVarType?: (payload: { + nodeId: string, + valueSelector: ValueSelector, + }) => Type } const WorkflowVariableBlockComponent = ({ nodeKey, variables, workflowNodesMap = {}, + getVarType, }: WorkflowVariableBlockComponentProps) => { const { t } = useTranslation() const [editor] = useLexicalComposerContext() const [ref, isSelected] = useSelectOrDelete(nodeKey, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND) const variablesLength = variables.length + const isShowAPart = variablesLength > 2 const varName = ( () => { const isSystem = isSystemVar(variables) - const varName = variablesLength >= 3 ? (variables).slice(-2).join('.') : variables[variablesLength - 1] + const varName = variables[variablesLength - 1] return `${isSystem ? 'sys.' : ''}${varName}` } )() @@ -76,7 +86,7 @@ const WorkflowVariableBlockComponent = ({ const Item = (
)} + {isShowAPart && ( +
+ + +
+ )} +
{!isEnv && !isChatVar && } {isEnv && } @@ -126,7 +143,27 @@ const WorkflowVariableBlockComponent = ({ ) } - return Item + if (!node) + return null + + return ( + } + disabled={!isShowAPart} + > +
{Item}
+
+ ) } export default memo(WorkflowVariableBlockComponent) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx index 05d4505e20..479dce9615 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx @@ -9,7 +9,7 @@ import { } from 'lexical' import { mergeRegister } from '@lexical/utils' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' -import type { WorkflowVariableBlockType } from '../../types' +import type { GetVarType, WorkflowVariableBlockType } from '../../types' import { $createWorkflowVariableBlockNode, WorkflowVariableBlockNode, @@ -25,11 +25,13 @@ export type WorkflowVariableBlockProps = { getWorkflowNode: (nodeId: string) => Node onInsert?: () => void onDelete?: () => void + getVarType: GetVarType } const WorkflowVariableBlock = memo(({ workflowNodesMap, onInsert, onDelete, + getVarType, }: WorkflowVariableBlockType) => { const [editor] = useLexicalComposerContext() @@ -48,7 +50,7 @@ const WorkflowVariableBlock = memo(({ INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, (variables: string[]) => { editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap) + const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) $insertNodes([workflowVariableBlockNode]) if (onInsert) @@ -69,7 +71,7 @@ const WorkflowVariableBlock = memo(({ COMMAND_PRIORITY_EDITOR, ), ) - }, [editor, onInsert, onDelete, workflowNodesMap]) + }, [editor, onInsert, onDelete, workflowNodesMap, getVarType]) return null }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx index 0564e6f16d..dce636d92d 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx @@ -2,34 +2,39 @@ import type { LexicalNode, NodeKey, SerializedLexicalNode } from 'lexical' import { DecoratorNode } from 'lexical' import type { WorkflowVariableBlockType } from '../../types' import WorkflowVariableBlockComponent from './component' +import type { GetVarType } from '../../types' export type WorkflowNodesMap = WorkflowVariableBlockType['workflowNodesMap'] + export type SerializedNode = SerializedLexicalNode & { variables: string[] workflowNodesMap: WorkflowNodesMap + getVarType?: GetVarType } export class WorkflowVariableBlockNode extends DecoratorNode { __variables: string[] __workflowNodesMap: WorkflowNodesMap + __getVarType?: GetVarType static getType(): string { return 'workflow-variable-block' } static clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode { - return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__key) + return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__getVarType, node.__key) } isInline(): boolean { return true } - constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, key?: NodeKey) { + constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType: any, key?: NodeKey) { super(key) this.__variables = variables this.__workflowNodesMap = workflowNodesMap + this.__getVarType = getVarType } createDOM(): HTMLElement { @@ -48,12 +53,13 @@ export class WorkflowVariableBlockNode extends DecoratorNode nodeKey={this.getKey()} variables={this.__variables} workflowNodesMap={this.__workflowNodesMap} + getVarType={this.__getVarType!} /> ) } static importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode { - const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap) + const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap, serializedNode.getVarType) return node } @@ -64,6 +70,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode version: 1, variables: this.getVariables(), workflowNodesMap: this.getWorkflowNodesMap(), + getVarType: this.getVarType(), } } @@ -77,12 +84,17 @@ export class WorkflowVariableBlockNode extends DecoratorNode return self.__workflowNodesMap } + getVarType(): any { + const self = this.getLatest() + return self.__getVarType + } + getTextContent(): string { return `{{#${this.getVariables().join('.')}#}}` } } -export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap): WorkflowVariableBlockNode { - return new WorkflowVariableBlockNode(variables, workflowNodesMap) +export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType?: GetVarType): WorkflowVariableBlockNode { + return new WorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) } export function $isWorkflowVariableBlockNode( diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx index 22ebc5d248..288008bbcc 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx @@ -16,6 +16,7 @@ import { VAR_REGEX as REGEX, resetReg } from '@/config' const WorkflowVariableBlockReplacementBlock = ({ workflowNodesMap, + getVarType, onInsert, }: WorkflowVariableBlockType) => { const [editor] = useLexicalComposerContext() @@ -30,8 +31,8 @@ const WorkflowVariableBlockReplacementBlock = ({ onInsert() const nodePathString = textNode.getTextContent().slice(3, -3) - return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap)) - }, [onInsert, workflowNodesMap]) + return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap, getVarType)) + }, [onInsert, workflowNodesMap, getVarType]) const getMatch = useCallback((text: string) => { const matchArr = REGEX.exec(text) diff --git a/web/app/components/base/prompt-editor/types.ts b/web/app/components/base/prompt-editor/types.ts index 6d0f307c17..0f09fb2473 100644 --- a/web/app/components/base/prompt-editor/types.ts +++ b/web/app/components/base/prompt-editor/types.ts @@ -1,8 +1,10 @@ +import type { Type } from '../../workflow/nodes/llm/types' import type { Dataset } from './plugins/context-block' import type { RoleName } from './plugins/history-block' import type { Node, NodeOutPutVar, + ValueSelector, } from '@/app/components/workflow/types' export type Option = { @@ -54,12 +56,18 @@ export type ExternalToolBlockType = { onAddExternalTool?: () => void } +export type GetVarType = (payload: { + nodeId: string, + valueSelector: ValueSelector, +}) => Type + export type WorkflowVariableBlockType = { show?: boolean variables?: NodeOutPutVar[] workflowNodesMap?: Record> onInsert?: () => void onDelete?: () => void + getVarType?: GetVarType } export type MenuTextMatch = { diff --git a/web/app/components/base/segmented-control/index.tsx b/web/app/components/base/segmented-control/index.tsx new file mode 100644 index 0000000000..bd921e4243 --- /dev/null +++ b/web/app/components/base/segmented-control/index.tsx @@ -0,0 +1,68 @@ +import React from 'react' +import classNames from '@/utils/classnames' +import type { RemixiconComponentType } from '@remixicon/react' +import Divider from '../divider' + +// Updated generic type to allow enum values +type SegmentedControlProps = { + options: { Icon: RemixiconComponentType, text: string, value: T }[] + value: T + onChange: (value: T) => void + className?: string +} + +export const SegmentedControl = ({ + options, + value, + onChange, + className, +}: SegmentedControlProps): JSX.Element => { + const selectedOptionIndex = options.findIndex(option => option.value === value) + + return ( +
+ {options.map((option, index) => { + const { Icon } = option + const isSelected = index === selectedOptionIndex + const isNextSelected = index === selectedOptionIndex - 1 + const isLast = index === options.length - 1 + return ( + + ) + })} +
+ ) +} + +export default React.memo(SegmentedControl) as typeof SegmentedControl diff --git a/web/app/components/base/textarea/index.tsx b/web/app/components/base/textarea/index.tsx index 0f18bebedf..1e274515f8 100644 --- a/web/app/components/base/textarea/index.tsx +++ b/web/app/components/base/textarea/index.tsx @@ -8,8 +8,9 @@ const textareaVariants = cva( { variants: { size: { - regular: 'px-3 radius-md system-sm-regular', - large: 'px-4 radius-lg system-md-regular', + small: 'py-1 rounded-md system-xs-regular', + regular: 'px-3 rounded-md system-sm-regular', + large: 'px-4 rounded-lg system-md-regular', }, }, defaultVariants: { diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 39e229cd54..12dd9b3b5b 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -60,6 +60,7 @@ export enum ModelFeatureEnum { video = 'video', document = 'document', audio = 'audio', + StructuredOutput = 'structured-output', } export enum ModelFeatureTextEnum { diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index 28001bef5e..c5af4ed8a1 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -376,6 +376,7 @@ function Form< tooltip={tooltip?.[language] || tooltip?.en_US} value={value[variable] || []} onChange={item => handleFormChange(variable, item as any)} + supportCollapse /> {fieldMoreInfo?.(formSchema)} {validating && changeKey === variable && } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 4bb3cbf7d5..3e969d708b 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -10,6 +10,7 @@ import Slider from '@/app/components/base/slider' import Radio from '@/app/components/base/radio' import { SimpleSelect } from '@/app/components/base/select' import TagInput from '@/app/components/base/tag-input' +import { useTranslation } from 'react-i18next' export type ParameterValue = number | string | string[] | boolean | undefined @@ -27,6 +28,7 @@ const ParameterItem: FC = ({ onSwitch, isInWorkflow, }) => { + const { t } = useTranslation() const language = useLanguage() const [localValue, setLocalValue] = useState(value) const numberInputRef = useRef(null) diff --git a/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx index fc29feaefc..f243d30aff 100644 --- a/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx @@ -2,7 +2,6 @@ import React from 'react' import { useTranslation } from 'react-i18next' import { RiAddLine, - RiArrowDropDownLine, RiQuestionLine, } from '@remixicon/react' import ToolSelector from '@/app/components/plugins/plugin-detail-panel/tool-selector' @@ -13,6 +12,7 @@ import type { ToolValue } from '@/app/components/workflow/block-selector/types' import type { Node } from 'reactflow' import type { NodeOutPutVar } from '@/app/components/workflow/types' import cn from '@/utils/classnames' +import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general' type Props = { disabled?: boolean @@ -98,14 +98,12 @@ const MultipleToolSelector = ({ )} {supportCollapse && ( -
- -
+ )}
{value.length > 0 && ( diff --git a/web/app/components/workflow/hooks/use-workflow-variables.ts b/web/app/components/workflow/hooks/use-workflow-variables.ts index a2863671ed..35637bc775 100644 --- a/web/app/components/workflow/hooks/use-workflow-variables.ts +++ b/web/app/components/workflow/hooks/use-workflow-variables.ts @@ -8,6 +8,8 @@ import type { ValueSelector, Var, } from '@/app/components/workflow/types' +import { useIsChatMode } from './use-workflow' +import { useStoreApi } from 'reactflow' export const useWorkflowVariables = () => { const { t } = useTranslation() @@ -75,3 +77,37 @@ export const useWorkflowVariables = () => { getCurrentVariableType, } } + +export const useWorkflowVariableType = () => { + const store = useStoreApi() + const { + getNodes, + } = store.getState() + const { getCurrentVariableType } = useWorkflowVariables() + + const isChatMode = useIsChatMode() + + const getVarType = ({ + nodeId, + valueSelector, + }: { + nodeId: string, + valueSelector: ValueSelector, + }) => { + const node = getNodes().find(n => n.id === nodeId) + const isInIteration = !!node?.data.isInIteration + const iterationNode = isInIteration ? getNodes().find(n => n.id === node.parentId) : null + const availableNodes = [node] + + const type = getCurrentVariableType({ + parentNode: iterationNode, + valueSelector, + availableNodes, + isChatMode, + isConstant: false, + }) + return type + } + + return getVarType +} diff --git a/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx b/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx index 4b36125575..2390dfd74e 100644 --- a/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx +++ b/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx @@ -4,10 +4,16 @@ import Collapse from '.' type FieldCollapseProps = { title: string children: ReactNode + collapsed?: boolean + onCollapse?: (collapsed: boolean) => void + operations?: ReactNode } const FieldCollapse = ({ title, children, + collapsed, + onCollapse, + operations, }: FieldCollapseProps) => { return (
@@ -15,6 +21,9 @@ const FieldCollapse = ({ trigger={
{title}
} + operations={operations} + collapsed={collapsed} + onCollapse={onCollapse} >
{children} diff --git a/web/app/components/workflow/nodes/_base/components/collapse/index.tsx b/web/app/components/workflow/nodes/_base/components/collapse/index.tsx index 1f39c1c1c5..16fba88a25 100644 --- a/web/app/components/workflow/nodes/_base/components/collapse/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/collapse/index.tsx @@ -1,15 +1,18 @@ -import { useState } from 'react' -import { RiArrowDropRightLine } from '@remixicon/react' +import type { ReactNode } from 'react' +import { useMemo, useState } from 'react' +import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general' import cn from '@/utils/classnames' export { default as FieldCollapse } from './field-collapse' type CollapseProps = { disabled?: boolean - trigger: React.JSX.Element + trigger: React.JSX.Element | ((collapseIcon: React.JSX.Element | null) => React.JSX.Element) children: React.JSX.Element collapsed?: boolean onCollapse?: (collapsed: boolean) => void + operations?: ReactNode + hideCollapseIcon?: boolean } const Collapse = ({ disabled, @@ -17,34 +20,44 @@ const Collapse = ({ children, collapsed, onCollapse, + operations, + hideCollapseIcon, }: CollapseProps) => { const [collapsedLocal, setCollapsedLocal] = useState(true) const collapsedMerged = collapsed !== undefined ? collapsed : collapsedLocal + const collapseIcon = useMemo(() => { + if (disabled) + return null + return ( + + ) + }, [collapsedMerged, disabled]) return ( <> -
{ - if (!disabled) { - setCollapsedLocal(!collapsedMerged) - onCollapse?.(!collapsedMerged) - } - }} - > -
- { - !disabled && ( - - ) - } +
+
{ + if (!disabled) { + setCollapsedLocal(!collapsedMerged) + onCollapse?.(!collapsedMerged) + } + }} + > + {typeof trigger === 'function' ? trigger(collapseIcon) : trigger} + {!hideCollapseIcon && ( +
+ {collapseIcon} +
+ )}
- {trigger} + {operations}
{ !collapsedMerged && children diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx index b36abbfb00..cfcbae80f3 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx @@ -49,20 +49,23 @@ const ErrorHandle = ({ disabled={!error_strategy} collapsed={collapsed} onCollapse={setCollapsed} + hideCollapseIcon trigger={ -
-
-
- {t('workflow.nodes.common.errorHandle.title')} + collapseIcon => ( +
+
+
+ {t('workflow.nodes.common.errorHandle.title')} +
+ + {collapseIcon}
- +
- -
- } + )} > <> { diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx index 190c748831..d9516dfcf5 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx @@ -50,6 +50,7 @@ const ErrorHandleTypeSelector = ({ > { e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() setOpen(v => !v) }}> + + )} + + + +
+
+
+ +
+
+ ) +} + +export default React.memo(CodeEditor) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx new file mode 100644 index 0000000000..2685182f9f --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx @@ -0,0 +1,27 @@ +import React from 'react' +import type { FC } from 'react' +import { RiErrorWarningFill } from '@remixicon/react' +import classNames from '@/utils/classnames' + +type ErrorMessageProps = { + message: string +} & React.HTMLAttributes + +const ErrorMessage: FC = ({ + message, + className, +}) => { + return ( +
+ +
+ {message} +
+
+ ) +} + +export default React.memo(ErrorMessage) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx new file mode 100644 index 0000000000..d34836d5b2 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx @@ -0,0 +1,34 @@ +import React, { type FC } from 'react' +import Modal from '../../../../../base/modal' +import type { SchemaRoot } from '../../types' +import JsonSchemaConfig from './json-schema-config' + +type JsonSchemaConfigModalProps = { + isShow: boolean + defaultSchema?: SchemaRoot + onSave: (schema: SchemaRoot) => void + onClose: () => void +} + +const JsonSchemaConfigModal: FC = ({ + isShow, + defaultSchema, + onSave, + onClose, +}) => { + return ( + + + + ) +} + +export default JsonSchemaConfigModal diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx new file mode 100644 index 0000000000..643059adbd --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx @@ -0,0 +1,136 @@ +import React, { type FC, useCallback, useEffect, useRef, useState } from 'react' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import cn from '@/utils/classnames' +import { useTranslation } from 'react-i18next' +import { RiCloseLine } from '@remixicon/react' +import Button from '@/app/components/base/button' +import { checkJsonDepth } from '../../utils' +import { JSON_SCHEMA_MAX_DEPTH } from '@/config' +import CodeEditor from './code-editor' +import ErrorMessage from './error-message' +import { useVisualEditorStore } from './visual-editor/store' +import { useMittContext } from './visual-editor/context' + +type JsonImporterProps = { + onSubmit: (schema: any) => void + updateBtnWidth: (width: number) => void +} + +const JsonImporter: FC = ({ + onSubmit, + updateBtnWidth, +}) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + const [json, setJson] = useState('') + const [parseError, setParseError] = useState(null) + const importBtnRef = useRef(null) + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const { emit } = useMittContext() + + useEffect(() => { + if (importBtnRef.current) { + const rect = importBtnRef.current.getBoundingClientRect() + updateBtnWidth(rect.width) + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + const handleTrigger = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + if (advancedEditing || isAddingNewField) + emit('quitEditing', {}) + setOpen(!open) + }, [open, advancedEditing, isAddingNewField, emit]) + + const onClose = useCallback(() => { + setOpen(false) + }, []) + + const handleSubmit = useCallback(() => { + try { + const parsedJSON = JSON.parse(json) + if (typeof parsedJSON !== 'object' || Array.isArray(parsedJSON)) { + setParseError(new Error('Root must be an object, not an array or primitive value.')) + return + } + const maxDepth = checkJsonDepth(parsedJSON) + if (maxDepth > JSON_SCHEMA_MAX_DEPTH) { + setParseError({ + type: 'error', + message: `Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`, + }) + return + } + onSubmit(parsedJSON) + setParseError(null) + setOpen(false) + } + catch (e: any) { + if (e instanceof Error) + setParseError(e) + else + setParseError(new Error('Invalid JSON')) + } + }, [onSubmit, json]) + + return ( + + + + + +
+ {/* Title */} +
+
+ +
+
+ {t('workflow.nodes.llm.jsonSchema.import')} +
+
+ {/* Content */} +
+ + {parseError && } +
+ {/* Footer */} +
+ + +
+
+
+
+ ) +} + +export default JsonImporter diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx new file mode 100644 index 0000000000..d125e31dae --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx @@ -0,0 +1,301 @@ +import React, { type FC, useCallback, useState } from 'react' +import { type SchemaRoot, Type } from '../../types' +import { RiBracesLine, RiCloseLine, RiExternalLinkLine, RiTimelineView } from '@remixicon/react' +import { SegmentedControl } from '../../../../../base/segmented-control' +import JsonSchemaGenerator from './json-schema-generator' +import Divider from '@/app/components/base/divider' +import JsonImporter from './json-importer' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import VisualEditor from './visual-editor' +import SchemaEditor from './schema-editor' +import { + checkJsonSchemaDepth, + convertBooleanToString, + getValidationErrorMessage, + jsonToSchema, + preValidateSchema, + validateSchemaAgainstDraft7, +} from '../../utils' +import { MittProvider, VisualEditorContextProvider, useMittContext } from './visual-editor/context' +import ErrorMessage from './error-message' +import { useVisualEditorStore } from './visual-editor/store' +import Toast from '@/app/components/base/toast' +import { useGetLanguage } from '@/context/i18n' +import { JSON_SCHEMA_MAX_DEPTH } from '@/config' + +type JsonSchemaConfigProps = { + defaultSchema?: SchemaRoot + onSave: (schema: SchemaRoot) => void + onClose: () => void +} + +enum SchemaView { + VisualEditor = 'visualEditor', + JsonSchema = 'jsonSchema', +} + +const VIEW_TABS = [ + { Icon: RiTimelineView, text: 'Visual Editor', value: SchemaView.VisualEditor }, + { Icon: RiBracesLine, text: 'JSON Schema', value: SchemaView.JsonSchema }, +] + +const DEFAULT_SCHEMA: SchemaRoot = { + type: Type.object, + properties: {}, + required: [], + additionalProperties: false, +} + +const HELP_DOC_URL = { + zh_Hans: 'https://docs.dify.ai/zh-hans/guides/workflow/structured-outputs', + en_US: 'https://docs.dify.ai/guides/workflow/structured-outputs', + ja_JP: 'https://docs.dify.ai/ja-jp/guides/workflow/structured-outputs', +} + +type LocaleKey = keyof typeof HELP_DOC_URL + +const JsonSchemaConfig: FC = ({ + defaultSchema, + onSave, + onClose, +}) => { + const { t } = useTranslation() + const locale = useGetLanguage() as LocaleKey + const [currentTab, setCurrentTab] = useState(SchemaView.VisualEditor) + const [jsonSchema, setJsonSchema] = useState(defaultSchema || DEFAULT_SCHEMA) + const [json, setJson] = useState(JSON.stringify(jsonSchema, null, 2)) + const [btnWidth, setBtnWidth] = useState(0) + const [parseError, setParseError] = useState(null) + const [validationError, setValidationError] = useState('') + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const setAdvancedEditing = useVisualEditorStore(state => state.setAdvancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const setIsAddingNewField = useVisualEditorStore(state => state.setIsAddingNewField) + const setHoveringProperty = useVisualEditorStore(state => state.setHoveringProperty) + const { emit } = useMittContext() + + const updateBtnWidth = useCallback((width: number) => { + setBtnWidth(width + 32) + }, []) + + const handleTabChange = useCallback((value: SchemaView) => { + if (currentTab === value) return + if (currentTab === SchemaView.JsonSchema) { + try { + const schema = JSON.parse(json) + setParseError(null) + const result = preValidateSchema(schema) + if (!result.success) { + setValidationError(result.error.message) + return + } + const schemaDepth = checkJsonSchemaDepth(schema) + if (schemaDepth > JSON_SCHEMA_MAX_DEPTH) { + setValidationError(`Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`) + return + } + convertBooleanToString(schema) + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + setJsonSchema(schema) + setValidationError('') + } + catch (error) { + setValidationError('') + if (error instanceof Error) + setParseError(error) + else + setParseError(new Error('Invalid JSON')) + return + } + } + else if (currentTab === SchemaView.VisualEditor) { + if (advancedEditing || isAddingNewField) + emit('quitEditing', { callback: (backup: SchemaRoot) => setJson(JSON.stringify(backup || jsonSchema, null, 2)) }) + else + setJson(JSON.stringify(jsonSchema, null, 2)) + } + + setCurrentTab(value) + }, [currentTab, jsonSchema, json, advancedEditing, isAddingNewField, emit]) + + const handleApplySchema = useCallback((schema: SchemaRoot) => { + if (currentTab === SchemaView.VisualEditor) + setJsonSchema(schema) + else if (currentTab === SchemaView.JsonSchema) + setJson(JSON.stringify(schema, null, 2)) + }, [currentTab]) + + const handleSubmit = useCallback((schema: any) => { + const jsonSchema = jsonToSchema(schema) as SchemaRoot + if (currentTab === SchemaView.VisualEditor) + setJsonSchema(jsonSchema) + else if (currentTab === SchemaView.JsonSchema) + setJson(JSON.stringify(jsonSchema, null, 2)) + }, [currentTab]) + + const handleVisualEditorUpdate = useCallback((schema: SchemaRoot) => { + setJsonSchema(schema) + }, []) + + const handleSchemaEditorUpdate = useCallback((schema: string) => { + setJson(schema) + }, []) + + const handleResetDefaults = useCallback(() => { + if (currentTab === SchemaView.VisualEditor) { + setHoveringProperty(null) + advancedEditing && setAdvancedEditing(false) + isAddingNewField && setIsAddingNewField(false) + } + setJsonSchema(DEFAULT_SCHEMA) + setJson(JSON.stringify(DEFAULT_SCHEMA, null, 2)) + }, [currentTab, advancedEditing, isAddingNewField, setAdvancedEditing, setIsAddingNewField, setHoveringProperty]) + + const handleCancel = useCallback(() => { + onClose() + }, [onClose]) + + const handleSave = useCallback(() => { + let schema = jsonSchema + if (currentTab === SchemaView.JsonSchema) { + try { + schema = JSON.parse(json) + setParseError(null) + const result = preValidateSchema(schema) + if (!result.success) { + setValidationError(result.error.message) + return + } + const schemaDepth = checkJsonSchemaDepth(schema) + if (schemaDepth > JSON_SCHEMA_MAX_DEPTH) { + setValidationError(`Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`) + return + } + convertBooleanToString(schema) + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + setJsonSchema(schema) + setValidationError('') + } + catch (error) { + setValidationError('') + if (error instanceof Error) + setParseError(error) + else + setParseError(new Error('Invalid JSON')) + return + } + } + else if (currentTab === SchemaView.VisualEditor) { + if (advancedEditing || isAddingNewField) { + Toast.notify({ + type: 'warning', + message: t('workflow.nodes.llm.jsonSchema.warningTips.saveSchema'), + }) + return + } + } + onSave(schema) + onClose() + }, [currentTab, jsonSchema, json, onSave, onClose, advancedEditing, isAddingNewField, t]) + + return ( +
+ {/* Header */} +
+
+ {t('workflow.nodes.llm.jsonSchema.title')} +
+
+ +
+
+ {/* Content */} +
+ {/* Tab */} + + options={VIEW_TABS} + value={currentTab} + onChange={handleTabChange} + /> +
+ {/* JSON Schema Generator */} + + + {/* JSON Schema Importer */} + +
+
+
+ {currentTab === SchemaView.VisualEditor && ( + + )} + {currentTab === SchemaView.JsonSchema && ( + + )} + {parseError && } + {validationError && } +
+ {/* Footer */} +
+ + {t('workflow.nodes.llm.jsonSchema.doc')} + + +
+
+ + +
+
+ + +
+
+
+
+ ) +} + +const JsonSchemaConfigWrapper: FC = (props) => { + return ( + + + + + + ) +} + +export default JsonSchemaConfigWrapper diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx new file mode 100644 index 0000000000..5f1f117086 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx @@ -0,0 +1,7 @@ +import SchemaGeneratorLight from './schema-generator-light' +import SchemaGeneratorDark from './schema-generator-dark' + +export { + SchemaGeneratorLight, + SchemaGeneratorDark, +} diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx new file mode 100644 index 0000000000..ac4793b1e3 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx @@ -0,0 +1,15 @@ +const SchemaGeneratorDark = () => { + return ( + + + + + + + + + + ) +} + +export default SchemaGeneratorDark diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx new file mode 100644 index 0000000000..8b898bde68 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx @@ -0,0 +1,15 @@ +const SchemaGeneratorLight = () => { + return ( + + + + + + + + + + ) +} + +export default SchemaGeneratorLight diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx new file mode 100644 index 0000000000..00f57237e5 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx @@ -0,0 +1,121 @@ +import React, { type FC, useCallback, useMemo, useState } from 'react' +import type { SchemaRoot } from '../../../types' +import { RiArrowLeftLine, RiCloseLine, RiSparklingLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import CodeEditor from '../code-editor' +import ErrorMessage from '../error-message' +import { getValidationErrorMessage, validateSchemaAgainstDraft7 } from '../../../utils' +import Loading from '@/app/components/base/loading' + +type GeneratedResultProps = { + schema: SchemaRoot + isGenerating: boolean + onBack: () => void + onRegenerate: () => void + onClose: () => void + onApply: () => void +} + +const GeneratedResult: FC = ({ + schema, + isGenerating, + onBack, + onRegenerate, + onClose, + onApply, +}) => { + const { t } = useTranslation() + const [parseError, setParseError] = useState(null) + const [validationError, setValidationError] = useState('') + + const formatJSON = (json: SchemaRoot) => { + try { + const schema = JSON.stringify(json, null, 2) + setParseError(null) + return schema + } + catch (e) { + if (e instanceof Error) + setParseError(e) + else + setParseError(new Error('Invalid JSON')) + return '' + } + } + + const jsonSchema = useMemo(() => formatJSON(schema), [schema]) + + const handleApply = useCallback(() => { + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + onApply() + setValidationError('') + }, [schema, onApply]) + + return ( +
+ { + isGenerating ? ( +
+ +
{t('workflow.nodes.llm.jsonSchema.generating')}
+
+ ) : ( + <> +
+ +
+ {/* Title */} +
+
+ {t('workflow.nodes.llm.jsonSchema.generatedResult')} +
+
+ {t('workflow.nodes.llm.jsonSchema.resultTip')} +
+
+ {/* Content */} +
+ + {parseError && } + {validationError && } +
+ {/* Footer */} +
+ +
+ + +
+
+ + + ) + } +
+ ) +} + +export default React.memo(GeneratedResult) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx new file mode 100644 index 0000000000..4732499f3a --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx @@ -0,0 +1,183 @@ +import React, { type FC, useCallback, useEffect, useState } from 'react' +import type { SchemaRoot } from '../../../types' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import useTheme from '@/hooks/use-theme' +import type { CompletionParams, Model } from '@/types/app' +import { ModelModeType } from '@/types/app' +import { Theme } from '@/types/app' +import { SchemaGeneratorDark, SchemaGeneratorLight } from './assets' +import cn from '@/utils/classnames' +import type { ModelInfo } from './prompt-editor' +import PromptEditor from './prompt-editor' +import GeneratedResult from './generated-result' +import { useGenerateStructuredOutputRules } from '@/service/use-common' +import Toast from '@/app/components/base/toast' +import { type FormValue, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useVisualEditorStore } from '../visual-editor/store' +import { useTranslation } from 'react-i18next' +import { useMittContext } from '../visual-editor/context' + +type JsonSchemaGeneratorProps = { + onApply: (schema: SchemaRoot) => void + crossAxisOffset?: number +} + +enum GeneratorView { + promptEditor = 'promptEditor', + result = 'result', +} + +export const JsonSchemaGenerator: FC = ({ + onApply, + crossAxisOffset, +}) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + const [view, setView] = useState(GeneratorView.promptEditor) + const [model, setModel] = useState({ + name: '', + provider: '', + mode: ModelModeType.completion, + completion_params: {} as CompletionParams, + }) + const [instruction, setInstruction] = useState('') + const [schema, setSchema] = useState(null) + const { theme } = useTheme() + const { + defaultModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const { emit } = useMittContext() + const SchemaGenerator = theme === Theme.light ? SchemaGeneratorLight : SchemaGeneratorDark + + useEffect(() => { + if (defaultModel) { + setModel(prev => ({ + ...prev, + name: defaultModel.model, + provider: defaultModel.provider.provider, + })) + } + }, [defaultModel]) + + const handleTrigger = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + if (advancedEditing || isAddingNewField) + emit('quitEditing', {}) + setOpen(!open) + }, [open, advancedEditing, isAddingNewField, emit]) + + const onClose = useCallback(() => { + setOpen(false) + }, []) + + const handleModelChange = useCallback((model: ModelInfo) => { + setModel(prev => ({ + ...prev, + provider: model.provider, + name: model.modelId, + mode: model.mode as ModelModeType, + })) + }, []) + + const handleCompletionParamsChange = useCallback((newParams: FormValue) => { + setModel(prev => ({ + ...prev, + completion_params: newParams as CompletionParams, + }), + ) + }, []) + + const { mutateAsync: generateStructuredOutputRules, isPending: isGenerating } = useGenerateStructuredOutputRules() + + const generateSchema = useCallback(async () => { + const { output, error } = await generateStructuredOutputRules({ instruction, model_config: model! }) + if (error) { + Toast.notify({ + type: 'error', + message: error, + }) + setSchema(null) + setView(GeneratorView.promptEditor) + return + } + return output + }, [instruction, model, generateStructuredOutputRules]) + + const handleGenerate = useCallback(async () => { + setView(GeneratorView.result) + const output = await generateSchema() + if (output === undefined) return + setSchema(JSON.parse(output)) + }, [generateSchema]) + + const goBackToPromptEditor = () => { + setView(GeneratorView.promptEditor) + } + + const handleRegenerate = useCallback(async () => { + const output = await generateSchema() + if (output === undefined) return + setSchema(JSON.parse(output)) + }, [generateSchema]) + + const handleApply = () => { + onApply(schema!) + setOpen(false) + } + + return ( + + + + + + {view === GeneratorView.promptEditor && ( + + )} + {view === GeneratorView.result && ( + + )} + + + ) +} + +export default JsonSchemaGenerator diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx new file mode 100644 index 0000000000..9387813ee5 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx @@ -0,0 +1,108 @@ +import React, { useCallback } from 'react' +import type { FC } from 'react' +import { RiCloseLine, RiSparklingFill } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Textarea from '@/app/components/base/textarea' +import Tooltip from '@/app/components/base/tooltip' +import Button from '@/app/components/base/button' +import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' +import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' +import type { Model } from '@/types/app' + +export type ModelInfo = { + modelId: string + provider: string + mode?: string + features?: string[] +} + +type PromptEditorProps = { + instruction: string + model: Model + onInstructionChange: (instruction: string) => void + onCompletionParamsChange: (newParams: FormValue) => void + onModelChange: (model: ModelInfo) => void + onClose: () => void + onGenerate: () => void +} + +const PromptEditor: FC = ({ + instruction, + model, + onInstructionChange, + onCompletionParamsChange, + onClose, + onGenerate, + onModelChange, +}) => { + const { t } = useTranslation() + + const handleInstructionChange = useCallback((e: React.ChangeEvent) => { + onInstructionChange(e.target.value) + }, [onInstructionChange]) + + return ( +
+
+ +
+ {/* Title */} +
+
+ {t('workflow.nodes.llm.jsonSchema.generateJsonSchema')} +
+
+ {t('workflow.nodes.llm.jsonSchema.generationTip')} +
+
+ {/* Content */} +
+
+ {t('common.modelProvider.model')} +
+ +
+
+
+ {t('workflow.nodes.llm.jsonSchema.instruction')} + +
+
+