From 24af4b931384ecf3118a8aa07e0ea7de12605d75 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 13 Sep 2024 15:37:54 +0800 Subject: [PATCH 01/43] fix: o1-series model encounters an error when the generate mode is blocking (#8363) --- .../model_runtime/model_providers/openai/llm/llm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 95533ccfaf..60d69c6e47 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -615,10 +615,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): block_as_stream = False if model.startswith("o1"): - block_as_stream = True - stream = False - if "stream_options" in extra_model_kwargs: - del extra_model_kwargs["stream_options"] + if stream: + block_as_stream = True + stream = False + + if "stream_options" in extra_model_kwargs: + del extra_model_kwargs["stream_options"] if "stop" in extra_model_kwargs: del extra_model_kwargs["stop"] From 5dfd7abb2b9730a1d57d3aad2a251ccecee6de3e Mon Sep 17 00:00:00 2001 From: Joel Date: Fri, 13 Sep 2024 16:05:26 +0800 Subject: [PATCH 02/43] fix: when edit load balancing config not pass the empty filed value hidden (#8366) --- .../model-modal/model-load-balancing-entry-modal.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx index 86857f1ab2..e18b490844 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx @@ -192,12 +192,12 @@ const ModelLoadBalancingEntryModal: FC = ({ }) const getSecretValues = useCallback((v: FormValue) => { return secretFormSchemas.reduce((prev, next) => { - if (v[next.variable] === initialFormSchemasValue[next.variable]) + if (isEditMode && v[next.variable] && v[next.variable] === initialFormSchemasValue[next.variable]) prev[next.variable] = '[__HIDDEN__]' return prev }, {} as Record) - }, [initialFormSchemasValue, secretFormSchemas]) + }, [initialFormSchemasValue, isEditMode, secretFormSchemas]) // const handleValueChange = ({ __model_type, __model_name, ...v }: FormValue) => { const handleValueChange = (v: FormValue) => { From 84ac5ccc8f3179efcf998b8d138c7bab310373ca Mon Sep 17 00:00:00 2001 From: Joe <79627742+ZhouhaoJiang@users.noreply.github.com> Date: Fri, 13 Sep 2024 16:08:08 +0800 Subject: [PATCH 03/43] fix: add before send to remove langfuse defaultErrorResponse (#8361) --- api/extensions/ext_sentry.py | 10 ++++++++++ api/services/ops_service.py | 24 ++++++++++++++++++------ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 86e4720b3f..c2dc736038 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -6,6 +6,15 @@ from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException +def before_send(event, hint): + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if parse_error.defaultErrorResponse in str(exc_value): + return None + + return event + + def init_app(app): if app.config.get("SENTRY_DSN"): sentry_sdk.init( @@ -16,4 +25,5 @@ def init_app(app): profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), environment=app.config.get("DEPLOY_ENV"), release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", + before_send=before_send, ) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index d8e2b1689a..1160a1f275 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -31,16 +31,28 @@ class OpsService: if tracing_provider == "langfuse" and ( "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") ): - project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) - new_decrypt_tracing_config.update( - {"project_url": "{host}/project/{key}".format(host=decrypt_tracing_config.get("host"), key=project_key)} - ) + try: + project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update( + { + "project_url": "{host}/project/{key}".format( + host=decrypt_tracing_config.get("host"), key=project_key + ) + } + ) + except Exception: + new_decrypt_tracing_config.update( + {"project_url": "{host}/".format(host=decrypt_tracing_config.get("host"))} + ) if tracing_provider == "langsmith" and ( "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") ): - project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) - new_decrypt_tracing_config.update({"project_url": project_url}) + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"}) trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() From 9d80d7def7c46d576a73496ff9542d01c7841894 Mon Sep 17 00:00:00 2001 From: Joel Date: Fri, 13 Sep 2024 17:15:03 +0800 Subject: [PATCH 04/43] fix: edit load balancing not pass id (#8370) --- .../model-modal/model-load-balancing-entry-modal.tsx | 1 + .../header/account-setting/model-provider-page/utils.ts | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx index e18b490844..1c318b9baf 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx @@ -214,6 +214,7 @@ const ModelLoadBalancingEntryModal: FC = ({ ...value, ...getSecretValues(value), }, + entry?.id, ) if (res.status === ValidatedStatus.Success) { // notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) diff --git a/web/app/components/header/account-setting/model-provider-page/utils.ts b/web/app/components/header/account-setting/model-provider-page/utils.ts index 8cad399763..165926b2bb 100644 --- a/web/app/components/header/account-setting/model-provider-page/utils.ts +++ b/web/app/components/header/account-setting/model-provider-page/utils.ts @@ -56,14 +56,14 @@ export const validateCredentials = async (predefined: boolean, provider: string, } } -export const validateLoadBalancingCredentials = async (predefined: boolean, provider: string, v: FormValue): Promise<{ +export const validateLoadBalancingCredentials = async (predefined: boolean, provider: string, v: FormValue, id?: string): Promise<{ status: ValidatedStatus message?: string }> => { const { __model_name, __model_type, ...credentials } = v try { const res = await validateModelLoadBalancingCredentials({ - url: `/workspaces/current/model-providers/${provider}/models/load-balancing-configs/credentials-validate`, + url: `/workspaces/current/model-providers/${provider}/models/load-balancing-configs/${id ? `${id}/` : ''}credentials-validate`, body: { model: __model_name, model_type: __model_type, From cd3eaed3353bc00b89be0b367ee8f12bbc1582c1 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 13 Sep 2024 19:55:54 +0800 Subject: [PATCH 05/43] fix(workflow): both parallel and single branch errors occur in if-else (#8378) --- .../workflow/graph_engine/entities/graph.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 20efe91a59..1d7e9158d8 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -405,21 +405,22 @@ class Graph(BaseModel): if condition_edge_mappings: for condition_hash, graph_edges in condition_edge_mappings.items(): - current_parallel = cls._get_current_parallel( - parallel_mapping=parallel_mapping, - graph_edge=graph_edge, - parallel=condition_parallels.get(condition_hash), - parent_parallel=parent_parallel, - ) + for graph_edge in graph_edges: + current_parallel: GraphParallel | None = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=condition_parallels.get(condition_hash), + parent_parallel=parent_parallel, + ) - cls._recursively_add_parallels( - edge_mapping=edge_mapping, - reverse_edge_mapping=reverse_edge_mapping, - start_node_id=graph_edge.target_node_id, - parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping, - parent_parallel=current_parallel, - ) + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) else: for graph_edge in target_node_edges: current_parallel = cls._get_current_parallel( From 06b66216d76062768f892dce0239a668e4b87e38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Fri, 13 Sep 2024 20:02:00 +0800 Subject: [PATCH 06/43] chore: update firecrawl scrape to V1 api (#8367) --- .../builtin/firecrawl/firecrawl_appx.py | 25 +-- .../provider/builtin/firecrawl/tools/crawl.py | 47 ++-- .../builtin/firecrawl/tools/crawl.yaml | 163 +++++--------- .../provider/builtin/firecrawl/tools/map.py | 25 +++ .../provider/builtin/firecrawl/tools/map.yaml | 59 +++++ .../builtin/firecrawl/tools/scrape.py | 40 ++-- .../builtin/firecrawl/tools/scrape.yaml | 204 ++++++++---------- .../builtin/firecrawl/tools/search.py | 27 --- .../builtin/firecrawl/tools/search.yaml | 75 ------- 9 files changed, 287 insertions(+), 378 deletions(-) create mode 100644 api/core/tools/provider/builtin/firecrawl/tools/map.py create mode 100644 api/core/tools/provider/builtin/firecrawl/tools/map.yaml delete mode 100644 api/core/tools/provider/builtin/firecrawl/tools/search.py delete mode 100644 api/core/tools/provider/builtin/firecrawl/tools/search.yaml diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py index a0e4cdf933..d9fb6f04bc 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py @@ -37,9 +37,8 @@ class FirecrawlApp: for i in range(retries): try: response = requests.request(method, url, json=data, headers=headers) - response.raise_for_status() return response.json() - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException: if i < retries - 1: time.sleep(backoff_factor * (2**i)) else: @@ -47,7 +46,7 @@ class FirecrawlApp: return None def scrape_url(self, url: str, **kwargs): - endpoint = f"{self.base_url}/v0/scrape" + endpoint = f"{self.base_url}/v1/scrape" data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") response = self._request("POST", endpoint, data) @@ -55,39 +54,41 @@ class FirecrawlApp: raise HTTPError("Failed to scrape URL after multiple retries") return response - def search(self, query: str, **kwargs): - endpoint = f"{self.base_url}/v0/search" - data = {"query": query, **kwargs} + def map(self, url: str, **kwargs): + endpoint = f"{self.base_url}/v1/map" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") response = self._request("POST", endpoint, data) if response is None: - raise HTTPError("Failed to perform search after multiple retries") + raise HTTPError("Failed to perform map after multiple retries") return response def crawl_url( self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs ): - endpoint = f"{self.base_url}/v0/crawl" + endpoint = f"{self.base_url}/v1/crawl" headers = self._prepare_headers(idempotency_key) data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate crawl after multiple retries") - job_id: str = response["jobId"] + elif response.get("success") == False: + raise HTTPError(f'Failed to crawl: {response.get("error")}') + job_id: str = response["id"] if wait: return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) return response def check_crawl_status(self, job_id: str): - endpoint = f"{self.base_url}/v0/crawl/status/{job_id}" + endpoint = f"{self.base_url}/v1/crawl/{job_id}" response = self._request("GET", endpoint) if response is None: raise HTTPError(f"Failed to check status for job {job_id} after multiple retries") return response def cancel_crawl_job(self, job_id: str): - endpoint = f"{self.base_url}/v0/crawl/cancel/{job_id}" + endpoint = f"{self.base_url}/v1/crawl/{job_id}" response = self._request("DELETE", endpoint) if response is None: raise HTTPError(f"Failed to cancel job {job_id} after multiple retries") @@ -116,6 +117,6 @@ def get_json_params(tool_parameters: dict[str, Any], key): # support both single quotes and double quotes param = param.replace("'", '"') param = json.loads(param) - except: + except Exception: raise ValueError(f"Invalid {key} format.") return param diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 94717cbbfb..15ab510c6c 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -8,39 +8,38 @@ from core.tools.tool.builtin_tool import BuiltinTool class CrawlTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: """ - the crawlerOptions and pageOptions comes from doc here: + the api doc: https://docs.firecrawl.dev/api-reference/endpoint/crawl """ app = FirecrawlApp( api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] ) - crawlerOptions = {} - pageOptions = {} + + scrapeOptions = {} + payload = {} wait_for_results = tool_parameters.get("wait_for_results", True) - crawlerOptions["excludes"] = get_array_params(tool_parameters, "excludes") - crawlerOptions["includes"] = get_array_params(tool_parameters, "includes") - crawlerOptions["returnOnlyUrls"] = tool_parameters.get("returnOnlyUrls", False) - crawlerOptions["maxDepth"] = tool_parameters.get("maxDepth") - crawlerOptions["mode"] = tool_parameters.get("mode") - crawlerOptions["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) - crawlerOptions["limit"] = tool_parameters.get("limit", 5) - crawlerOptions["allowBackwardCrawling"] = tool_parameters.get("allowBackwardCrawling", False) - crawlerOptions["allowExternalContentLinks"] = tool_parameters.get("allowExternalContentLinks", False) + payload["excludePaths"] = get_array_params(tool_parameters, "excludePaths") + payload["includePaths"] = get_array_params(tool_parameters, "includePaths") + payload["maxDepth"] = tool_parameters.get("maxDepth") + payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) + payload["limit"] = tool_parameters.get("limit", 5) + payload["allowBackwardLinks"] = tool_parameters.get("allowBackwardLinks", False) + payload["allowExternalLinks"] = tool_parameters.get("allowExternalLinks", False) + payload["webhook"] = tool_parameters.get("webhook") - pageOptions["headers"] = get_json_params(tool_parameters, "headers") - pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) - pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) - pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") - pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") - pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) - pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) - pageOptions["screenshot"] = tool_parameters.get("screenshot", False) - pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) + scrapeOptions["formats"] = get_array_params(tool_parameters, "formats") + scrapeOptions["headers"] = get_json_params(tool_parameters, "headers") + scrapeOptions["includeTags"] = get_array_params(tool_parameters, "includeTags") + scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags") + scrapeOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + scrapeOptions["waitFor"] = tool_parameters.get("waitFor", 0) + scrapeOptions = {k: v for k, v in scrapeOptions.items() if v not in (None, "")} + payload["scrapeOptions"] = scrapeOptions or None - crawl_result = app.crawl_url( - url=tool_parameters["url"], wait=wait_for_results, crawlerOptions=crawlerOptions, pageOptions=pageOptions - ) + payload = {k: v for k, v in payload.items() if v not in (None, "")} + + crawl_result = app.crawl_url(url=tool_parameters["url"], wait=wait_for_results, **payload) return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml index 0c5399f973..0d7dbcac20 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml @@ -31,8 +31,21 @@ parameters: en_US: If you choose not to wait, it will directly return a job ID. You can use this job ID to check the crawling results or cancel the crawling task, which is usually very useful for a large-scale crawling task. zh_Hans: 如果选择不等待,则会直接返回一个job_id,可以通过job_id查询爬取结果或取消爬取任务,这通常对于一个大型爬取任务来说非常有用。 form: form -############## Crawl Options ####################### - - name: includes +############## Payload ####################### + - name: excludePaths + type: string + label: + en_US: URL patterns to exclude + zh_Hans: 要排除的URL模式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Pages matching these patterns will be skipped. Example: blog/*, about/* + zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* + form: form + - name: includePaths type: string required: false label: @@ -46,30 +59,6 @@ parameters: Only pages matching these patterns will be crawled. Example: blog/*, about/* zh_Hans: 只有与这些模式匹配的页面才会被爬取。示例:blog/*, about/* form: form - - name: excludes - type: string - label: - en_US: URL patterns to exclude - zh_Hans: 要排除的URL模式 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - human_description: - en_US: | - Pages matching these patterns will be skipped. Example: blog/*, about/* - zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* - form: form - - name: returnOnlyUrls - type: boolean - default: false - label: - en_US: return Only Urls - zh_Hans: 仅返回URL - human_description: - en_US: | - If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents. - zh_Hans: 只返回爬取到的网页链接,而不是网页内容本身。 - form: form - name: maxDepth type: number label: @@ -80,27 +69,10 @@ parameters: zh_Hans: 相对于输入的URL,爬取的最大深度。maxDepth为0时,仅抓取输入的URL。maxDepth为1时,抓取输入的URL以及所有一级深层页面。maxDepth为2时,抓取输入的URL以及所有两级深层页面。更高值遵循相同模式。 form: form min: 0 - - name: mode - type: select - required: false - form: form - options: - - value: default - label: - en_US: default - - value: fast - label: - en_US: fast - default: default - label: - en_US: Crawl Mode - zh_Hans: 爬取模式 - human_description: - en_US: The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites. - zh_Hans: 使用fast模式将不会使用其站点地图,比普通模式快4倍,但是可能不够准确,也不适用于大量js渲染的网站。 + default: 2 - name: ignoreSitemap type: boolean - default: false + default: true label: en_US: ignore Sitemap zh_Hans: 忽略站点地图 @@ -120,7 +92,7 @@ parameters: form: form min: 1 default: 5 - - name: allowBackwardCrawling + - name: allowBackwardLinks type: boolean default: false label: @@ -130,7 +102,7 @@ parameters: en_US: Enables the crawler to navigate from a specific URL to previously linked pages. For instance, from 'example.com/product/123' back to 'example.com/product' zh_Hans: 使爬虫能够从特定URL导航到之前链接的页面。例如,从'example.com/product/123'返回到'example.com/product' form: form - - name: allowExternalContentLinks + - name: allowExternalLinks type: boolean default: false label: @@ -140,7 +112,30 @@ parameters: en_US: Allows the crawler to follow links to external websites. zh_Hans: form: form -############## Page Options ####################### + - name: webhook + type: string + label: + en_US: webhook + human_description: + en_US: | + The URL to send the webhook to. This will trigger for crawl started (crawl.started) ,every page crawled (crawl.page) and when the crawl is completed (crawl.completed or crawl.failed). The response will be the same as the /scrape endpoint. + zh_Hans: 发送Webhook的URL。这将在开始爬取(crawl.started)、每爬取一个页面(crawl.page)以及爬取完成(crawl.completed或crawl.failed)时触发。响应将与/scrape端点相同。 + form: form +############## Scrape Options ####################### + - name: formats + type: string + label: + en_US: Formats + zh_Hans: 结果的格式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot + zh_Hans: | + 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot + form: form - name: headers type: string label: @@ -155,30 +150,10 @@ parameters: en_US: Please enter an object that can be serialized in JSON zh_Hans: 请输入可以json序列化的对象 form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form - - name: onlyIncludeTags + - name: includeTags type: string label: - en_US: only Include Tags + en_US: Include Tags zh_Hans: 仅抓取这些标签 placeholder: en_US: Use commas to separate multiple tags @@ -189,6 +164,20 @@ parameters: zh_Hans: | 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer form: form + - name: excludeTags + type: string + label: + en_US: Exclude Tags + zh_Hans: 要移除这些标签 + human_description: + en_US: | + Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + form: form - name: onlyMainContent type: boolean default: false @@ -199,40 +188,6 @@ parameters: en_US: Only return the main content of the page excluding headers, navs, footers, etc. zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 form: form - - name: removeTags - type: string - label: - en_US: remove Tags - zh_Hans: 要移除这些标签 - human_description: - en_US: | - Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - form: form - - name: replaceAllPathsWithAbsolutePaths - type: boolean - default: false - label: - en_US: All AbsolutePaths - zh_Hans: 使用绝对路径 - human_description: - en_US: Replace all relative paths with absolute paths for images and links. - zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 - form: form - - name: screenshot - type: boolean - default: false - label: - en_US: screenshot - zh_Hans: 截图 - human_description: - en_US: Include a screenshot of the top of the page that you are scraping. - zh_Hans: 提供正在抓取的页面的顶部的截图。 - form: form - name: waitFor type: number min: 0 diff --git a/api/core/tools/provider/builtin/firecrawl/tools/map.py b/api/core/tools/provider/builtin/firecrawl/tools/map.py new file mode 100644 index 0000000000..bdfb5faeb8 --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/map.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp +from core.tools.tool.builtin_tool import BuiltinTool + + +class MapTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + """ + the api doc: + https://docs.firecrawl.dev/api-reference/endpoint/map + """ + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + payload = {} + payload["search"] = tool_parameters.get("search") + payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", True) + payload["includeSubdomains"] = tool_parameters.get("includeSubdomains", False) + payload["limit"] = tool_parameters.get("limit", 5000) + + map_result = app.map(url=tool_parameters["url"], **payload) + + return self.create_json_message(map_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/map.yaml b/api/core/tools/provider/builtin/firecrawl/tools/map.yaml new file mode 100644 index 0000000000..9913756983 --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/map.yaml @@ -0,0 +1,59 @@ +identity: + name: map + author: hjlarry + label: + en_US: Map + zh_Hans: 地图式快爬 +description: + human: + en_US: Input a website and get all the urls on the website - extremly fast + zh_Hans: 输入一个网站,快速获取网站上的所有网址。 + llm: Input a website and get all the urls on the website - extremly fast +parameters: + - name: url + type: string + required: true + label: + en_US: Start URL + zh_Hans: 起始URL + human_description: + en_US: The base URL to start crawling from. + zh_Hans: 要爬取网站的起始URL。 + llm_description: The URL of the website that needs to be crawled. This is a required parameter. + form: llm + - name: search + type: string + label: + en_US: search + zh_Hans: 搜索查询 + human_description: + en_US: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied. + zh_Hans: 用于映射的搜索查询。在Alpha阶段,搜索功能的“智能”部分限制为最多100个搜索结果。然而,如果地图找到了更多结果,则不施加任何限制。 + llm_description: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied. + form: llm +############## Page Options ####################### + - name: ignoreSitemap + type: boolean + default: true + label: + en_US: ignore Sitemap + zh_Hans: 忽略站点地图 + human_description: + en_US: Ignore the website sitemap when crawling. + zh_Hans: 爬取时忽略网站站点地图。 + form: form + - name: includeSubdomains + type: boolean + default: false + label: + en_US: include Subdomains + zh_Hans: 包含子域名 + form: form + - name: limit + type: number + min: 0 + default: 5000 + label: + en_US: Maximum results + zh_Hans: 最大结果数量 + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index 962570bf73..f00a9b31ce 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -6,34 +6,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: """ - the pageOptions and extractorOptions comes from doc here: + the api doc: https://docs.firecrawl.dev/api-reference/endpoint/scrape """ app = FirecrawlApp( api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] ) - pageOptions = {} - extractorOptions = {} + payload = {} + extract = {} - pageOptions["headers"] = get_json_params(tool_parameters, "headers") - pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) - pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) - pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") - pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") - pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) - pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) - pageOptions["screenshot"] = tool_parameters.get("screenshot", False) - pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) + payload["formats"] = get_array_params(tool_parameters, "formats") + payload["onlyMainContent"] = tool_parameters.get("onlyMainContent", True) + payload["includeTags"] = get_array_params(tool_parameters, "includeTags") + payload["excludeTags"] = get_array_params(tool_parameters, "excludeTags") + payload["headers"] = get_json_params(tool_parameters, "headers") + payload["waitFor"] = tool_parameters.get("waitFor", 0) + payload["timeout"] = tool_parameters.get("timeout", 30000) - extractorOptions["mode"] = tool_parameters.get("mode", "") - extractorOptions["extractionPrompt"] = tool_parameters.get("extractionPrompt", "") - extractorOptions["extractionSchema"] = get_json_params(tool_parameters, "extractionSchema") + extract["schema"] = get_json_params(tool_parameters, "schema") + extract["systemPrompt"] = tool_parameters.get("systemPrompt") + extract["prompt"] = tool_parameters.get("prompt") + extract = {k: v for k, v in extract.items() if v not in (None, "")} + payload["extract"] = extract or None - crawl_result = app.scrape_url( - url=tool_parameters["url"], pageOptions=pageOptions, extractorOptions=extractorOptions - ) + payload = {k: v for k, v in payload.items() if v not in (None, "")} - return self.create_json_message(crawl_result) + crawl_result = app.scrape_url(url=tool_parameters["url"], **payload) + markdown_result = crawl_result.get("data", {}).get("markdown", "") + return [self.create_text_message(markdown_result), self.create_json_message(crawl_result)] diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml index 598429de5e..8f1f1348a4 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml @@ -6,8 +6,8 @@ identity: zh_Hans: 单页面抓取 description: human: - en_US: Extract data from a single URL. - zh_Hans: 从单个URL抓取数据。 + en_US: Turn any url into clean data. + zh_Hans: 将任何网址转换为干净的数据。 llm: This tool is designed to scrape URL and output the content in Markdown format. parameters: - name: url @@ -21,7 +21,59 @@ parameters: zh_Hans: 要抓取并提取数据的网站URL。 llm_description: The URL of the website that needs to be crawled. This is a required parameter. form: llm -############## Page Options ####################### +############## Payload ####################### + - name: formats + type: string + label: + en_US: Formats + zh_Hans: 结果的格式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage + zh_Hans: | + 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage + form: form + - name: onlyMainContent + type: boolean + default: false + label: + en_US: only Main Content + zh_Hans: 仅抓取主要内容 + human_description: + en_US: Only return the main content of the page excluding headers, navs, footers, etc. + zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 + form: form + - name: includeTags + type: string + label: + en_US: Include Tags + zh_Hans: 仅抓取这些标签 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + form: form + - name: excludeTags + type: string + label: + en_US: Exclude Tags + zh_Hans: 要移除这些标签 + human_description: + en_US: | + Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + form: form - name: headers type: string label: @@ -36,87 +88,10 @@ parameters: en_US: Please enter an object that can be serialized in JSON zh_Hans: 请输入可以json序列化的对象 form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form - - name: onlyIncludeTags - type: string - label: - en_US: only Include Tags - zh_Hans: 仅抓取这些标签 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - human_description: - en_US: | - Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - form: form - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: removeTags - type: string - label: - en_US: remove Tags - zh_Hans: 要移除这些标签 - human_description: - en_US: | - Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - form: form - - name: replaceAllPathsWithAbsolutePaths - type: boolean - default: false - label: - en_US: All AbsolutePaths - zh_Hans: 使用绝对路径 - human_description: - en_US: Replace all relative paths with absolute paths for images and links. - zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 - form: form - - name: screenshot - type: boolean - default: false - label: - en_US: screenshot - zh_Hans: 截图 - human_description: - en_US: Include a screenshot of the top of the page that you are scraping. - zh_Hans: 提供正在抓取的页面的顶部的截图。 - form: form - name: waitFor type: number min: 0 + default: 0 label: en_US: wait For zh_Hans: 等待时间 @@ -124,57 +99,54 @@ parameters: en_US: Wait x amount of milliseconds for the page to load to fetch content. zh_Hans: 等待x毫秒以使页面加载并获取内容。 form: form + - name: timeout + type: number + min: 0 + default: 30000 + label: + en_US: Timeout + human_description: + en_US: Timeout in milliseconds for the request. + zh_Hans: 请求的超时时间(以毫秒为单位)。 + form: form ############## Extractor Options ####################### - - name: mode - type: select - options: - - value: markdown - label: - en_US: markdown - - value: llm-extraction - label: - en_US: llm-extraction - - value: llm-extraction-from-raw-html - label: - en_US: llm-extraction-from-raw-html - - value: llm-extraction-from-markdown - label: - en_US: llm-extraction-from-markdown - label: - en_US: Extractor Mode - zh_Hans: 提取模式 - human_description: - en_US: | - The extraction mode to use. 'markdown': Returns the scraped markdown content, does not perform LLM extraction. 'llm-extraction': Extracts information from the cleaned and parsed content using LLM. - zh_Hans: 使用的提取模式。“markdown”:返回抓取的markdown内容,不执行LLM提取。“llm-extractioin”:使用LLM按Extractor Schema从内容中提取信息。 - form: form - - name: extractionPrompt - type: string - label: - en_US: Extractor Prompt - zh_Hans: 提取时的提示词 - human_description: - en_US: A prompt describing what information to extract from the page, applicable for LLM extraction modes. - zh_Hans: 当使用LLM提取模式时,用于给LLM描述提取规则。 - form: form - - name: extractionSchema + - name: schema type: string label: en_US: Extractor Schema zh_Hans: 提取时的结构 placeholder: en_US: Please enter an object that can be serialized in JSON + zh_Hans: 请输入可以json序列化的对象 human_description: en_US: | - The schema for the data to be extracted, required only for LLM extraction modes. Example: { + The schema for the data to be extracted. Example: { "type": "object", "properties": {"company_mission": {"type": "string"}}, "required": ["company_mission"] } zh_Hans: | - 当使用LLM提取模式时,使用该结构去提取,示例:{ + 使用该结构去提取,示例:{ "type": "object", "properties": {"company_mission": {"type": "string"}}, "required": ["company_mission"] } form: form + - name: systemPrompt + type: string + label: + en_US: Extractor System Prompt + zh_Hans: 提取时的系统提示词 + human_description: + en_US: The system prompt to use for the extraction. + zh_Hans: 用于提取的系统提示。 + form: form + - name: prompt + type: string + label: + en_US: Extractor Prompt + zh_Hans: 提取时的提示词 + human_description: + en_US: The prompt to use for the extraction without a schema. + zh_Hans: 用于无schema时提取的提示词 + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.py b/api/core/tools/provider/builtin/firecrawl/tools/search.py deleted file mode 100644 index f077e7d8ea..0000000000 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp -from core.tools.tool.builtin_tool import BuiltinTool - - -class SearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - """ - the pageOptions and searchOptions comes from doc here: - https://docs.firecrawl.dev/api-reference/endpoint/search - """ - app = FirecrawlApp( - api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] - ) - pageOptions = {} - pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) - pageOptions["fetchPageContent"] = tool_parameters.get("fetchPageContent", True) - pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) - pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) - searchOptions = {"limit": tool_parameters.get("limit")} - search_result = app.search( - query=tool_parameters["keyword"], pageOptions=pageOptions, searchOptions=searchOptions - ) - - return self.create_json_message(search_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml b/api/core/tools/provider/builtin/firecrawl/tools/search.yaml deleted file mode 100644 index 29df0cfaaa..0000000000 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml +++ /dev/null @@ -1,75 +0,0 @@ -identity: - name: search - author: ahasasjeb - label: - en_US: Search - zh_Hans: 搜索 -description: - human: - en_US: Search, and output in Markdown format - zh_Hans: 搜索,并且以Markdown格式输出 - llm: This tool can perform online searches and convert the results to Markdown format. -parameters: - - name: keyword - type: string - required: true - label: - en_US: keyword - zh_Hans: 关键词 - human_description: - en_US: Input keywords to use Firecrawl API for search. - zh_Hans: 输入关键词即可使用Firecrawl API进行搜索。 - llm_description: Efficiently extract keywords from user text. - form: llm -############## Page Options ####################### - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: fetchPageContent - type: boolean - default: true - label: - en_US: fetch Page Content - zh_Hans: 抓取页面内容 - human_description: - en_US: Fetch the content of each page. If false, defaults to a basic fast serp API. - zh_Hans: 获取每个页面的内容。如果为否,则使用基本的快速搜索结果页面API。 - form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form -############## Search Options ####################### - - name: limit - type: number - min: 0 - label: - en_US: Maximum results - zh_Hans: 最大结果数量 - human_description: - en_US: Maximum number of results. Max is 20 during beta. - zh_Hans: 最大结果数量。在测试阶段,最大为20。 - form: form From 1ab81b49728b43bb5d232aac6a6bfcd37c4c18ab Mon Sep 17 00:00:00 2001 From: xiandan-erizo Date: Fri, 13 Sep 2024 20:21:48 +0800 Subject: [PATCH 07/43] support hunyuan-turbo (#8372) Co-authored-by: sunkesi --- .../hunyuan/llm/_position.yaml | 1 + .../hunyuan/llm/hunyuan-turbo.yaml | 38 +++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml index 2c1b981f85..ca8600a534 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml @@ -2,3 +2,4 @@ - hunyuan-standard - hunyuan-standard-256k - hunyuan-pro +- hunyuan-turbo diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml new file mode 100644 index 0000000000..4837fed4ba --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml @@ -0,0 +1,38 @@ +model: hunyuan-turbo +label: + zh_Hans: hunyuan-turbo + en_US: hunyuan-turbo +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.015' + output: '0.05' + unit: '0.001' + currency: RMB From a1104ab97ee5822f6caac7321089aba6064efb0c Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 13 Sep 2024 22:42:08 +0800 Subject: [PATCH 08/43] chore: refurish python code by applying Pylint linter rules (#8322) --- api/app.py | 2 +- api/commands.py | 4 ++-- api/controllers/console/app/audio.py | 2 +- api/controllers/console/auth/oauth.py | 2 +- .../console/datasets/datasets_document.py | 6 +++--- api/controllers/console/explore/audio.py | 2 +- api/controllers/console/explore/completion.py | 4 ++-- .../console/explore/conversation.py | 10 ++++----- .../console/explore/installed_app.py | 2 +- api/controllers/console/explore/message.py | 4 ++-- api/controllers/console/explore/parameter.py | 2 +- .../console/workspace/workspace.py | 2 +- api/controllers/service_api/app/app.py | 2 +- api/controllers/service_api/app/audio.py | 2 +- api/controllers/service_api/app/completion.py | 4 ++-- .../service_api/app/conversation.py | 6 +++--- api/controllers/service_api/app/message.py | 4 ++-- api/controllers/web/app.py | 2 +- api/controllers/web/audio.py | 2 +- api/controllers/web/completion.py | 4 ++-- api/controllers/web/conversation.py | 10 ++++----- api/controllers/web/message.py | 4 ++-- .../agent/output_parser/cot_output_parser.py | 4 ++-- .../app/app_config/base_app_config_manager.py | 2 +- .../easy_ui_based_app/agent/manager.py | 6 +++--- .../easy_ui_based_app/dataset/manager.py | 2 +- .../easy_ui_based_app/variables/manager.py | 6 +++--- .../features/file_upload/manager.py | 4 ++-- api/core/app/apps/advanced_chat/app_runner.py | 4 ++-- .../base_app_generate_response_converter.py | 2 +- api/core/app/apps/base_app_generator.py | 6 +++--- api/core/app/apps/base_app_queue_manager.py | 4 ++-- .../app/apps/message_based_app_generator.py | 6 +++--- api/core/app/apps/workflow/app_runner.py | 4 ++-- .../annotation_reply/annotation_reply.py | 2 +- .../easy_ui_based_generate_task_pipeline.py | 2 +- .../task_pipeline/workflow_cycle_manage.py | 4 ++-- .../index_tool_callback_handler.py | 2 +- api/core/indexing_runner.py | 2 +- api/core/memory/token_buffer_memory.py | 2 +- .../model_runtime/entities/model_entities.py | 12 +++++------ .../model_providers/anthropic/llm/llm.py | 2 +- .../model_providers/azure_openai/tts/tts.py | 4 ++-- .../model_providers/bedrock/llm/llm.py | 10 ++++----- .../bedrock/text_embedding/text_embedding.py | 8 +++---- .../model_providers/google/llm/llm.py | 4 ++-- .../huggingface_hub/llm/llm.py | 4 ++-- .../huggingface_tei/tei_helper.py | 2 +- .../minimax/llm/chat_completion.py | 4 ++-- .../minimax/llm/chat_completion_pro.py | 4 ++-- .../minimax/text_embedding/text_embedding.py | 2 +- .../model_providers/openai/llm/llm.py | 2 +- .../model_providers/openai/tts/tts.py | 4 ++-- .../model_providers/openrouter/llm/llm.py | 1 - .../model_providers/replicate/llm/llm.py | 2 +- .../text_embedding/text_embedding.py | 4 ++-- .../model_providers/tongyi/llm/llm.py | 8 +++---- .../model_providers/upstage/llm/llm.py | 2 +- .../model_providers/vertex_ai/llm/llm.py | 4 ++-- .../legacy/volc_sdk/base/auth.py | 3 +-- .../model_providers/wenxin/llm/llm.py | 2 +- .../xinference/xinference_helper.py | 2 +- .../model_providers/zhipuai/llm/llm.py | 21 ++++++++----------- .../zhipuai_sdk/types/fine_tuning/__init__.py | 5 ++--- .../schema_validators/common_validator.py | 4 ++-- .../vdb/elasticsearch/elasticsearch_vector.py | 4 ++-- .../datasource/vdb/myscale/myscale_vector.py | 4 ++-- .../rag/datasource/vdb/oracle/oraclevector.py | 10 +-------- api/core/rag/extractor/extract_processor.py | 12 +++++------ .../rag/extractor/firecrawl/firecrawl_app.py | 2 +- api/core/rag/extractor/notion_extractor.py | 4 ++-- api/core/rag/retrieval/dataset_retrieval.py | 2 +- api/core/rag/splitter/text_splitter.py | 2 +- api/core/tools/provider/app_tool_provider.py | 2 +- .../provider/builtin/aippt/tools/aippt.py | 4 ++-- .../builtin/azuredalle/tools/dalle3.py | 4 ++-- .../builtin/code/tools/simple_code.py | 2 +- .../builtin/cogview/tools/cogview3.py | 4 ++-- .../provider/builtin/dalle/tools/dalle3.py | 4 ++-- .../builtin/hap/tools/get_worksheet_fields.py | 4 ++-- .../hap/tools/list_worksheet_records.py | 6 +++--- .../novitaai/tools/novitaai_modelquery.py | 2 +- .../builtin/searchapi/tools/google.py | 2 +- .../builtin/searchapi/tools/google_jobs.py | 2 +- .../builtin/searchapi/tools/google_news.py | 2 +- .../searchapi/tools/youtube_transcripts.py | 2 +- .../provider/builtin/spider/spiderApp.py | 2 +- .../builtin/stability/tools/text2image.py | 2 +- .../provider/builtin/vanna/tools/vanna.py | 2 +- .../tools/provider/builtin_tool_provider.py | 2 +- api/core/tools/provider/tool_provider.py | 18 ++++++++-------- api/core/tools/tool/api_tool.py | 8 +++---- api/core/tools/tool_engine.py | 12 +++-------- api/core/tools/utils/message_transformer.py | 2 +- api/core/tools/utils/parser.py | 2 +- api/core/tools/utils/web_reader_tool.py | 2 +- .../entities/runtime_route_state.py | 2 +- .../answer/answer_stream_generate_router.py | 4 ++-- .../nodes/end/end_stream_generate_router.py | 4 ++-- .../nodes/http_request/http_executor.py | 8 +++---- .../nodes/parameter_extractor/entities.py | 4 ++-- .../parameter_extractor_node.py | 10 ++++----- api/core/workflow/nodes/tool/tool_node.py | 5 +---- api/libs/oauth_data_source.py | 4 ++-- api/libs/rsa.py | 2 +- api/models/dataset.py | 6 +++--- api/models/model.py | 6 +++--- api/pyproject.toml | 14 +++++++++---- api/services/account_service.py | 6 +++--- api/services/app_dsl_service.py | 8 +++---- api/services/app_service.py | 2 +- api/services/audio_service.py | 4 ++-- api/services/auth/firecrawl.py | 2 +- api/services/dataset_service.py | 2 +- api/services/tools/tools_transform_service.py | 2 +- api/services/workflow_service.py | 2 +- api/tasks/recover_document_indexing_task.py | 2 +- .../model_runtime/__mock/google.py | 4 +--- .../model_runtime/__mock/openai_chat.py | 4 ++-- .../model_runtime/__mock/openai_completion.py | 2 +- .../model_runtime/__mock/openai_embeddings.py | 2 +- .../model_runtime/__mock/openai_moderation.py | 2 +- .../__mock/openai_speech2text.py | 2 +- .../model_runtime/__mock/xinference.py | 4 ++-- .../nodes/test_parameter_extractor.py | 2 +- .../graph_engine/test_graph_engine.py | 6 +++--- 126 files changed, 253 insertions(+), 272 deletions(-) diff --git a/api/app.py b/api/app.py index ad219ca0d6..91a49337fc 100644 --- a/api/app.py +++ b/api/app.py @@ -164,7 +164,7 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in ["console", "inner_api"]: + if request.blueprint not in {"console", "inner_api"}: return None # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get("Authorization", "") diff --git a/api/commands.py b/api/commands.py index 887270b43e..3a6b4963cf 100644 --- a/api/commands.py +++ b/api/commands.py @@ -140,9 +140,9 @@ def reset_encrypt_key_pair(): @click.command("vdb-migrate", help="migrate vector db.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ["knowledge", "all"]: + if scope in {"knowledge", "all"}: migrate_knowledge_vector_database() - if scope in ["annotation", "all"]: + if scope in {"annotation", "all"}: migrate_annotation_vector_database() diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 7332758e83..c1ef05a488 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 1df0f5de9d..ad0c0580ae 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -71,7 +71,7 @@ class OAuthCallback(Resource): account = _generate_account(provider, user_info) # Check account status - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: return {"error": "Account is banned or closed."}, 403 if account.status == AccountStatus.PENDING.value: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 076f3cd44d..829ef11e52 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ["completed", "error"]: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule @@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ["completed", "error"]: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info @@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource): db.session.commit() elif action == "resume": - if document.indexing_status not in ["paused", "error"]: + if document.indexing_status not in {"paused", "error"}: raise InvalidActionError("Document not in paused or error state.") document.paused_by = None diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 2eb7e04490..9690677f61 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index c039e8bca5..f464692098 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 2918024b64..6f9d7769b9 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource): def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource): def delete(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource): def post(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3f1e64a247..408afc33a0 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource): "app_owner_tenant_id": installed_app.app_owner_tenant_id, "is_pinned": installed_app.is_pinned, "last_used_at": installed_app.last_used_at, - "editable": current_user.role in ["owner", "admin"], + "editable": current_user.role in {"owner", "admin"}, "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index f5eb185172..0e0238556c 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() message_id = str(message_id) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index ad55b04043..aab7dd7888 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource): """Retrieve app parameters.""" app_model = installed_app.app - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 623f0b8b74..af3ebc099b 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource): raise TooManyFilesError() extension = file.filename.split(".")[-1] - if extension.lower() not in ["svg", "png"]: + if extension.lower() not in {"svg", "png"}: raise UnsupportedFileTypeError() try: diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecc2d73deb..f7c091217b 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -42,7 +42,7 @@ class AppParameterApi(Resource): @marshal_with(parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 8d8ca8d78c..5db4163647 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -79,7 +79,7 @@ class TextApi(Resource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index f1771baf31..8d8e356c4c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -96,7 +96,7 @@ class ChatApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -144,7 +144,7 @@ class ChatStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 734027a1c5..527ef4ecd3 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -18,7 +18,7 @@ class ConversationApi(Resource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -52,7 +52,7 @@ class ConversationDetailApi(Resource): @marshal_with(simple_conversation_fields) def delete(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -69,7 +69,7 @@ class ConversationRenameApi(Resource): @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index b39aaf7dd8..e54e6f4903 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -76,7 +76,7 @@ class MessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource): def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() try: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index aabca93338..20b4e4674c 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource): @marshal_with(parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 49c467dbe1..23550efe2e 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -78,7 +78,7 @@ class TextApi(WebApiResource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0837eedfb0..115492b796 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): def post(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -136,7 +136,7 @@ class ChatApi(WebApiResource): class ChatStopApi(WebApiResource): def post(self, app_model, end_user, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 6bbfa94c27..c3b0cd4f44 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource): class ConversationApi(WebApiResource): def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource): @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource): class ConversationPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource): class ConversationUnPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 56aaaa930a..0d4047f4ef 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -78,7 +78,7 @@ class MessageListApi(WebApiResource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource): class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotCompletionAppError() message_id = str(message_id) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 1a161677dd..d04e38777a 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -90,7 +90,7 @@ class CotAgentOutputParser: if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: - if last_character not in ["\n", " ", ""]: + if last_character not in {"\n", " ", ""}: index += steps yield delta continue @@ -117,7 +117,7 @@ class CotAgentOutputParser: action_idx = 0 if delta.lower() == thought_str[thought_idx] and thought_idx == 0: - if last_character not in ["\n", " ", ""]: + if last_character not in {"\n", " ", ""}: index += steps yield delta continue diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 0fd2a779a4..24d80f9cdd 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -29,7 +29,7 @@ class BaseAppConfigManager: additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.file_upload = FileUploadConfigManager.convert( - config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] + config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT} ) additional_features.opening_statement, additional_features.suggested_questions = ( diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 6e89f19508..f503543d7b 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -18,7 +18,7 @@ class AgentConfigManager: if agent_strategy == "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == "cot" or agent_strategy == "react": + elif agent_strategy in {"cot", "react"}: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT else: # old configs, try to detect default strategy @@ -43,10 +43,10 @@ class AgentConfigManager: agent_tools.append(AgentToolEntity(**agent_tool_properties)) - if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { "react_router", "router", - ]: + }: agent_prompt = agent_dict.get("prompt", None) or {} # check model mode model_mode = config.get("model", {}).get("mode", "completion") diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index ff131b62e2..a22395b8e3 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -167,7 +167,7 @@ class DatasetConfigManager: config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value has_datasets = False - if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: + if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: for tool in config["agent_mode"]["tools"]: key = list(tool.keys())[0] if key == "dataset": diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index e70522f21d..a1bfde3208 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -42,12 +42,12 @@ class BasicVariablesConfigManager: variable=variable["variable"], type=variable["type"], config=variable["config"] ) ) - elif variable_type in [ + elif variable_type in { VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH, VariableEntityType.NUMBER, VariableEntityType.SELECT, - ]: + }: variable = variables[variable_type] variable_entities.append( VariableEntity( @@ -97,7 +97,7 @@ class BasicVariablesConfigManager: variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 5f7fc99151..7a275cb532 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -54,14 +54,14 @@ class FileUploadConfigManager: if is_vision: detail = config["file_upload"]["image"]["detail"] - if detail not in ["high", "low"]: + if detail not in {"high", "low"}: raise ValueError("detail must be in ['high', 'low']") transfer_methods = config["file_upload"]["image"]["transfer_methods"] if not isinstance(transfer_methods, list): raise ValueError("transfer_methods must be of list type") for method in transfer_methods: - if method not in ["remote_url", "local_file"]: + if method not in {"remote_url", "local_file"}: raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") return config, ["file_upload"] diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c4cdba6441..1bca1e1b71 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): raise ValueError("Workflow not initialized") user_id = None - if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else UserFrom.END_USER ), invoke_from=self.application_generate_entity.invoke_from, diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 73025d99d0..c6855ac854 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC): def convert( cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom ) -> dict[str, Any] | Generator[str, Any, None]: - if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index ce6f7d4338..15be7000fc 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -22,11 +22,11 @@ class BaseAppGenerator: return var.default or "" if ( var.type - in ( + in { VariableEntityType.TEXT_INPUT, VariableEntityType.SELECT, VariableEntityType.PARAGRAPH, - ) + } and user_input_value and not isinstance(user_input_value, str) ): @@ -44,7 +44,7 @@ class BaseAppGenerator: options = var.options or [] if user_input_value not in options: raise ValueError(f"{var.variable} in input form must be one of the following: {options}") - elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): + elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: if var.max_length and user_input_value and len(user_input_value) > var.max_length: raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index f3c3199354..4c4d282e99 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -32,7 +32,7 @@ class AppQueueManager: self._user_id = user_id self._invoke_from = invoke_from - user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" redis_client.setex( AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" ) @@ -118,7 +118,7 @@ class AppQueueManager: if result is None: return - user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" if result.decode("utf-8") != f"{user_prefix}-{user_id}": return diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index f629c5c8b7..c4db95cbd0 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): # get from source end_user_id = None account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: from_source = "api" end_user_id = application_generate_entity.user_id else: @@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): model_provider = application_generate_entity.model_conf.provider model_id = application_generate_entity.model_conf.model override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in { AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION, - ]: + }: override_model_configs = app_config.app_model_config_dict # get conversation introduction diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 81c8463dd5..22ec228fa7 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = cast(WorkflowAppConfig, app_config) user_id = None - if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else UserFrom.END_USER ), invoke_from=self.application_generate_entity.invoke_from, diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 2e37a126c3..77b6bb554c 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -63,7 +63,7 @@ class AnnotationReplyFeature: score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: - if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: + if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: from_source = "api" else: from_source = "console" diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 659503301e..8f834b6458 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] + is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras, ) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index a030d5dcbf..f10189798f 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -383,7 +383,7 @@ class WorkflowCycleManage: :param workflow_node_execution: workflow node execution :return: """ - if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None response = NodeStartStreamResponse( @@ -430,7 +430,7 @@ class WorkflowCycleManage: :param workflow_node_execution: workflow node execution :return: """ - if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None return NodeFinishStreamResponse( diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 6d5393ce5c..7cf472d984 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler: source="app", source_app_id=self._app_id, created_by_role=( - "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" ), created_by=self._user_id, ) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index eeb1dbfda0..af20df41b1 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -292,7 +292,7 @@ class IndexingRunner: self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict ) -> list[Document]: # load file - if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: + if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: return [] data_source_info = dataset_document.data_source_info_dict diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index a14d237a12..d3185c3b11 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -52,7 +52,7 @@ class TokenBufferMemory: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index d898ef1490..52ea787c3a 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -27,17 +27,17 @@ class ModelType(Enum): :return: model type """ - if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: + if origin_model_type in {"text-generation", cls.LLM.value}: return cls.LLM - elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: return cls.TEXT_EMBEDDING - elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: + elif origin_model_type in {"reranking", cls.RERANK.value}: return cls.RERANK - elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: return cls.SPEECH2TEXT - elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: + elif origin_model_type in {"tts", cls.TTS.value}: return cls.TTS - elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type in {"text2img", cls.TEXT2IMG.value}: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index ff741e0240..46e1b415b8 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 8db044b24d..af178703a0 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): for i in range(len(sentences)) ] for future in futures: - yield from future.result().__enter__().iter_bytes(1024) + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: response = client.audio.speech.with_streaming_response.create( model=model, voice=voice, response_format="mp3", input=content_text.strip() ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index c34c20ced3..06a8606901 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): base64_data = data_split[1] image_content = base64.b64decode(base64_data) - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" @@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in [ + elif error_code in { "ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException", - ]: + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 2d898e3aaa..251170d1ae 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in [ + elif error_code in { "ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException", - ]: + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b10d0edba3..3fc6787a44 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -6,10 +6,10 @@ from collections.abc import Generator from typing import Optional, Union, cast import google.ai.generativelanguage as glm -import google.api_core.exceptions as exceptions import google.generativeai as genai -import google.generativeai.client as client import requests +from google.api_core import exceptions +from google.generativeai import client from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part from PIL import Image diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 48ab477c50..9d29237fdd 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel if "huggingfacehub_api_type" not in credentials: raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): + if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}: raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") if "huggingfacehub_api_token" not in credentials: @@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel credentials["huggingfacehub_api_token"], model ) - if credentials["task_type"] not in ("text2text-generation", "text-generation"): + if credentials["task_type"] not in {"text2text-generation", "text-generation"}: raise CredentialsValidateFailedError( "Huggingface Hub Task Type must be one of text2text-generation, text-generation." ) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 288637495f..81ab249214 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -75,7 +75,7 @@ class TeiHelper: if len(model_type.keys()) < 1: raise RuntimeError("model_type is empty") model_type = list(model_type.keys())[0] - if model_type not in ["embedding", "reranker"]: + if model_type not in {"embedding", "reranker"}: raise RuntimeError(f"invalid model_type: {model_type}") max_input_length = response_json.get("max_input_length", 512) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 96f99c8929..88cc0e8e0f 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -100,9 +100,9 @@ class MinimaxChatCompletion: return self._handle_chat_generate_response(response) def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 0a2a67a56d..8b8fdbb6bd 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -105,9 +105,9 @@ class MinimaxChatCompletionPro: return self._handle_chat_generate_response(response) def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 02a53708be..76fd1342bd 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001: + if code in {1000, 1001}: raise InternalServerError(msg) elif code == 1002: raise RateLimitReachedError(msg) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 60d69c6e47..d42fce528a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model_mode = self.get_model_mode(base_model, credentials) # transform response format - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index b50b43199f..a14c91639b 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): for i in range(len(sentences)) ] for future in futures: - yield from future.result().__enter__().iter_bytes(1024) + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: response = client.audio.speech.with_streaming_response.create( model=model, voice=voice, response_format="mp3", input=content_text.strip() ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index 71b5745f7d..b6bb249a04 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): credentials["endpoint_url"] = "https://openrouter.ai/api/v1" credentials["mode"] = self.get_model_mode(model).value credentials["function_calling_type"] = "tool_call" - return def _invoke( self, diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index daef8949fb..3641b35dc0 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ) for key, value in input_properties: - if key not in ["system_prompt", "prompt"] and "stop" not in key: + if key not in {"system_prompt", "prompt"} and "stop" not in key: value_type = value.get("type") if not value_type: diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index f6b7754d74..71b6fb99c4 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): ) for input_property in input_properties: - if input_property[0] in ("text", "texts", "inputs"): + if input_property[0] in {"text", "texts", "inputs"}: text_input_key = input_property[0] return text_input_key @@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): def _generate_embeddings_by_text_input_key( client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] ) -> list[list[float]]: - if text_input_key in ("text", "inputs"): + if text_input_key in {"text", "inputs"}: embeddings = [] for text in texts: result = client.run(replicate_model_version, input={text_input_key: text}) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index cd7718361f..1d4eba6668 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: """ - if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: model = model.replace("-chat", "") if model == "farui-plus": model = "qwen-farui-plus" @@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): mode = self.get_model_mode(model, credentials) - if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: model = model.replace("-chat", "") extra_model_kwargs = {} @@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - if response.status_code != 200 and response.status_code != HTTPStatus.OK: + if response.status_code not in {200, HTTPStatus.OK}: raise ServiceUnavailableError(response.message) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( @@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): full_text = "" tool_calls = [] for index, response in enumerate(responses): - if response.status_code != 200 and response.status_code != HTTPStatus.OK: + if response.status_code not in {200, HTTPStatus.OK}: raise ServiceUnavailableError( f"Failed to invoke model {model}, status code: {response.status_code}, " f"message: {response.message}" diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index 74524e81e2..a18ee90624 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ Code block mode wrapper for invoking large language model """ - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] self._transform_chat_json_prompts( model=model, diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index dad5002f35..da69b7cdf3 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -5,7 +5,6 @@ import logging from collections.abc import Generator from typing import Optional, Union, cast -import google.api_core.exceptions as exceptions import google.auth.transport.requests import vertexai.generative_models as glm from anthropic import AnthropicVertex, Stream @@ -17,6 +16,7 @@ from anthropic.types import ( MessageStopEvent, MessageStreamEvent, ) +from google.api_core import exceptions from google.cloud import aiplatform from google.oauth2 import service_account from PIL import Image @@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 97c77de8d3..c22bf8e76d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -96,7 +96,6 @@ class Signer: signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) - return @staticmethod def hashed_canonical_request_v4(request, meta): @@ -105,7 +104,7 @@ class Signer: signed_headers = {} for key in request.headers: - if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): + if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"): signed_headers[key.lower()] = request.headers[key] if "host" in signed_headers: diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index ec3556f7da..f7c160b6b4 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): """ Code block mode wrapper for invoking large language model """ - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: response_format = model_parameters["response_format"] stop = stop or [] self._transform_json_prompts( diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 1e05da9c56..619ee1492a 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -103,7 +103,7 @@ class XinferenceHelper: model_handle_type = "embedding" elif response_json.get("model_type") == "audio": model_handle_type = "audio" - if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: + if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}: model_ability.append("text-to-audio") else: model_ability.append("audio-to-text") diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index f76e51fee9..ea331701ab 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() - if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model not in ("glm-4v", "glm-4v-plus"): + if model not in {"glm-4v", "glm-4v-plus"}: # not support list message continue # get image and @@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: - if ( - copy_prompt_message.role == PromptMessageRole.USER - or copy_prompt_message.role == PromptMessageRole.TOOL - ): + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}: new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.SYSTEM: new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) @@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): else: new_prompt_messages.append(copy_prompt_message) - if model == "glm-4v" or model == "glm-4v-plus": + if model in {"glm-4v", "glm-4v-plus"}: params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: params = {"model": model, "messages": [], **model_parameters} @@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if ( - prompt_message.role == PromptMessageRole.SYSTEM - or prompt_message.role == PromptMessageRole.TOOL - or prompt_message.role == PromptMessageRole.USER - ): + if prompt_message.role in { + PromptMessageRole.SYSTEM, + PromptMessageRole.TOOL, + PromptMessageRole.USER, + }: if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": params["messages"][-1]["content"] += "\n\n" + prompt_message.content else: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py index af0991892e..416f516ef7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py @@ -1,5 +1,4 @@ from __future__ import annotations -from .fine_tuning_job import FineTuningJob as FineTuningJob -from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob -from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent +from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob +from .fine_tuning_job_event import FineTuningJobEvent diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index c05edb72e3..029ec1a581 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -75,7 +75,7 @@ class CommonValidator: if not isinstance(value, str): raise ValueError(f"Variable {credential_form_schema.variable} should be string") - if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: + if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: @@ -83,7 +83,7 @@ class CommonValidator: if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ["true", "false"]: + if value.lower() not in {"true", "false"}: raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") value = True if value.lower() == "true" else False diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f8300cc271..8d57855120 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector): def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: parsed_url = urlparse(config.host) - if parsed_url.scheme in ["http", "https"]: + if parsed_url.scheme in {"http", "https"}: hosts = f"{config.host}:{config.port}" else: hosts = f"http://{config.host}:{config.port}" @@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector): return uuids def text_exists(self, id: str) -> bool: - return self._client.exists(index=self._collection_name, id=id).__bool__() + return bool(self._client.exists(index=self._collection_name, id=id)) def delete_by_ids(self, ids: list[str]) -> None: for id in ids: diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 1bd5bcd3e4..2320a69a30 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -35,7 +35,7 @@ class MyScaleVector(BaseVector): super().__init__(collection_name) self._config = config self._metric = metric - self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC + self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC self._client = get_client( host=config.host, port=config.port, @@ -92,7 +92,7 @@ class MyScaleVector(BaseVector): @staticmethod def escape_str(value: Any) -> str: - return "".join(" " if c in ("\\", "'") else c for c in str(value)) + return "".join(" " if c in {"\\", "'"} else c for c in str(value)) def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index b974fa80a4..77ec45b4d3 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -223,15 +223,7 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if ( - pos == "nr" - or pos == "Ng" - or pos == "eng" - or pos == "nz" - or pos == "n" - or pos == "ORG" - or pos == "v" - ): # nr: 人名, ns: 地名, nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 3181656f59..fe7eaa32e6 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -98,17 +98,17 @@ class ExtractProcessor: unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY if etl_type == "Unstructured": - if file_extension == ".xlsx" or file_extension == ".xls": + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in [".md", ".markdown"]: + elif file_extension in {".md", ".markdown"}: extractor = ( UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic else MarkdownExtractor(file_path, autodetect_encoding=True) ) - elif file_extension in [".htm", ".html"]: + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) @@ -134,13 +134,13 @@ class ExtractProcessor: else TextExtractor(file_path, autodetect_encoding=True) ) else: - if file_extension == ".xlsx" or file_extension == ".xls": + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in [".md", ".markdown"]: + elif file_extension in {".md", ".markdown"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in [".htm", ".html"]: + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 054ce5f4b2..17c2087a0a 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -32,7 +32,7 @@ class FirecrawlApp: else: raise Exception(f'Failed to scrape URL. Error: {response["error"]}') - elif response.status_code in [402, 409, 500]: + elif response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 0ee24983a4..87a4ce08bf 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor): multi_select_list = property_value[type] for multi_select in multi_select_list: value.append(multi_select["name"]) - elif type == "rich_text" or type == "title": + elif type in {"rich_text", "title"}: if len(property_value[type]) > 0: value = property_value[type][0]["plain_text"] else: value = "" - elif type == "select" or type == "status": + elif type in {"select", "status"}: if property_value[type]: value = property_value[type]["name"] else: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 12868d6ae4..124c58f0fe 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -115,7 +115,7 @@ class DatasetRetrieval: available_datasets.append(dataset) all_documents = [] - user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( app_id, diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 161c36607d..7dd62f8de1 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -35,7 +35,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if (s != "" and s != "\n")] + return [s for s in splits if (s not in {"", "\n"})] class TextSplitter(BaseDocumentTransformer, ABC): diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 01544d7e56..09f328cd1f 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController): label = input_form[form_type]["label"] variable_name = input_form[form_type]["variable_name"] options = input_form[form_type].get("options", []) - if form_type == "paragraph" or form_type == "text-input": + if form_type in {"paragraph", "text-input"}: tool["parameters"].append( ToolParameter( name=variable_name, diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index a2d69fbcd1..dd9371f70d 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool): pass elif event == "close": break - elif event == "error" or event == "filter": + elif event in {"error", "filter"}: raise Exception(f"Failed to generate outline: {data}") return outline @@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool): pass elif event == "close": break - elif event == "error" or event == "filter": + elif event in {"error", "filter"}: raise Exception(f"Failed to generate content: {data}") return content diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 7462824be1..cfa3cfb092 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool): n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # set extra body seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index 017fe548f7..632c9fc7f1 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool): language = tool_parameters.get("language", CodeLanguage.PYTHON3) code = tool_parameters.get("code", "") - if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: + if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}: raise ValueError(f"Only python3 and javascript are supported, not {language}") result = CodeExecutor.execute_code(language, "", code) diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 9776bd7dd1..9039708588 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool): n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # set extra body seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index bcfa2212b6..a8c647d71e 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool): n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # call openapi dalle3 diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 40e1af043b..79e5889eae 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool): def _extract_options(self, control: dict) -> list: options = [] - if control["type"] in [9, 10, 11]: + if control["type"] in {9, 10, 11}: options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) - elif control["type"] in [28, 36]: + elif control["type"] in {28, 36}: itemnames = control["advancedSetting"].get("itemnames") if itemnames and itemnames.startswith("[{"): try: diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index 171895a306..44c7e52307 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool): type_id = field.get("typeId") if type_id == 10: value = value if isinstance(value, str) else "、".join(value) - elif type_id in [28, 36]: + elif type_id in {28, 36}: value = field.get("options", {}).get(value, value) - elif type_id in [26, 27, 48, 14]: + elif type_id in {26, 27, 48, 14}: value = self.process_value(value) - elif type_id in [35, 29]: + elif type_id in {35, 29}: value = self.parse_cascade_or_associated(field, value) elif type_id == 40: value = self.parse_location(value) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index 9ca14b327c..a200ee8123 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool): models_data=[], headers=headers, params=params, - recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"), + recursive=result_type not in {"first sd_name", "first name sd_name pair"}, ) result_str = "" diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index 16ae14549d..17e2978194 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -38,7 +38,7 @@ class SearchAPI: return { "engine": "google", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index d29cb0ae3f..c478bc108b 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -38,7 +38,7 @@ class SearchAPI: return { "engine": "google_jobs", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 8458c8c958..562bc01964 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -38,7 +38,7 @@ class SearchAPI: return { "engine": "google_news", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 09725cf8a2..1867cf7be7 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -38,7 +38,7 @@ class SearchAPI: "engine": "youtube_transcripts", "video_id": video_id, "lang": language or "en", - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index 3972e560c4..4bc446a1a0 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -214,7 +214,7 @@ class Spider: return requests.delete(url, headers=headers, stream=stream) def _handle_error(self, response, action): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 9f415ceb55..6bcf315484 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): model = tool_parameters.get("model", "core") - if model in ["sd3", "sd3-turbo"]: + if model in {"sd3", "sd3-turbo"}: payload["model"] = tool_parameters.get("model") if model != "sd3-turbo": diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index a6efb0f79a..c90d766e48 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -38,7 +38,7 @@ class VannaTool(BuiltinTool): vn = VannaDefault(model=model, api_key=api_key) db_type = tool_parameters.get("db_type", "") - if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: + if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: if not db_name: return self.create_text_message("Please input database name") if not username: diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 6b64dd1b4e..ff022812ef 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file class BuiltinToolProviderController(ToolProviderController): def __init__(self, **data: Any) -> None: - if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: + if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}: super().__init__(**data) return diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index f4008eedce..7ba9dda179 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC): # check type credential_schema = credentials_need_to_validate[credential_name] - if ( - credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT - or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT - ): + if credential_schema in { + ToolProviderCredentials.CredentialsType.SECRET_INPUT, + ToolProviderCredentials.CredentialsType.TEXT_INPUT, + }: if not isinstance(credentials[credential_name], str): raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") @@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC): if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if ( - credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT - or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT - or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT - ): + if credential_schema.type in { + ToolProviderCredentials.CredentialsType.SECRET_INPUT, + ToolProviderCredentials.CredentialsType.TEXT_INPUT, + ToolProviderCredentials.CredentialsType.SELECT, + }: default_value = str(default_value) credentials[credential_name] = default_value diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index bf336b48f3..c779d704c3 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -5,7 +5,7 @@ from urllib.parse import urlencode import httpx -import core.helper.ssrf_proxy as ssrf_proxy +from core.helper import ssrf_proxy from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError @@ -191,7 +191,7 @@ class ApiTool(Tool): else: body = body - if method in ("get", "head", "post", "put", "delete", "patch"): + if method in {"get", "head", "post", "put", "delete", "patch"}: response = getattr(ssrf_proxy, method)( url, params=params, @@ -224,9 +224,9 @@ class ApiTool(Tool): elif option["type"] == "string": return str(value) elif option["type"] == "boolean": - if str(value).lower() in ["true", "1"]: + if str(value).lower() in {"true", "1"}: return True - elif str(value).lower() in ["false", "0"]: + elif str(value).lower() in {"false", "0"}: return False else: continue # Not a boolean, try next option diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 645f0861fa..9912114dd6 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -189,10 +189,7 @@ class ToolEngine: result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: result += f"result link: {response.message}. please tell user to check it." - elif ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, you do not need to create it," " just tell the user to check it now." @@ -212,10 +209,7 @@ class ToolEngine: result = [] for response in tool_response: - if ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: mimetype = None if response.meta.get("mime_type"): mimetype = response.meta.get("mime_type") @@ -297,7 +291,7 @@ class ToolEngine: belongs_to="assistant", url=message.url, upload_file_id=None, - created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"), + created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"), created_by=user_id, ) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bf040d91d3..3cfab207ba 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -19,7 +19,7 @@ class ToolFileMessageTransformer: result = [] for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK: + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: result.append(message) elif message.type == ToolInvokeMessage.MessageType.IMAGE: # try to download image diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 210b84b29a..9ead4f8e5c 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser: elif "schema" in parameter and "type" in parameter["schema"]: typ = parameter["schema"]["type"] - if typ == "integer" or typ == "number": + if typ in {"integer", "number"}: return ToolParameter.ToolParameterType.NUMBER elif typ == "boolean": return ToolParameter.ToolParameterType.BOOLEAN diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index e57cae9f16..1ced7d0488 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -313,7 +313,7 @@ def normalize_whitespace(text): def is_leaf(element): - return element.name in ["p", "li"] + return element.name in {"p", "li"} def is_text(element): diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 8fc8047426..bb24b51112 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -51,7 +51,7 @@ class RouteNodeState(BaseModel): :param run_result: run result """ - if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: raise Exception(f"Route state {self.id} already finished") if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index e31a1479a8..5e6de8fb15 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -148,11 +148,11 @@ class AnswerStreamGeneratorRouter: for edge in reverse_edges: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in ( + if source_node_type in { NodeType.ANSWER.value, NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER.value, - ): + }: answer_dependencies[answer_node_id].append(source_node_id) else: cls._recursive_fetch_answer_dependencies( diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index a38d982393..9a7d2ecde3 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -136,10 +136,10 @@ class EndStreamGeneratorRouter: for edge in reverse_edges: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in ( + if source_node_type in { NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER, - ): + }: end_dependencies[end_node_id].append(source_node_id) else: cls._recursive_fetch_end_dependencies( diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 49102dc3ab..f8ab4e3132 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -6,8 +6,8 @@ from urllib.parse import urlencode import httpx -import core.helper.ssrf_proxy as ssrf_proxy from configs import dify_config +from core.helper import ssrf_proxy from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.http_request.entities import ( @@ -176,7 +176,7 @@ class HttpExecutor: elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: self.headers["Content-Type"] = "application/x-www-form-urlencoded" - if node_data.body.type in ["form-data", "x-www-form-urlencoded"]: + if node_data.body.type in {"form-data", "x-www-form-urlencoded"}: body = self._to_dict(body_data) if node_data.body.type == "form-data": @@ -187,7 +187,7 @@ class HttpExecutor: self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" else: self.body = urlencode(body) - elif node_data.body.type in ["json", "raw-text"]: + elif node_data.body.type in {"json", "raw-text"}: self.body = body_data elif node_data.body.type == "none": self.body = "" @@ -258,7 +258,7 @@ class HttpExecutor: "follow_redirects": True, } - if self.method in ("get", "head", "post", "put", "delete", "patch"): + if self.method in {"get", "head", "post", "put", "delete", "patch"}: response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: raise ValueError(f"Invalid http method {self.method}") diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 802ed31e27..5697d7c049 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -33,7 +33,7 @@ class ParameterConfig(BaseModel): def validate_name(cls, value) -> str: if not value: raise ValueError("Parameter name is required") - if value in ["__reason", "__is_success"]: + if value in {"__reason", "__is_success"}: raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return value @@ -66,7 +66,7 @@ class ParameterExtractorNodeData(BaseNodeData): for parameter in self.parameters: parameter_schema = {"description": parameter.description} - if parameter.type in ["string", "select"]: + if parameter.type in {"string", "select"}: parameter_schema["type"] = "string" elif parameter.type.startswith("array"): parameter_schema["type"] = "array" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 131d26b19e..a6454bd1cd 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -467,7 +467,7 @@ class ParameterExtractorNode(LLMNode): # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') # elif isinstance(result[parameter.name], int): # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] elif parameter.type.startswith("array"): @@ -498,7 +498,7 @@ class ParameterExtractorNode(LLMNode): transformed_result[parameter.name] = 0 elif parameter.type == "bool": transformed_result[parameter.name] = False - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: transformed_result[parameter.name] = "" elif parameter.type.startswith("array"): transformed_result[parameter.name] = [] @@ -516,9 +516,9 @@ class ParameterExtractorNode(LLMNode): """ stack = [] for i, c in enumerate(text): - if c == "{" or c == "[": + if c in {"{", "["}: stack.append(c) - elif c == "}" or c == "]": + elif c in {"}", "]"}: # check if stack is empty if not stack: return text[:i] @@ -560,7 +560,7 @@ class ParameterExtractorNode(LLMNode): result[parameter.name] = 0 elif parameter.type == "bool": result[parameter.name] = False - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: result[parameter.name] = "" return result diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index e55adfc1f4..3b86b29cf8 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -163,10 +163,7 @@ class ToolNode(BaseNode): result = [] for response in tool_response: - if ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: url = response.message ext = path.splitext(url)[1] mimetype = response.meta.get("mime_type", "image/jpeg") diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 6da1a6d39b..05a73b09b7 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -158,7 +158,7 @@ class NotionOAuth(OAuthDataSource): page_icon = page_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: @@ -191,7 +191,7 @@ class NotionOAuth(OAuthDataSource): page_icon = database_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: diff --git a/api/libs/rsa.py b/api/libs/rsa.py index a578bf3e56..637bcc4a1d 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -4,9 +4,9 @@ from Crypto.Cipher import AES from Crypto.PublicKey import RSA from Crypto.Random import get_random_bytes -import libs.gmpy2_pkcs10aep_cipher as gmpy2_pkcs10aep_cipher from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from libs import gmpy2_pkcs10aep_cipher def generate_key_pair(tenant_id): diff --git a/api/models/dataset.py b/api/models/dataset.py index 0da35910cd..a2d2a3454d 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -284,9 +284,9 @@ class Document(db.Model): status = None if self.indexing_status == "waiting": status = "queuing" - elif self.indexing_status not in ["completed", "error", "waiting"] and self.is_paused: + elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused: status = "paused" - elif self.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}: status = "indexing" elif self.indexing_status == "error": status = "error" @@ -331,7 +331,7 @@ class Document(db.Model): "created_at": file_detail.created_at.timestamp(), } } - elif self.data_source_type == "notion_import" or self.data_source_type == "website_crawl": + elif self.data_source_type in {"notion_import", "website_crawl"}: return json.loads(self.data_source_info) return {} diff --git a/api/models/model.py b/api/models/model.py index a8b2e00ee4..ae0bc3210b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -134,7 +134,7 @@ class App(db.Model): return False if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( "strategy", "" - ) in ["function_call", "react"]: + ) in {"function_call", "react"}: self.mode = AppMode.AGENT_CHAT.value db.session.commit() return True @@ -1501,6 +1501,6 @@ class TraceAppConfig(db.Model): "tracing_provider": self.tracing_provider, "tracing_config": self.tracing_config_dict, "is_active": self.is_active, - "created_at": self.created_at.__str__() if self.created_at else None, - "updated_at": self.updated_at.__str__() if self.updated_at else None, + "created_at": str(self.created_at) if self.created_at else None, + "updated_at": str(self.updated_at) if self.updated_at else None, } diff --git a/api/pyproject.toml b/api/pyproject.toml index 83aa35c542..57a3844200 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,6 +6,9 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] +exclude=[ + "migrations/*", +] line-length = 120 [tool.ruff.lint] @@ -19,6 +22,13 @@ select = [ "I", # isort rules "N", # pep8-naming "PT", # flake8-pytest-style rules + "PLC0208", # iteration-over-set + "PLC2801", # unnecessary-dunder-call + "PLC0414", # useless-import-alias + "PLR0402", # manual-from-import + "PLR1711", # useless-return + "PLR1714", # repeated-equality-comparison + "PLR6201", # literal-membership "RUF019", # unnecessary-key-check "RUF100", # unused-noqa "RUF101", # redirected-noqa @@ -78,9 +88,6 @@ ignore = [ "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name ] -"migrations/versions/*" = [ - "E501", # line-too-long -] "tests/*" = [ "F401", # unused-import "F811", # redefined-while-unused @@ -88,7 +95,6 @@ ignore = [ [tool.ruff.format] exclude = [ - "migrations/**/*", ] [tool.pytest_env] diff --git a/api/services/account_service.py b/api/services/account_service.py index e839ae54ba..66ff5d2b7c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -47,7 +47,7 @@ class AccountService: if not account: return None - if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise Unauthorized("Account is banned or closed.") current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( @@ -92,7 +92,7 @@ class AccountService: if not account: raise AccountLoginError("Invalid email or password.") - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise AccountLoginError("Account is banned or closed.") if account.status == AccountStatus.PENDING.value: @@ -427,7 +427,7 @@ class TenantService: "remove": [TenantAccountRole.OWNER], "update": [TenantAccountRole.OWNER], } - if action not in ["add", "remove", "update"]: + if action not in {"add", "remove", "update"}: raise InvalidActionError("Invalid action.") if member: diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 2fe39b5224..54594e1175 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -90,7 +90,7 @@ class AppDslService: # import dsl and create app app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, @@ -103,7 +103,7 @@ class AppDslService: icon_background=icon_background, use_icon_as_answer_icon=use_icon_as_answer_icon, ) - elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: + elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, @@ -143,7 +143,7 @@ class AppDslService: # import dsl and overwrite app app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: raise ValueError("Only support import workflow in advanced-chat or workflow app.") if app_data.get("mode") != app_model.mode: @@ -177,7 +177,7 @@ class AppDslService: }, } - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: cls._append_workflow_export_data( export_data=export_data, app_model=app_model, include_secret=include_secret ) diff --git a/api/services/app_service.py b/api/services/app_service.py index 1dacfea246..ac45d623e8 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -316,7 +316,7 @@ class AppService: meta = {"tool_icons": {}} - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 05cd1c96a1..7a0cd5725b 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") @@ -83,7 +83,7 @@ class AudioService: def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): with app.app_context(): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("TTS is not enabled") diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py index 30e4ee57c0..afc491398f 100644 --- a/api/services/auth/firecrawl.py +++ b/api/services/auth/firecrawl.py @@ -37,7 +37,7 @@ class FirecrawlAuth(ApiKeyAuthBase): return requests.post(url, headers=headers, json=data) def _handle_error(self, response): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index fa017bfa42..30c010ef29 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -544,7 +544,7 @@ class DocumentService: @staticmethod def pause_document(document): - if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]: + if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: raise DocumentIndexingError() # update document to be paused document.is_paused = True diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 6fb0f2f517..7ae1b9f231 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -33,7 +33,7 @@ class ToolTransformService: if provider_type == ToolProviderType.BUILT_IN.value: return url_prefix + "builtin/" + provider_name + "/icon" - elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: + elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: return json.loads(icon) except: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 357ffd41c1..0ff81f1f7e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -295,7 +295,7 @@ class WorkflowService: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: + if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 21ea11d4dd..934eb7430c 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -29,7 +29,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): try: indexing_runner = IndexingRunner() - if document.indexing_status in ["waiting", "parsing", "cleaning"]: + if document.indexing_status in {"waiting", "parsing", "cleaning"}: indexing_runner.run([document]) elif document.indexing_status == "splitting": indexing_runner.run_in_splitting_status(document) diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index bc0684086f..402bd9c2c2 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,15 +1,13 @@ from collections.abc import Generator -import google.generativeai.types.content_types as content_types import google.generativeai.types.generation_types as generation_config_types -import google.generativeai.types.safety_types as safety_types import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm from google.ai.generativelanguage_v1beta.types import content as gag_content from google.generativeai import GenerativeModel from google.generativeai.client import _ClientManager, configure -from google.generativeai.types import GenerateContentResponse +from google.generativeai.types import GenerateContentResponse, content_types, safety_types from google.generativeai.types.generation_types import BaseGenerateContentResponse current_api_key = "" diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index d9cd7b046e..439f7d56e9 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -6,7 +6,6 @@ from time import time # import monkeypatch from typing import Any, Literal, Optional, Union -import openai.types.chat.completion_create_params as completion_create_params from openai import AzureOpenAI, OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.chat.completions import Completions @@ -18,6 +17,7 @@ from openai.types.chat import ( ChatCompletionMessageToolCall, ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, + completion_create_params, ) from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice @@ -254,7 +254,7 @@ class MockChatClass: "gpt-3.5-turbo-16k-0613", ] azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index c27e89248f..14223668e0 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -112,7 +112,7 @@ class MockCompletionsClass: ] azure_openai_models = ["gpt-35-turbo-instruct"] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index 025913cb17..e27b9891f5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -22,7 +22,7 @@ class MockEmbeddingsClass: if isinstance(input, str): input = [input] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 270a88e85f..4262d40f3e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -20,7 +20,7 @@ class MockModerationClass: if isinstance(input, str): input = [input] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index ef361e8613..a51dcab4be 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -20,7 +20,7 @@ class MockSpeech2TextClass: temperature: float | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> Transcription: - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 777737187e..299523f4f5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -42,7 +42,7 @@ class MockXinferenceClass: model_uid = url.split("/")[-1] or "" if not re.match( r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid - ) and model_uid not in ["generate", "chat", "embedding", "rerank"]: + ) and model_uid not in {"generate", "chat", "embedding", "rerank"}: response.status_code = 404 response._content = b"{}" return response @@ -53,7 +53,7 @@ class MockXinferenceClass: response._content = b"{}" return response - if model_uid in ["generate", "chat"]: + if model_uid in {"generate", "chat"}: response.status_code = 200 response._content = b"""{ "model_type": "LLM", diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index cbe9c5914f..88435c4022 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -411,5 +411,5 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): if latest_role is not None: assert latest_role != prompt.get("role") - if prompt.get("role") in ["user", "assistant"]: + if prompt.get("role") in {"user", "assistant"}: latest_role = prompt.get("role") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index a2d71d61fc..197288adba 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -210,7 +210,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]: + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in {"llm2", "llm3", "end1", "end2"}: assert item.parallel_id is not None assert len(items) == 18 @@ -315,12 +315,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in { "answer2", "answer3", "answer4", "answer5", - ]: + }: assert item.parallel_id is not None assert len(items) == 23 From aad6f340b3a7402cd73c40d51802c190bb7fd269 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 13 Sep 2024 23:19:36 +0800 Subject: [PATCH 09/43] fix (#8322 followup): resolve the violation of pylint rules (#8391) --- api/core/tools/provider/builtin/firecrawl/tools/crawl.py | 4 ++-- api/core/tools/provider/builtin/firecrawl/tools/scrape.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 15ab510c6c..9675b8eb91 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -35,10 +35,10 @@ class CrawlTool(BuiltinTool): scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags") scrapeOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) scrapeOptions["waitFor"] = tool_parameters.get("waitFor", 0) - scrapeOptions = {k: v for k, v in scrapeOptions.items() if v not in (None, "")} + scrapeOptions = {k: v for k, v in scrapeOptions.items() if v not in {None, ""}} payload["scrapeOptions"] = scrapeOptions or None - payload = {k: v for k, v in payload.items() if v not in (None, "")} + payload = {k: v for k, v in payload.items() if v not in {None, ""}} crawl_result = app.crawl_url(url=tool_parameters["url"], wait=wait_for_results, **payload) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index f00a9b31ce..538b4a1fcb 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -29,10 +29,10 @@ class ScrapeTool(BuiltinTool): extract["schema"] = get_json_params(tool_parameters, "schema") extract["systemPrompt"] = tool_parameters.get("systemPrompt") extract["prompt"] = tool_parameters.get("prompt") - extract = {k: v for k, v in extract.items() if v not in (None, "")} + extract = {k: v for k, v in extract.items() if v not in {None, ""}} payload["extract"] = extract or None - payload = {k: v for k, v in payload.items() if v not in (None, "")} + payload = {k: v for k, v in payload.items() if v not in {None, ""}} crawl_result = app.scrape_url(url=tool_parameters["url"], **payload) markdown_result = crawl_result.get("data", {}).get("markdown", "") From 5b98acde2fb09b0efcff32024d796bdca760ae18 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 13 Sep 2024 23:34:39 +0800 Subject: [PATCH 10/43] chore: improve usage of striping prefix or suffix of string with Ruff 0.6.5 (#8392) --- .../huggingface_tei/rerank/rerank.py | 3 +- .../text_embedding/text_embedding.py | 6 +-- .../model_providers/jina/rerank/rerank.py | 3 +- .../jina/text_embedding/text_embedding.py | 3 +- .../siliconflow/rerank/rerank.py | 3 +- .../model_providers/xinference/llm/llm.py | 3 +- .../xinference/rerank/rerank.py | 6 +-- .../xinference/speech2text/speech2text.py | 6 +-- .../text_embedding/text_embedding.py | 6 +-- .../model_providers/xinference/tts/tts.py | 6 +-- .../zhipuai/zhipuai_sdk/core/_sse_client.py | 3 +- .../builtin/hap/tools/add_worksheet_record.py | 2 +- .../hap/tools/delete_worksheet_record.py | 2 +- .../builtin/hap/tools/get_worksheet_fields.py | 2 +- .../hap/tools/get_worksheet_pivot_data.py | 2 +- .../hap/tools/list_worksheet_records.py | 2 +- .../builtin/hap/tools/list_worksheets.py | 2 +- .../hap/tools/update_worksheet_record.py | 2 +- api/poetry.lock | 40 +++++++++---------- api/pyproject.toml | 2 +- 20 files changed, 44 insertions(+), 60 deletions(-) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index c128c35f6d..74a1dfc3ff 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -49,8 +49,7 @@ class HuggingfaceTeiRerankModel(RerankModel): return RerankResult(model=model, docs=[]) server_url = credentials["server_url"] - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") try: results = TeiHelper.invoke_rerank(server_url, query, docs) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 2d04abb277..55f3c25804 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -42,8 +42,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): """ server_url = credentials["server_url"] - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") # get model properties context_size = self._get_context_size(model, credentials) @@ -119,8 +118,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): num_tokens = 0 server_url = credentials["server_url"] - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) num_tokens = sum(len(tokens) for tokens in batch_tokens) diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index d8394f7a4c..79ca68914f 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -48,8 +48,7 @@ class JinaRerankModel(RerankModel): return RerankResult(model=model, docs=[]) base_url = credentials.get("base_url", "https://api.jina.ai/v1") - if base_url.endswith("/"): - base_url = base_url[:-1] + base_url = base_url.removesuffix("/") try: response = httpx.post( diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 7ed3e4d384..ef12e534db 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -44,8 +44,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): raise CredentialsValidateFailedError("api_key is required") base_url = credentials.get("base_url", self.api_base) - if base_url.endswith("/"): - base_url = base_url[:-1] + base_url = base_url.removesuffix("/") url = base_url + "/embeddings" headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py index 6f652e9d52..58b033d28a 100644 --- a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -30,8 +30,7 @@ class SiliconflowRerankModel(RerankModel): return RerankResult(model=model, docs=[]) base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1") - if base_url.endswith("/"): - base_url = base_url[:-1] + base_url = base_url.removesuffix("/") try: response = httpx.post( base_url + "/rerank", diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index f8f6c6b12d..4fadda5df5 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -459,8 +459,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if "server_url" not in credentials: raise CredentialsValidateFailedError("server_url is required in credentials") - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") api_key = credentials.get("api_key") or "abc" diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 1582fe43b9..8f18bc42d2 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -50,8 +50,7 @@ class XinferenceRerankModel(RerankModel): server_url = credentials["server_url"] model_uid = credentials["model_uid"] api_key = credentials.get("api_key") - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True} @@ -98,8 +97,7 @@ class XinferenceRerankModel(RerankModel): if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") # initialize client client = Client( diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 18efde758c..a6c5b8a0a5 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -45,8 +45,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") # initialize client client = Client( @@ -116,8 +115,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): server_url = credentials["server_url"] model_uid = credentials["model_uid"] api_key = credentials.get("api_key") - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index ac704e7de8..8043af1d6c 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -45,8 +45,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): server_url = credentials["server_url"] model_uid = credentials["model_uid"] api_key = credentials.get("api_key") - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: @@ -118,8 +117,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): if extra_args.max_tokens: credentials["max_tokens"] = extra_args.max_tokens - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") client = Client( base_url=server_url, diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index d29e8ed8f9..10538b5788 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -73,8 +73,7 @@ class XinferenceText2SpeechModel(TTSModel): if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials["server_url"], @@ -189,8 +188,7 @@ class XinferenceText2SpeechModel(TTSModel): :param voice: model timbre :return: text translated to audio file """ - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") try: api_key = credentials.get("api_key") diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 3566c6b332..ec2745d059 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -127,8 +127,7 @@ class SSELineParser: field, _p, value = line.partition(":") - if value.startswith(" "): - value = value[1:] + value = value.removeprefix(" ") if field == "data": self._data.append(value) elif field == "event": diff --git a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py index f2288ed81c..597adc91db 100644 --- a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py @@ -30,7 +30,7 @@ class AddWorksheetRecordTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/addRow" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py index 1df5f6d5cf..5d42af4c49 100644 --- a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py @@ -29,7 +29,7 @@ class DeleteWorksheetRecordTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/deleteRow" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 79e5889eae..6887b8b4e9 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -27,7 +27,7 @@ class GetWorksheetFieldsTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/getWorksheetInfo" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py index 01c1af9b3e..26d7116869 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py @@ -38,7 +38,7 @@ class GetWorksheetPivotDataTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/report/getPivotData" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index 44c7e52307..d6ac3688b7 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -30,7 +30,7 @@ class ListWorksheetRecordsTool(BuiltinTool): elif not (host.startswith("http://") or host.startswith("https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url_fields = f"{host}/v2/open/worksheet/getWorksheetInfo" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py index 4dba2df1f1..4e852c0028 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py @@ -24,7 +24,7 @@ class ListWorksheetsTool(BuiltinTool): elif not (host.startswith("http://") or host.startswith("https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v1/open/app/get" result_type = tool_parameters.get("result_type", "") diff --git a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py index 32abb18f9a..971f3d37f6 100644 --- a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py @@ -33,7 +33,7 @@ class UpdateWorksheetRecordTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/editRow" headers = {"Content-Type": "application/json"} diff --git a/api/poetry.lock b/api/poetry.lock index 36b52f68be..191db600e4 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -8003,29 +8003,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.6.4" +version = "0.6.5" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.4-py3-none-linux_armv6l.whl", hash = "sha256:c4b153fc152af51855458e79e835fb6b933032921756cec9af7d0ba2aa01a258"}, - {file = "ruff-0.6.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bedff9e4f004dad5f7f76a9d39c4ca98af526c9b1695068198b3bda8c085ef60"}, - {file = "ruff-0.6.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d02a4127a86de23002e694d7ff19f905c51e338c72d8e09b56bfb60e1681724f"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7862f42fc1a4aca1ea3ffe8a11f67819d183a5693b228f0bb3a531f5e40336fc"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eebe4ff1967c838a1a9618a5a59a3b0a00406f8d7eefee97c70411fefc353617"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:932063a03bac394866683e15710c25b8690ccdca1cf192b9a98260332ca93408"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:50e30b437cebef547bd5c3edf9ce81343e5dd7c737cb36ccb4fe83573f3d392e"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c44536df7b93a587de690e124b89bd47306fddd59398a0fb12afd6133c7b3818"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ea086601b22dc5e7693a78f3fcfc460cceabfdf3bdc36dc898792aba48fbad6"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b52387d3289ccd227b62102c24714ed75fbba0b16ecc69a923a37e3b5e0aaaa"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0308610470fcc82969082fc83c76c0d362f562e2f0cdab0586516f03a4e06ec6"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:803b96dea21795a6c9d5bfa9e96127cc9c31a1987802ca68f35e5c95aed3fc0d"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:66dbfea86b663baab8fcae56c59f190caba9398df1488164e2df53e216248baa"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:34d5efad480193c046c86608dbba2bccdc1c5fd11950fb271f8086e0c763a5d1"}, - {file = "ruff-0.6.4-py3-none-win32.whl", hash = "sha256:f0f8968feea5ce3777c0d8365653d5e91c40c31a81d95824ba61d871a11b8523"}, - {file = "ruff-0.6.4-py3-none-win_amd64.whl", hash = "sha256:549daccee5227282289390b0222d0fbee0275d1db6d514550d65420053021a58"}, - {file = "ruff-0.6.4-py3-none-win_arm64.whl", hash = "sha256:ac4b75e898ed189b3708c9ab3fc70b79a433219e1e87193b4f2b77251d058d14"}, - {file = "ruff-0.6.4.tar.gz", hash = "sha256:ac3b5bfbee99973f80aa1b7cbd1c9cbce200883bdd067300c22a6cc1c7fba212"}, + {file = "ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748"}, + {file = "ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69"}, + {file = "ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:52e75a82bbc9b42e63c08d22ad0ac525117e72aee9729a069d7c4f235fc4d276"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09c72a833fd3551135ceddcba5ebdb68ff89225d30758027280968c9acdc7810"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:800c50371bdcb99b3c1551d5691e14d16d6f07063a518770254227f7f6e8c178"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e25ddd9cd63ba1f3bd51c1f09903904a6adf8429df34f17d728a8fa11174253"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291e64d7129f24d1b0c947ec3ec4c0076e958d1475c61202497c6aced35dd19"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc"}, + {file = "ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5"}, + {file = "ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9"}, + {file = "ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0"}, + {file = "ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb"}, ] [[package]] @@ -10416,4 +10416,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "726af69ca5a577808dfe76dbce098de77ce358bf64862a4d27309cb1900cea0c" +content-hash = "9173a56b2efea12804c980511e1465fba43c7a3d83b1ad284ee149851ed67fc5" diff --git a/api/pyproject.toml b/api/pyproject.toml index 57a3844200..166ddcec50 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -283,4 +283,4 @@ optional = true [tool.poetry.group.lint.dependencies] dotenv-linter = "~0.5.0" -ruff = "~0.6.4" +ruff = "~0.6.5" From b6b1057a182bbf9af878391de66287f5bfd4a4e4 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Sat, 14 Sep 2024 02:02:55 +0800 Subject: [PATCH 11/43] fix: sandbox issue related httpx and requests (#8397) --- .../workflow/nodes/code_executor/test_code_python3.py | 1 - docker/docker-compose.middleware.yaml | 2 +- docker/docker-compose.yaml | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py index cbe4a5d335..25af312afa 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -1,4 +1,3 @@ -import json from textwrap import dedent from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index dbfc1ea531..00faa2960a 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -41,7 +41,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.7 + image: langgenius/dify-sandbox:0.2.9 restart: always environment: # The DifySandbox configurations diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index b8e068fde0..d080731a28 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -292,7 +292,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.7 + image: langgenius/dify-sandbox:0.2.9 restart: always environment: # The DifySandbox configurations From 71b4480c4a7e6db54a1599c053250a74b57683c1 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Sat, 14 Sep 2024 02:39:58 +0800 Subject: [PATCH 12/43] fix: o1-mini 65563 -> 65536 (#8388) --- .../model_providers/openai/llm/o1-mini-2024-09-12.yaml | 4 ++-- .../model_runtime/model_providers/openai/llm/o1-mini.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml index 07a3bc9a7a..0ade7f8ded 100644 --- a/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml @@ -11,9 +11,9 @@ model_properties: parameter_rules: - name: max_tokens use_template: max_tokens - default: 65563 + default: 65536 min: 1 - max: 65563 + max: 65536 - name: response_format label: zh_Hans: 回复格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml index 3e83529201..60816c5d1e 100644 --- a/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml @@ -11,9 +11,9 @@ model_properties: parameter_rules: - name: max_tokens use_template: max_tokens - default: 65563 + default: 65536 min: 1 - max: 65563 + max: 65536 - name: response_format label: zh_Hans: 回复格式 From bf55b1910f8c75d9693f42f83d7d0233be5b5d9a Mon Sep 17 00:00:00 2001 From: Nam Vu Date: Sat, 14 Sep 2024 08:45:49 +0700 Subject: [PATCH 13/43] fix: pyproject.toml typo (#8396) --- api/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 166ddcec50..8c10f1dad9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -209,7 +209,7 @@ zhipuai = "1.0.7" # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. ############################################################ -# Related transparent dependencies with pinned verion +# Related transparent dependencies with pinned version # required by main implementations ############################################################ azure-ai-ml = "^1.19.0" From 8efae1cba29ee36127a986315e19a43d36d4f943 Mon Sep 17 00:00:00 2001 From: Incca Date: Sat, 14 Sep 2024 09:52:59 +0800 Subject: [PATCH 14/43] fix(docker): aliyun oss path env key (#8394) --- docker/docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d080731a28..10bd1d1ae2 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -75,7 +75,7 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} - ALIYUN_OSS_PATHS: ${ALIYUN_OSS_PATH:-} + ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-} From b613b114228cd30f1d141327acda7a7ca3758ec0 Mon Sep 17 00:00:00 2001 From: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Date: Sat, 14 Sep 2024 11:06:20 +0800 Subject: [PATCH 15/43] Fix: Support Bedrock cross region inference #8190 (Update Model name to distinguish between different region groups) (#8402) Co-authored-by: Yuanbo Li --- .../bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml | 2 +- .../bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml | 2 +- .../bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml | 2 +- .../bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml | 2 +- .../bedrock/llm/us.anthropic.claude-3-opus-v1.yaml | 2 +- .../bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml | 2 +- .../bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml index fe5f54de13..24a65ef1bb 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml @@ -1,6 +1,6 @@ model: eu.anthropic.claude-3-haiku-20240307-v1:0 label: - en_US: Claude 3 Haiku(Cross Region Inference) + en_US: Claude 3 Haiku(EU.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml index 9f8d029a57..e3d25c7d8f 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml @@ -1,6 +1,6 @@ model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0 label: - en_US: Claude 3.5 Sonnet(Cross Region Inference) + en_US: Claude 3.5 Sonnet(EU.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml index bfaf5abb8e..9a06a4ad6d 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml @@ -1,6 +1,6 @@ model: eu.anthropic.claude-3-sonnet-20240229-v1:0 label: - en_US: Claude 3 Sonnet(Cross Region Inference) + en_US: Claude 3 Sonnet(EU.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml index 58c1f05779..9247f46974 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml @@ -1,6 +1,6 @@ model: us.anthropic.claude-3-haiku-20240307-v1:0 label: - en_US: Claude 3 Haiku(Cross Region Inference) + en_US: Claude 3 Haiku(US.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml index 6b9e1ec067..f9854d51f0 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml @@ -1,6 +1,6 @@ model: us.anthropic.claude-3-opus-20240229-v1:0 label: - en_US: Claude 3 Opus(Cross Region Inference) + en_US: Claude 3 Opus(US.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml index f1e0d6c5a2..fbcab2d5f3 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml @@ -1,6 +1,6 @@ model: us.anthropic.claude-3-5-sonnet-20240620-v1:0 label: - en_US: Claude 3.5 Sonnet(Cross Region Inference) + en_US: Claude 3.5 Sonnet(US.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml index dce50bf4b5..9f5a1501f0 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml @@ -1,6 +1,6 @@ model: us.anthropic.claude-3-sonnet-20240229-v1:0 label: - en_US: Claude 3 Sonnet(Cross Region Inference) + en_US: Claude 3 Sonnet(US.Cross Region Inference) model_type: llm features: - agent-thought From f55e06d8bf38d6cfef1baa0b21b5ce663dc0a71e Mon Sep 17 00:00:00 2001 From: swingchen01 Date: Sat, 14 Sep 2024 11:07:16 +0800 Subject: [PATCH 16/43] fix: resolve runtime error when self.folder is None (#8401) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 陈长君 --- api/extensions/storage/aliyun_storage.py | 48 +++++++----------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py index bee237fc17..2677912aa9 100644 --- a/api/extensions/storage/aliyun_storage.py +++ b/api/extensions/storage/aliyun_storage.py @@ -31,54 +31,34 @@ class AliyunStorage(BaseStorage): ) def save(self, filename, data): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - self.client.put_object(filename, data) + self.client.put_object(self.__wrapper_folder_filename(filename), data) def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - with closing(self.client.get_object(filename)) as obj: + with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj: data = obj.read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - with closing(self.client.get_object(filename)) as obj: + with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj: while chunk := obj.read(4096): yield chunk return generate() def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - self.client.get_object_to_file(filename, target_filepath) + self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) def exists(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - return self.client.object_exists(filename) + return self.client.object_exists(self.__wrapper_folder_filename(filename)) def delete(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - self.client.delete_object(filename) + self.client.delete_object(self.__wrapper_folder_filename(filename)) + + def __wrapper_folder_filename(self, filename) -> str: + if self.folder: + if self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + return filename From 0123498452fb2ebb89a157925113168b5cc7fe7a Mon Sep 17 00:00:00 2001 From: HowardChan Date: Sat, 14 Sep 2024 12:56:45 +0800 Subject: [PATCH 17/43] fix:logs and rm unused codes in CacheEmbedding (#8409) --- api/core/embedding/cached_embedding.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 4cc793b0d7..8ce12fd59f 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -65,7 +65,7 @@ class CacheEmbedding(Embeddings): except IntegrityError: db.session.rollback() except Exception as e: - logging.exception("Failed transform embedding: ", e) + logging.exception("Failed transform embedding: %s", e) cache_embeddings = [] try: for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): @@ -85,7 +85,7 @@ class CacheEmbedding(Embeddings): db.session.rollback() except Exception as ex: db.session.rollback() - logger.error("Failed to embed documents: ", ex) + logger.error("Failed to embed documents: %s", ex) raise ex return text_embeddings @@ -116,10 +116,7 @@ class CacheEmbedding(Embeddings): # Transform to string encoded_str = encoded_vector.decode("utf-8") redis_client.setex(embedding_cache_key, 600, encoded_str) - - except IntegrityError: - db.session.rollback() - except: - logging.exception("Failed to add embedding to redis") + except Exception as ex: + logging.exception("Failed to add embedding to redis %s", ex) return embedding_results From f01602b5703f0daad62b781dec91fce72cf784d9 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 14 Sep 2024 14:02:09 +0800 Subject: [PATCH 18/43] fix(workflow): the answer node after the iteration node containing the answer was output prematurely (#8419) --- api/core/workflow/nodes/answer/answer_stream_generate_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 5e6de8fb15..bbd1f88867 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -152,6 +152,7 @@ class AnswerStreamGeneratorRouter: NodeType.ANSWER.value, NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER.value, + NodeType.ITERATION.value, }: answer_dependencies[answer_node_id].append(source_node_id) else: From 5b18e851d28ef0241d16c4c1f412e73c17e76649 Mon Sep 17 00:00:00 2001 From: Pika Date: Sat, 14 Sep 2024 14:08:10 +0800 Subject: [PATCH 19/43] fix: when the variable does not exist, an error should be prompted (#8413) Co-authored-by: Chen(MAC) --- .../nodes/_base/components/variable-tag.tsx | 81 +++++++++++------- .../variable/var-reference-picker.tsx | 85 +++++++++++-------- .../condition-list/condition-item.tsx | 1 + 3 files changed, 100 insertions(+), 67 deletions(-) diff --git a/web/app/components/workflow/nodes/_base/components/variable-tag.tsx b/web/app/components/workflow/nodes/_base/components/variable-tag.tsx index 6edda1d7e8..0a80bfe37f 100644 --- a/web/app/components/workflow/nodes/_base/components/variable-tag.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable-tag.tsx @@ -1,9 +1,12 @@ import { useMemo } from 'react' import { useNodes } from 'reactflow' import { capitalize } from 'lodash-es' +import { useTranslation } from 'react-i18next' +import { RiErrorWarningFill } from '@remixicon/react' import { VarBlockIcon } from '@/app/components/workflow/block-icon' import type { CommonNodeType, + Node, ValueSelector, VarType, } from '@/app/components/workflow/types' @@ -11,63 +14,75 @@ import { BlockEnum } from '@/app/components/workflow/types' import { Line3 } from '@/app/components/base/icons/src/public/common' import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' import { BubbleX, Env } from '@/app/components/base/icons/src/vender/line/others' -import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' +import { getNodeInfoById, isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' +import Tooltip from '@/app/components/base/tooltip' import cn from '@/utils/classnames' type VariableTagProps = { valueSelector: ValueSelector varType: VarType + availableNodes?: Node[] } const VariableTag = ({ valueSelector, varType, + availableNodes, }: VariableTagProps) => { const nodes = useNodes() const node = useMemo(() => { - if (isSystemVar(valueSelector)) - return nodes.find(node => node.data.type === BlockEnum.Start) + if (isSystemVar(valueSelector)) { + const startNode = availableNodes?.find(n => n.data.type === BlockEnum.Start) + if (startNode) + return startNode + } + return getNodeInfoById(availableNodes || nodes, valueSelector[0]) + }, [nodes, valueSelector, availableNodes]) - return nodes.find(node => node.id === valueSelector[0]) - }, [nodes, valueSelector]) const isEnv = isENV(valueSelector) const isChatVar = isConversationVar(valueSelector) + const isValid = Boolean(node) || isEnv || isChatVar const variableName = isSystemVar(valueSelector) ? valueSelector.slice(0).join('.') : valueSelector.slice(1).join('.') + const { t } = useTranslation() return ( -
- {!isEnv && !isChatVar && ( - <> + +
+ {(!isEnv && !isChatVar && <> {node && ( - + <> + +
+ {node?.data.title} +
+ )} -
- {node?.data.title} -
- - )} - {isEnv && } - {isChatVar && } -
- {variableName} + )} + {isEnv && } + {isChatVar && } +
+ {variableName} +
+ { + varType && ( +
{capitalize(varType)}
+ ) + } + {!isValid && }
- { - varType && ( -
{capitalize(varType)}
- ) - } -
+
) } diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx index 7fb4ad68d8..67c839bf76 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx @@ -5,9 +5,10 @@ import { useTranslation } from 'react-i18next' import { RiArrowDownSLine, RiCloseLine, + RiErrorWarningFill, } from '@remixicon/react' import produce from 'immer' -import { useStoreApi } from 'reactflow' +import { useEdges, useStoreApi } from 'reactflow' import useAvailableVarList from '../../hooks/use-available-var-list' import VarReferencePopup from './var-reference-popup' import { getNodeInfoById, isConversationVar, isENV, isSystemVar } from './utils' @@ -33,6 +34,8 @@ import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/typ import TypeSelector from '@/app/components/workflow/nodes/_base/components/selector' import AddButton from '@/app/components/base/button/add-button' import Badge from '@/app/components/base/badge' +import Tooltip from '@/app/components/base/tooltip' + const TRIGGER_DEFAULT_WIDTH = 227 type Props = { @@ -77,6 +80,7 @@ const VarReferencePicker: FC = ({ const { getNodes, } = store.getState() + const edges = useEdges() const isChatMode = useIsChatMode() const { getCurrentVariableType } = useWorkflowVariables() @@ -206,8 +210,16 @@ const VarReferencePicker: FC = ({ isConstant: !!isConstant, }) - const isEnv = isENV(value as ValueSelector) - const isChatVar = isConversationVar(value as ValueSelector) + const { isEnv, isChatVar, isValidVar } = useMemo(() => { + const isEnv = isENV(value as ValueSelector) + const isChatVar = isConversationVar(value as ValueSelector) + const isValidVar = Boolean(outputVarNode) || isEnv || isChatVar + return { + isEnv, + isChatVar, + isValidVar, + } + }, [value, edges, outputVarNode]) // 8(left/right-padding) + 14(icon) + 4 + 14 + 2 = 42 + 17 buff const availableWidth = triggerWidth - 56 @@ -285,39 +297,44 @@ const VarReferencePicker: FC = ({ className='grow h-full' >
-
- {hasValue - ? ( - <> - {isShowNodeName && !isEnv && !isChatVar && ( -
-
- + +
+ {hasValue + ? ( + <> + {isShowNodeName && !isEnv && !isChatVar && ( +
+
+ {outputVarNode?.type && } +
+
{outputVarNode?.title}
+
-
{outputVarNode?.title}
- + )} +
+ {!hasValue && } + {isEnv && } + {isChatVar && } +
{varName}
- )} -
- {!hasValue && } - {isEnv && } - {isChatVar && } -
{varName}
-
-
{type}
- - ) - :
{t('workflow.common.setVarValuePlaceholder')}
} -
+
{type}
+ {!isValidVar && } + + ) + :
{t('workflow.common.setVarValuePlaceholder')}
} +
+
diff --git a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx index c6cb580118..c96f6b3ef6 100644 --- a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx +++ b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx @@ -80,6 +80,7 @@ const ConditionItem = ({
From 032dd93b2f792ad630fc00c84c3a9b7d99ce0ff3 Mon Sep 17 00:00:00 2001 From: KVOJJJin Date: Sat, 14 Sep 2024 14:08:31 +0800 Subject: [PATCH 20/43] Fix: operation postion of answer in logs (#8411) Co-authored-by: Yi --- .../annotation/annotation-ctrl-btn/index.tsx | 4 +- web/app/components/app/log/list.tsx | 4 +- .../base/chat/chat/answer/index.tsx | 53 ++++++------------- .../base/chat/chat/answer/operation.tsx | 2 +- 4 files changed, 21 insertions(+), 42 deletions(-) diff --git a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx index 111c380afc..809b907d62 100644 --- a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx +++ b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx @@ -73,7 +73,7 @@ const CacheCtrlBtn: FC = ({ setShowModal(false) } return ( -
+
{cached ? ( @@ -101,7 +101,6 @@ const CacheCtrlBtn: FC = ({ ? (
= ({ }
{/* Panel Header */} -
+
{isChatMode ? t('appLog.detail.conversationId') : t('appLog.detail.time')}
{isChatMode && ( @@ -725,7 +725,7 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) onClose={onCloseDrawer} mask={isMobile} footer={null} - panelClassname='mt-16 mx-2 sm:mr-2 mb-4 !p-0 !max-w-[640px] rounded-xl' + panelClassname='mt-16 mx-2 sm:mr-2 mb-4 !p-0 !max-w-[640px] rounded-xl bg-background-gradient-bg-fill-chat-bg-1' > = ({ } = item const hasAgentThoughts = !!agent_thoughts?.length - const [containerWidth] = useState(0) + const [containerWidth, setContainerWidth] = useState(0) const [contentWidth, setContentWidth] = useState(0) const containerRef = useRef(null) const contentRef = useRef(null) - const { - config: chatContextConfig, - } = useChatContext() + const getContainerWidth = () => { + if (containerRef.current) + setContainerWidth(containerRef.current?.clientWidth + 16) + } + useEffect(() => { + getContainerWidth() + }, []) - const voiceRef = useRef(chatContextConfig?.text_to_speech?.voice) const getContentWidth = () => { if (contentRef.current) setContentWidth(contentRef.current?.clientWidth) } - useEffect(() => { - voiceRef.current = chatContextConfig?.text_to_speech?.voice - } - , [chatContextConfig?.text_to_speech?.voice]) - useEffect(() => { if (!responding) getContentWidth() @@ -89,36 +86,20 @@ const Answer: FC = ({ return (
- { - answerIcon || - } - { - responding && ( -
- -
- ) - } + {answerIcon || } + {responding && ( +
+ +
+ )}
-
+
- {annotation?.id && ( -
-
- -
-
- )} { !responding && ( = ({ /> )} { - !positionRight && annotation?.id && ( + annotation?.id && (
From 52857dc0a60d4407a5680900e2ad7245c8b31e41 Mon Sep 17 00:00:00 2001 From: kurokobo Date: Sat, 14 Sep 2024 15:11:45 +0900 Subject: [PATCH 21/43] feat: allow users to specify timeout for text generations and workflows by environment variable (#8395) --- docker/.env.example | 7 +++++++ docker/docker-compose.yaml | 1 + web/.env.example | 3 +++ .../share/text-generation/result/index.tsx | 12 +++++++++--- web/app/layout.tsx | 1 + web/config/index.ts | 9 ++++++++- web/docker/entrypoint.sh | 2 ++ web/i18n/en-US/app-debug.ts | 3 +++ 8 files changed, 34 insertions(+), 4 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index 7e4430a37d..c892c15636 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -568,6 +568,13 @@ SSRF_PROXY_HTTP_URL=http://ssrf_proxy:3128 # SSRF Proxy server HTTPS URL SSRF_PROXY_HTTPS_URL=http://ssrf_proxy:3128 +# ------------------------------ +# Environment Variables for web Service +# ------------------------------ + +# The timeout for the text generation in millisecond +TEXT_GENERATION_TIMEOUT_MS=60000 + # ------------------------------ # Environment Variables for db Service # ------------------------------ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 10bd1d1ae2..0fbc695177 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -254,6 +254,7 @@ services: APP_API_URL: ${APP_API_URL:-} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} + TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} # The postgres database. db: diff --git a/web/.env.example b/web/.env.example index 3045cef2f9..8e254082b3 100644 --- a/web/.env.example +++ b/web/.env.example @@ -19,3 +19,6 @@ NEXT_TELEMETRY_DISABLED=1 # Disable Upload Image as WebApp icon default is false NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false + +# The timeout for the text generation in millisecond +NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index 2d5546f9b4..a61302fc98 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -269,8 +269,10 @@ const Result: FC = ({ })) }, onWorkflowFinished: ({ data }) => { - if (isTimeout) + if (isTimeout) { + notify({ type: 'warning', message: t('appDebug.warningMessage.timeoutExceeded') }) return + } if (data.error) { notify({ type: 'error', message: data.error }) setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { @@ -326,8 +328,10 @@ const Result: FC = ({ setCompletionRes(res.join('')) }, onCompleted: () => { - if (isTimeout) + if (isTimeout) { + notify({ type: 'warning', message: t('appDebug.warningMessage.timeoutExceeded') }) return + } setRespondingFalse() setMessageId(tempMessageId) onCompleted(getCompletionRes(), taskId, true) @@ -338,8 +342,10 @@ const Result: FC = ({ setCompletionRes(res.join('')) }, onError() { - if (isTimeout) + if (isTimeout) { + notify({ type: 'warning', message: t('appDebug.warningMessage.timeoutExceeded') }) return + } setRespondingFalse() onCompleted(getCompletionRes(), taskId, false) isEnd = true diff --git a/web/app/layout.tsx b/web/app/layout.tsx index 008fdefa0b..e9242edfad 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -43,6 +43,7 @@ const LocaleLayout = ({ data-public-sentry-dsn={process.env.NEXT_PUBLIC_SENTRY_DSN} data-public-maintenance-notice={process.env.NEXT_PUBLIC_MAINTENANCE_NOTICE} data-public-site-about={process.env.NEXT_PUBLIC_SITE_ABOUT} + data-public-text-generation-timeout-ms={process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS} > diff --git a/web/config/index.ts b/web/config/index.ts index 71edf8f939..21fc2f211c 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -246,6 +246,13 @@ Thought: {{agent_scratchpad}} export const VAR_REGEX = /\{\{(#[a-zA-Z0-9_-]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}/gi -export const TEXT_GENERATION_TIMEOUT_MS = 60000 +export let textGenerationTimeoutMs = 60000 + +if (process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS && process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS !== '') + textGenerationTimeoutMs = parseInt(process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS) +else if (globalThis.document?.body?.getAttribute('data-public-text-generation-timeout-ms') && globalThis.document.body.getAttribute('data-public-text-generation-timeout-ms') !== '') + textGenerationTimeoutMs = parseInt(globalThis.document.body.getAttribute('data-public-text-generation-timeout-ms') as string) + +export const TEXT_GENERATION_TIMEOUT_MS = textGenerationTimeoutMs export const DISABLE_UPLOAD_IMAGE_AS_ICON = process.env.NEXT_PUBLIC_DISABLE_UPLOAD_IMAGE_AS_ICON === 'true' diff --git a/web/docker/entrypoint.sh b/web/docker/entrypoint.sh index a19c543d68..fc4a8f45bc 100755 --- a/web/docker/entrypoint.sh +++ b/web/docker/entrypoint.sh @@ -21,4 +21,6 @@ export NEXT_PUBLIC_SENTRY_DSN=${SENTRY_DSN} export NEXT_PUBLIC_SITE_ABOUT=${SITE_ABOUT} export NEXT_TELEMETRY_DISABLED=${NEXT_TELEMETRY_DISABLED} +export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=${TEXT_GENERATION_TIMEOUT_MS} + pm2 start ./pm2.json --no-daemon diff --git a/web/i18n/en-US/app-debug.ts b/web/i18n/en-US/app-debug.ts index b1f3f33cd8..3156c9f6cc 100644 --- a/web/i18n/en-US/app-debug.ts +++ b/web/i18n/en-US/app-debug.ts @@ -268,6 +268,9 @@ const translation = { notSelectModel: 'Please choose a model', waitForImgUpload: 'Please wait for the image to upload', }, + warningMessage: { + timeoutExceeded: 'Results are not displayed due to timeout. Please refer to the logs to gather complete results.', + }, chatSubTitle: 'Instructions', completionSubTitle: 'Prefix Prompt', promptTip: From de7bc226493bfe5b8431afc30e1ed1a864826d9f Mon Sep 17 00:00:00 2001 From: yanxiyue Date: Sat, 14 Sep 2024 15:16:12 +0800 Subject: [PATCH 22/43] fix: sys_var startwith 'sys.' not 'sys' #8421 (#8422) Co-authored-by: wuling --- .../nodes/_base/components/variable/var-reference-picker.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx index 67c839bf76..5e51173673 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx @@ -183,7 +183,7 @@ const VarReferencePicker: FC = ({ const handleVarReferenceChange = useCallback((value: ValueSelector, varInfo: Var) => { // sys var not passed to backend const newValue = produce(value, (draft) => { - if (draft[1] && draft[1].startsWith('sys')) { + if (draft[1] && draft[1].startsWith('sys.')) { draft.shift() const paths = draft[0].split('.') paths.forEach((p, i) => { From 6f7625fa47225b41b515137a6f6a2caea959a705 Mon Sep 17 00:00:00 2001 From: Aaron Ji <127167174+DresAaron@users.noreply.github.com> Date: Sat, 14 Sep 2024 16:21:17 +0800 Subject: [PATCH 23/43] chore: update Jina embedding model (#8376) --- api/core/model_runtime/model_providers/jina/jina.py | 4 ++-- .../jina/text_embedding/jina-embeddings-v3.yaml | 9 +++++++++ .../jina/text_embedding/text_embedding.py | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml diff --git a/api/core/model_runtime/model_providers/jina/jina.py b/api/core/model_runtime/model_providers/jina/jina.py index 33977b6a33..186a0a0fa7 100644 --- a/api/core/model_runtime/model_providers/jina/jina.py +++ b/api/core/model_runtime/model_providers/jina/jina.py @@ -18,9 +18,9 @@ class JinaProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) - # Use `jina-embeddings-v2-base-en` model for validate, + # Use `jina-embeddings-v3` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials) + model_instance.validate_credentials(model="jina-embeddings-v3", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml b/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml new file mode 100644 index 0000000000..4e5374dc9d --- /dev/null +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml @@ -0,0 +1,9 @@ +model: jina-embeddings-v3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 2048 +pricing: + input: '0.001' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index ef12e534db..5033f0f748 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -56,6 +56,9 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]} + if model == "jina-embeddings-v3": + data["task_type"] = "retrieval.passage" + try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: From b6ad7a1e06cc67b293b2c6ed3449fdd105d800f7 Mon Sep 17 00:00:00 2001 From: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Date: Sat, 14 Sep 2024 17:14:18 +0800 Subject: [PATCH 24/43] =?UTF-8?q?Fix:=20https://github.com/langgenius/dify?= =?UTF-8?q?/issues/8190=20(Update=20Model=20nam=E2=80=A6=20(#8426)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Yuanbo Li --- api/core/model_runtime/model_providers/bedrock/llm/llm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 06a8606901..77bab0c294 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,8 +1,8 @@ # standard import import base64 -import io import json import logging +import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -17,7 +17,6 @@ from botocore.exceptions import ( ServiceNotInRegionError, UnknownServiceError, ) -from PIL.Image import Image # local import from core.model_runtime.callbacks.base_callback import Callback @@ -443,8 +442,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): try: url = message_content.data image_content = requests.get(url).content - with Image.open(io.BytesIO(image_content)) as img: - mime_type = f"image/{img.format.lower()}" + if "?" in url: + url = url.split("?")[0] + mime_type, _ = mimetypes.guess_type(url) base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") From d882348f390dbf935158837a5f797392a47b2953 Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Sat, 14 Sep 2024 17:24:31 +0800 Subject: [PATCH 25/43] fix: delete the delay for the tooltips inside the add tool panel (#8436) --- web/app/components/tools/add-tool-modal/tools.tsx | 1 - 1 file changed, 1 deletion(-) diff --git a/web/app/components/tools/add-tool-modal/tools.tsx b/web/app/components/tools/add-tool-modal/tools.tsx index af26dd3e25..f6080a1c23 100644 --- a/web/app/components/tools/add-tool-modal/tools.tsx +++ b/web/app/components/tools/add-tool-modal/tools.tsx @@ -90,7 +90,6 @@ const Blocks = ({ )}
)} - needsDelay >
Date: Sat, 14 Sep 2024 18:02:43 +0800 Subject: [PATCH 26/43] chore(workflow): Optimize the iteration when selecting a variable from a branch in the output variable causes iteration index err (#8440) --- .../workflow/graph_engine/entities/graph.py | 16 +- .../nodes/iteration/iteration_node.py | 191 +++++++----------- 2 files changed, 80 insertions(+), 127 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 1d7e9158d8..1175f4af2a 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -689,23 +689,11 @@ class Graph(BaseModel): parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) - parallel_start_node_id = None - for p_start_node_id, branch_node_ids in parallel_start_node_ids.items(): + for _, branch_node_ids in parallel_start_node_ids.items(): if set(branch_node_ids) == set(routes_node_ids.keys()): - parallel_start_node_id = p_start_node_id return True - if not parallel_start_node_id: - raise Exception("Parallel start node id not found") - - for graph_edge in reverse_edge_mapping[start_node_id]: - if ( - graph_edge.source_node_id not in all_routes_node_ids - or graph_edge.source_node_id != parallel_start_node_id - ): - return False - - return True + return False @classmethod def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 4d944e93db..6f20745daf 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.iteration.entities import IterationNodeData -from core.workflow.utils.condition.entities import Condition from models.workflow import WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) @@ -68,38 +66,6 @@ class IterationNode(BaseNode): if not iteration_graph: raise ValueError("iteration graph not found") - leaf_node_ids = iteration_graph.get_leaf_node_ids() - iteration_leaf_node_ids = [] - for leaf_node_id in leaf_node_ids: - node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id) - if not node_config: - continue - - leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id") - if not leaf_node_iteration_id: - continue - - if leaf_node_iteration_id != self.node_id: - continue - - iteration_leaf_node_ids.append(leaf_node_id) - - # add condition of end nodes to root node - iteration_graph.add_extra_edge( - source_node_id=leaf_node_id, - target_node_id=root_node_id, - run_condition=RunCondition( - type="condition", - conditions=[ - Condition( - variable_selector=[self.node_id, "index"], - comparison_operator="<", - value=str(len(iterator_list_value)), - ) - ], - ), - ) - variable_pool = self.graph_runtime_state.variable_pool # append iteration variable (item, index) to variable pool @@ -149,91 +115,90 @@ class IterationNode(BaseNode): outputs: list[Any] = [] try: - # run workflow - rst = graph_engine.run() - for event in rst: - if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: - event.in_iteration_id = self.node_id + for _ in range(len(iterator_list_value)): + # run workflow + rst = graph_engine.run() + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id - if ( - isinstance(event, BaseNodeEvent) - and event.node_type == NodeType.ITERATION_START - and not isinstance(event, NodeRunStreamChunkEvent) - ): - continue + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue - if isinstance(event, NodeRunSucceededEvent): - if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} + if isinstance(event, NodeRunSucceededEvent): + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( - [self.node_id, "index"] - ) - event.route_node_state.node_run_result.metadata = metadata + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( + [self.node_id, "index"] + ) + event.route_node_state.node_run_result.metadata = metadata - yield event - - # handle iteration run result - if event.route_node_state.node_id in iteration_leaf_node_ids: - # append to iteration output variable list - current_iteration_output = variable_pool.get_any(self.node_data.output_selector) - outputs.append(current_iteration_output) - - # remove all nodes outputs from variable pool - for node_id in iteration_graph.node_ids: - variable_pool.remove_node(node_id) - - # move to next iteration - current_index = variable_pool.get([self.node_id, "index"]) - if current_index is None: - raise ValueError(f"iteration {self.node_id} current index not found") - - next_index = int(current_index.to_object()) + 1 - variable_pool.add([self.node_id, "index"], next_index) - - if next_index < len(iterator_list_value): - variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - - yield IterationRunNextEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - index=next_index, - pre_iteration_output=jsonable_encoder(current_iteration_output) - if current_iteration_output - else None, - ) - elif isinstance(event, BaseGraphEvent): - if isinstance(event, GraphRunFailedEvent): - # iteration run failed - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, - ) - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + yield event + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, ) - ) - break - else: - event = cast(InNodeEvent, event) - yield event + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return + else: + event = cast(InNodeEvent, event) + yield event + + # append to iteration output variable list + current_iteration_output = variable_pool.get_any(self.node_data.output_selector) + outputs.append(current_iteration_output) + + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove_node(node_id) + + # move to next iteration + current_index = variable_pool.get([self.node_id, "index"]) + if current_index is None: + raise ValueError(f"iteration {self.node_id} current index not found") + + next_index = int(current_index.to_object()) + 1 + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + pre_iteration_output=jsonable_encoder(current_iteration_output) + if current_iteration_output + else None, + ) yield IterationRunSucceededEvent( iteration_id=self.id, From 72b7f8a949e5a82dde94e5bfd3a4c37b265e2cd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B0=E5=9C=A8=E4=BF=AE=E8=A1=8C=E7=9A=84=E5=A4=A7?= =?UTF-8?q?=E8=A1=97=E4=B8=8A?= Date: Sat, 14 Sep 2024 18:59:06 +0800 Subject: [PATCH 27/43] Bugfix/fix feishu plugins (#8443) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 黎斌 --- ...raw_content.py => get_document_content.py} | 4 +- .../tools/get_document_content.yaml | 49 ++++++++++++ .../tools/get_document_raw_content.yaml | 23 ------ .../tools/list_document_block.yaml | 48 ------------ ...ument_block.py => list_document_blocks.py} | 2 +- .../tools/list_document_blocks.yaml | 74 +++++++++++++++++++ .../feishu_document/tools/write_document.yaml | 33 +++++---- api/core/tools/utils/feishu_api_utils.py | 12 +-- 8 files changed, 152 insertions(+), 93 deletions(-) rename api/core/tools/provider/builtin/feishu_document/tools/{get_document_raw_content.py => get_document_content.py} (79%) create mode 100644 api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml delete mode 100644 api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml delete mode 100644 api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml rename api/core/tools/provider/builtin/feishu_document/tools/{list_document_block.py => list_document_blocks.py} (90%) create mode 100644 api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py similarity index 79% rename from api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py rename to api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py index 83073e0822..c94a5f70ed 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py @@ -12,6 +12,8 @@ class GetDocumentRawContentTool(BuiltinTool): client = FeishuRequest(app_id, app_secret) document_id = tool_parameters.get("document_id") + mode = tool_parameters.get("mode") + lang = tool_parameters.get("lang", 0) - res = client.get_document_raw_content(document_id) + res = client.get_document_content(document_id, mode, lang) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml new file mode 100644 index 0000000000..51eda73a60 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml @@ -0,0 +1,49 @@ +identity: + name: get_document_content + author: Doug Lea + label: + en_US: Get Document Content + zh_Hans: 获取飞书云文档的内容 +description: + human: + en_US: Get document content + zh_Hans: 获取飞书云文档的内容 + llm: A tool for retrieving content from Feishu cloud documents. +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: mode + type: string + required: false + label: + en_US: mode + zh_Hans: 文档返回格式 + human_description: + en_US: Format of the document return, optional values are text, markdown, can be empty, default is markdown. + zh_Hans: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。 + llm_description: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。 + form: llm + + - name: lang + type: number + required: false + default: 0 + label: + en_US: lang + zh_Hans: 指定@用户的语言 + human_description: + en_US: | + Specifies the language for MentionUser, optional values are [0, 1]. 0: User's default name, 1: User's English name, default is 0. + zh_Hans: 指定返回的 MentionUser,即 @用户 的语言,可选值有 [0,1]。0:该用户的默认名称,1:该用户的英文名称,默认值为 0。 + llm_description: 指定返回的 MentionUser,即 @用户 的语言,可选值有 [0,1]。0:该用户的默认名称,1:该用户的英文名称,默认值为 0。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml deleted file mode 100644 index e5b0937e03..0000000000 --- a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml +++ /dev/null @@ -1,23 +0,0 @@ -identity: - name: get_document_raw_content - author: Doug Lea - label: - en_US: Get Document Raw Content - zh_Hans: 获取文档纯文本内容 -description: - human: - en_US: Get document raw content - zh_Hans: 获取文档纯文本内容 - llm: A tool for getting the plain text content of Feishu documents -parameters: - - name: document_id - type: string - required: true - label: - en_US: document_id - zh_Hans: 飞书文档的唯一标识 - human_description: - en_US: Unique ID of Feishu document document_id - zh_Hans: 飞书文档的唯一标识 document_id - llm_description: 飞书文档的唯一标识 document_id - form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml deleted file mode 100644 index d51e5a837c..0000000000 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml +++ /dev/null @@ -1,48 +0,0 @@ -identity: - name: list_document_block - author: Doug Lea - label: - en_US: List Document Block - zh_Hans: 获取飞书文档所有块 -description: - human: - en_US: List document block - zh_Hans: 获取飞书文档所有块的富文本内容并分页返回。 - llm: A tool to get all blocks of Feishu documents -parameters: - - name: document_id - type: string - required: true - label: - en_US: document_id - zh_Hans: 飞书文档的唯一标识 - human_description: - en_US: Unique ID of Feishu document document_id - zh_Hans: 飞书文档的唯一标识 document_id - llm_description: 飞书文档的唯一标识 document_id - form: llm - - - name: page_size - type: number - required: false - default: 500 - label: - en_US: page_size - zh_Hans: 分页大小 - human_description: - en_US: Paging size, the default and maximum value is 500. - zh_Hans: 分页大小, 默认值和最大值为 500。 - llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。 - form: llm - - - name: page_token - type: string - required: false - label: - en_US: page_token - zh_Hans: 分页标记 - human_description: - en_US: Pagination tag, used to paginate query results so that more items can be obtained in the next traversal. - zh_Hans: 分页标记,用于分页查询结果,以便下次遍历时获取更多项。 - llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py similarity index 90% rename from api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py rename to api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py index 8c0c4a3c97..572a7abf28 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py @@ -15,5 +15,5 @@ class ListDocumentBlockTool(BuiltinTool): page_size = tool_parameters.get("page_size", 500) page_token = tool_parameters.get("page_token", "") - res = client.list_document_block(document_id, page_token, page_size) + res = client.list_document_blocks(document_id, page_token, page_size) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml new file mode 100644 index 0000000000..019ac98390 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml @@ -0,0 +1,74 @@ +identity: + name: list_document_blocks + author: Doug Lea + label: + en_US: List Document Blocks + zh_Hans: 获取飞书文档所有块 +description: + human: + en_US: List document blocks + zh_Hans: 获取飞书文档所有块的富文本内容并分页返回 + llm: A tool to get all blocks of Feishu documents +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: llm + + - name: page_size + type: number + required: false + default: "500" + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: Paging size, the default and maximum value is 500. + zh_Hans: 分页大小, 默认值和最大值为 500。 + llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。 + form: llm + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: Pagination token used to navigate through query results, allowing retrieval of additional items in subsequent requests. + zh_Hans: 分页标记,用于分页查询结果,以便下次遍历时获取更多项。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml index 8ee219d4a7..4282e3dcf3 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml @@ -17,33 +17,35 @@ parameters: en_US: document_id zh_Hans: 飞书文档的唯一标识 human_description: - en_US: Unique ID of Feishu document document_id - zh_Hans: 飞书文档的唯一标识 document_id - llm_description: 飞书文档的唯一标识 document_id + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 form: llm - name: content type: string required: true label: - en_US: document content - zh_Hans: 文档内容 + en_US: Plain text or Markdown content + zh_Hans: 纯文本或 Markdown 内容 human_description: - en_US: Document content, supports markdown syntax, can be empty. - zh_Hans: 文档内容,支持 markdown 语法,可以为空。 - llm_description: + en_US: Plain text or Markdown content. Note that embedded tables in the document should not have merged cells. + zh_Hans: 纯文本或 Markdown 内容。注意文档的内嵌套表格不允许有单元格合并。 + llm_description: 纯文本或 Markdown 内容,注意文档的内嵌套表格不允许有单元格合并。 form: llm - name: position - type: select - required: true - default: start + type: string + required: false label: - en_US: Choose where to add content - zh_Hans: 选择添加内容的位置 + en_US: position + zh_Hans: 添加位置 human_description: - en_US: Please fill in start or end to add content at the beginning or end of the document respectively. - zh_Hans: 请填入 start 或 end, 分别表示在文档开头(start)或结尾(end)添加内容。 + en_US: | + Enumeration values: start or end. Use 'start' to add content at the beginning of the document, and 'end' to add content at the end. The default value is 'end'. + zh_Hans: 枚举值:start 或 end。使用 'start' 在文档开头添加内容,使用 'end' 在文档结尾添加内容,默认值为 'end'。 + llm_description: | + 枚举值 start、end,start: 在文档开头添加内容;end: 在文档结尾添加内容,默认值为 end。 form: llm options: - value: start @@ -54,3 +56,4 @@ parameters: label: en_US: end zh_Hans: 在文档结尾添加内容 + default: start diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index 44803d7d65..ffdb06498f 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -76,9 +76,9 @@ class FeishuRequest: url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} res = self._send_request(url, payload=payload) - return res.get("data") + return res - def get_document_raw_content(self, document_id: str) -> dict: + def get_document_content(self, document_id: str, mode: str, lang: int = 0) -> dict: """ API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content Example Response: @@ -92,16 +92,18 @@ class FeishuRequest: """ # noqa: E501 params = { "document_id": document_id, + "mode": mode, + "lang": lang, } - url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/get_document_raw_content" + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/get_document_content" res = self._send_request(url, method="get", params=params) return res.get("data").get("content") - def list_document_block(self, document_id: str, page_token: str, page_size: int = 500) -> dict: + def list_document_blocks(self, document_id: str, page_token: str, page_size: int = 500) -> dict: """ API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/list """ - url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/list_document_block" + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/list_document_blocks" params = { "document_id": document_id, "page_size": page_size, From 624331472a179f355f1a72e134bde5c683198fe7 Mon Sep 17 00:00:00 2001 From: Nam Vu Date: Sat, 14 Sep 2024 18:05:19 +0700 Subject: [PATCH 28/43] fix: Improve scrolling behavior for Conversation Opener (#8437) Co-authored-by: crazywoola <427733928@qq.com> --- web/app/components/base/chat/chat/index.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 65e49eff67..68194193c4 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -109,9 +109,9 @@ const Chat: FC = ({ const userScrolledRef = useRef(false) const handleScrollToBottom = useCallback(() => { - if (chatContainerRef.current && !userScrolledRef.current) + if (chatList.length > 1 && chatContainerRef.current && !userScrolledRef.current) chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight - }, []) + }, [chatList.length]) const handleWindowResize = useCallback(() => { if (chatContainerRef.current) From fa1af8e47bf11a44b67b025a6599e2239a1d4bd0 Mon Sep 17 00:00:00 2001 From: Ying Wang Date: Sat, 14 Sep 2024 19:06:37 +0800 Subject: [PATCH 29/43] add WorkflowClient.get_result, increase version number (#8435) Co-authored-by: wangying --- sdks/python-client/dify_client/client.py | 3 +++ sdks/python-client/setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index a8da1d7cae..b6b0ced2ce 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -131,3 +131,6 @@ class WorkflowClient(DifyClient): def stop(self, task_id, user): data = {"user": user} return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) + + def get_result(self, workflow_run_id): + return self._send_request("GET", f"/workflows/run/{workflow_run_id}") \ No newline at end of file diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py index e74748377e..e7253f7391 100644 --- a/sdks/python-client/setup.py +++ b/sdks/python-client/setup.py @@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh: setup( name="dify-client", - version="0.1.10", + version="0.1.11", author="Dify", author_email="hello@dify.ai", description="A package for interacting with the Dify Service-API", From 445497cf892ee0caaba80c55032810e579a4507e Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Sat, 14 Sep 2024 19:24:53 +0800 Subject: [PATCH 30/43] add svg render & Image preview optimization (#8387) Co-authored-by: crazywoola <427733928@qq.com> --- .../base/image-uploader/image-preview.tsx | 252 ++++++++++++++++-- web/app/components/base/markdown.tsx | 37 +-- web/app/components/base/svg-gallery/index.tsx | 79 ++++++ web/package.json | 1 + web/yarn.lock | 5 + 5 files changed, 334 insertions(+), 40 deletions(-) create mode 100644 web/app/components/base/svg-gallery/index.tsx diff --git a/web/app/components/base/image-uploader/image-preview.tsx b/web/app/components/base/image-uploader/image-preview.tsx index 41f29fda2e..e5bd4c1bbc 100644 --- a/web/app/components/base/image-uploader/image-preview.tsx +++ b/web/app/components/base/image-uploader/image-preview.tsx @@ -1,26 +1,42 @@ import type { FC } from 'react' -import { useRef } from 'react' +import React, { useCallback, useEffect, useRef, useState } from 'react' import { t } from 'i18next' import { createPortal } from 'react-dom' -import { RiCloseLine, RiExternalLinkLine } from '@remixicon/react' +import { RiAddBoxLine, RiCloseLine, RiDownloadCloud2Line, RiFileCopyLine, RiZoomInLine, RiZoomOutLine } from '@remixicon/react' import Tooltip from '@/app/components/base/tooltip' -import { randomString } from '@/utils' +import Toast from '@/app/components/base/toast' type ImagePreviewProps = { url: string title: string onCancel: () => void } + +const isBase64 = (str: string): boolean => { + try { + return btoa(atob(str)) === str + } + catch (err) { + return false + } +} + const ImagePreview: FC = ({ url, title, onCancel, }) => { - const selector = useRef(`copy-tooltip-${randomString(4)}`) + const [scale, setScale] = useState(1) + const [position, setPosition] = useState({ x: 0, y: 0 }) + const [isDragging, setIsDragging] = useState(false) + const imgRef = useRef(null) + const dragStartRef = useRef({ x: 0, y: 0 }) + const [isCopied, setIsCopied] = useState(false) + const containerRef = useRef(null) const openInNewTab = () => { // Open in a new window, considering the case when the page is inside an iframe - if (url.startsWith('http')) { + if (url.startsWith('http') || url.startsWith('https')) { window.open(url, '_blank') } else if (url.startsWith('data:image')) { @@ -29,34 +45,224 @@ const ImagePreview: FC = ({ win?.document.write(`${title}`) } else { - console.error('Unable to open image', url) + Toast.notify({ + type: 'error', + message: `Unable to open image: ${url}`, + }) + } + } + const downloadImage = () => { + // Open in a new window, considering the case when the page is inside an iframe + if (url.startsWith('http') || url.startsWith('https')) { + const a = document.createElement('a') + a.href = url + a.download = title + a.click() + } + else if (url.startsWith('data:image')) { + // Base64 image + const a = document.createElement('a') + a.href = url + a.download = title + a.click() + } + else { + Toast.notify({ + type: 'error', + message: `Unable to open image: ${url}`, + }) } } + const zoomIn = () => { + setScale(prevScale => Math.min(prevScale * 1.2, 15)) + } + + const zoomOut = () => { + setScale((prevScale) => { + const newScale = Math.max(prevScale / 1.2, 0.5) + if (newScale === 1) + setPosition({ x: 0, y: 0 }) // Reset position when fully zoomed out + + return newScale + }) + } + + const imageTobase64ToBlob = (base64: string, type = 'image/png'): Blob => { + const byteCharacters = atob(base64) + const byteArrays = [] + + for (let offset = 0; offset < byteCharacters.length; offset += 512) { + const slice = byteCharacters.slice(offset, offset + 512) + const byteNumbers = new Array(slice.length) + for (let i = 0; i < slice.length; i++) + byteNumbers[i] = slice.charCodeAt(i) + + const byteArray = new Uint8Array(byteNumbers) + byteArrays.push(byteArray) + } + + return new Blob(byteArrays, { type }) + } + + const imageCopy = useCallback(() => { + const shareImage = async () => { + try { + const base64Data = url.split(',')[1] + const blob = imageTobase64ToBlob(base64Data, 'image/png') + + await navigator.clipboard.write([ + new ClipboardItem({ + [blob.type]: blob, + }), + ]) + setIsCopied(true) + + Toast.notify({ + type: 'success', + message: t('common.operation.imageCopied'), + }) + } + catch (err) { + console.error('Failed to copy image:', err) + + const link = document.createElement('a') + link.href = url + link.download = `${title}.png` + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + + Toast.notify({ + type: 'info', + message: t('common.operation.imageDownloaded'), + }) + } + } + shareImage() + }, [title, url]) + + const handleWheel = useCallback((e: React.WheelEvent) => { + if (e.deltaY < 0) + zoomIn() + else + zoomOut() + }, []) + + const handleMouseDown = useCallback((e: React.MouseEvent) => { + if (scale > 1) { + setIsDragging(true) + dragStartRef.current = { x: e.clientX - position.x, y: e.clientY - position.y } + } + }, [scale, position]) + + const handleMouseMove = useCallback((e: React.MouseEvent) => { + if (isDragging && scale > 1) { + const deltaX = e.clientX - dragStartRef.current.x + const deltaY = e.clientY - dragStartRef.current.y + + // Calculate boundaries + const imgRect = imgRef.current?.getBoundingClientRect() + const containerRect = imgRef.current?.parentElement?.getBoundingClientRect() + + if (imgRect && containerRect) { + const maxX = (imgRect.width * scale - containerRect.width) / 2 + const maxY = (imgRect.height * scale - containerRect.height) / 2 + + setPosition({ + x: Math.max(-maxX, Math.min(maxX, deltaX)), + y: Math.max(-maxY, Math.min(maxY, deltaY)), + }) + } + } + }, [isDragging, scale]) + + const handleMouseUp = useCallback(() => { + setIsDragging(false) + }, []) + + useEffect(() => { + document.addEventListener('mouseup', handleMouseUp) + return () => { + document.removeEventListener('mouseup', handleMouseUp) + } + }, [handleMouseUp]) + + useEffect(() => { + const handleKeyDown = (event: KeyboardEvent) => { + if (event.key === 'Escape') + onCancel() + } + + window.addEventListener('keydown', handleKeyDown) + + // Set focus to the container element + if (containerRef.current) + containerRef.current.focus() + + // Cleanup function + return () => { + window.removeEventListener('keydown', handleKeyDown) + } + }, [onCancel]) + return createPortal( -
e.stopPropagation()}> +
e.stopPropagation()} + onWheel={handleWheel} + onMouseDown={handleMouseDown} + onMouseMove={handleMouseMove} + onMouseUp={handleMouseUp} + style={{ cursor: scale > 1 ? 'move' : 'default' }} + tabIndex={-1}> {/* eslint-disable-next-line @next/next/no-img-element */} {title} -
- -
- + +
+ {isCopied + ? + : } +
+
+ +
+ +
+
+ +
+ +
+
+ +
+ +
+
+ +
+ +
+
+
- + className='absolute top-6 right-6 flex items-center justify-center w-8 h-8 bg-white/8 rounded-lg backdrop-blur-[2px] cursor-pointer' + onClick={onCancel}> +
, diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index d4e7dac4ae..443ee3410c 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -5,6 +5,7 @@ import RemarkMath from 'remark-math' import RemarkBreaks from 'remark-breaks' import RehypeKatex from 'rehype-katex' import RemarkGfm from 'remark-gfm' +import RehypeRaw from 'rehype-raw' import SyntaxHighlighter from 'react-syntax-highlighter' import { atelierHeathLight } from 'react-syntax-highlighter/dist/esm/styles/hljs' import type { RefObject } from 'react' @@ -18,6 +19,7 @@ import ImageGallery from '@/app/components/base/image-gallery' import { useChatContext } from '@/app/components/base/chat/chat/context' import VideoGallery from '@/app/components/base/video-gallery' import AudioGallery from '@/app/components/base/audio-gallery' +import SVGRenderer from '@/app/components/base/svg-gallery' // Available language https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/AVAILABLE_LANGUAGES_HLJS.MD const capitalizationLanguageNameMap: Record = { @@ -40,6 +42,7 @@ const capitalizationLanguageNameMap: Record = { powershell: 'PowerShell', json: 'JSON', latex: 'Latex', + svg: 'SVG', } const getCorrectCapitalizationLanguageName = (language: string) => { if (!language) @@ -107,6 +110,7 @@ const useLazyLoad = (ref: RefObject): boolean => { // Error: Minified React error 185; // visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message // or use the non-minified dev environment for full errors and additional helpful warnings. + const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }) => { const [isSVG, setIsSVG] = useState(true) const match = /language-(\w+)/.exec(className || '') @@ -134,7 +138,7 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } >
{languageShowName}
- {language === 'mermaid' && } + {language === 'mermaid' && } {(language === 'mermaid' && isSVG) ? () - : ( - (language === 'echarts') - ? (
-
) + : (language === 'echarts' + ? (
) + : (language === 'svg' + ? () : ( {String(children).replace(/\n$/, '')} - ))} + )))}
) - : ( - - {children} - - ) + : ({children}) }, [chartData, children, className, inline, isSVG, language, languageShowName, match, props]) }) - CodeBlock.displayName = 'CodeBlock' const VideoBlock: CodeComponent = memo(({ node }) => { @@ -230,6 +227,7 @@ export function Markdown(props: { content: string; className?: string }) { remarkPlugins={[[RemarkGfm, RemarkMath, { singleDollarTextMath: false }], RemarkBreaks]} rehypePlugins={[ RehypeKatex, + RehypeRaw as any, // The Rehype plug-in is used to remove the ref attribute of an element () => { return (tree) => { @@ -244,6 +242,7 @@ export function Markdown(props: { content: string; className?: string }) { } }, ]} + disallowedElements={['script', 'iframe', 'head', 'html', 'meta', 'link', 'style', 'body']} components={{ code: CodeBlock, img: Img, @@ -266,19 +265,23 @@ export function Markdown(props: { content: string; className?: string }) { // This can happen when a component attempts to access an undefined object that references an unregistered map, causing the program to crash. export default class ErrorBoundary extends Component { - constructor(props) { + constructor(props: any) { super(props) this.state = { hasError: false } } - componentDidCatch(error, errorInfo) { + componentDidCatch(error: any, errorInfo: any) { this.setState({ hasError: true }) console.error(error, errorInfo) } render() { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-expect-error if (this.state.hasError) - return
Oops! ECharts reported a runtime error.
(see the browser console for more information)
+ return
Oops! An error occurred. This could be due to an ECharts runtime error or invalid SVG content.
(see the browser console for more information)
+ // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-expect-error return this.props.children } } diff --git a/web/app/components/base/svg-gallery/index.tsx b/web/app/components/base/svg-gallery/index.tsx new file mode 100644 index 0000000000..81e8e87655 --- /dev/null +++ b/web/app/components/base/svg-gallery/index.tsx @@ -0,0 +1,79 @@ +import { useEffect, useRef, useState } from 'react' +import { SVG } from '@svgdotjs/svg.js' +import ImagePreview from '@/app/components/base/image-uploader/image-preview' + +export const SVGRenderer = ({ content }: { content: string }) => { + const svgRef = useRef(null) + const [imagePreview, setImagePreview] = useState('') + const [windowSize, setWindowSize] = useState({ + width: typeof window !== 'undefined' ? window.innerWidth : 0, + height: typeof window !== 'undefined' ? window.innerHeight : 0, + }) + + const svgToDataURL = (svgElement: Element): string => { + const svgString = new XMLSerializer().serializeToString(svgElement) + const base64String = Buffer.from(svgString).toString('base64') + return `data:image/svg+xml;base64,${base64String}` + } + + useEffect(() => { + const handleResize = () => { + setWindowSize({ width: window.innerWidth, height: window.innerHeight }) + } + + window.addEventListener('resize', handleResize) + return () => window.removeEventListener('resize', handleResize) + }, []) + + useEffect(() => { + if (svgRef.current) { + try { + svgRef.current.innerHTML = '' + const draw = SVG().addTo(svgRef.current).size('100%', '100%') + + const parser = new DOMParser() + const svgDoc = parser.parseFromString(content, 'image/svg+xml') + const svgElement = svgDoc.documentElement + + if (!(svgElement instanceof SVGElement)) + throw new Error('Invalid SVG content') + + const originalWidth = parseInt(svgElement.getAttribute('width') || '400', 10) + const originalHeight = parseInt(svgElement.getAttribute('height') || '600', 10) + const scale = Math.min(windowSize.width / originalWidth, windowSize.height / originalHeight, 1) + const scaledWidth = originalWidth * scale + const scaledHeight = originalHeight * scale + draw.size(scaledWidth, scaledHeight) + + const rootElement = draw.svg(content) + rootElement.scale(scale) + + rootElement.click(() => { + setImagePreview(svgToDataURL(svgElement as Element)) + }) + } + catch (error) { + if (svgRef.current) + svgRef.current.innerHTML = 'Error rendering SVG. Wait for the image content to complete.' + } + } + }, [content, windowSize]) + + return ( + <> +
+ {imagePreview && ( setImagePreview('')} />)} + + ) +} + +export default SVGRenderer diff --git a/web/package.json b/web/package.json index 197a1e7e05..bc532fb242 100644 --- a/web/package.json +++ b/web/package.json @@ -44,6 +44,7 @@ "classnames": "^2.3.2", "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", + "@svgdotjs/svg.js": "^3.2.4", "dayjs": "^1.11.7", "echarts": "^5.4.1", "echarts-for-react": "^3.0.2", diff --git a/web/yarn.lock b/web/yarn.lock index 3c020c9664..bec2059a47 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -1489,6 +1489,11 @@ dependencies: "@sinonjs/commons" "^3.0.0" +"@svgdotjs/svg.js@^3.2.4": + version "3.2.4" + resolved "https://registry.yarnpkg.com/@svgdotjs/svg.js/-/svg.js-3.2.4.tgz#4716be92a64c66b29921b63f7235fcfb953fb13a" + integrity sha512-BjJ/7vWNowlX3Z8O4ywT58DqbNRyYlkk6Yz/D13aB7hGmfQTvGX4Tkgtm/ApYlu9M7lCQi15xUEidqMUmdMYwg== + "@swc/counter@^0.1.3": version "0.1.3" resolved "https://registry.npmjs.org/@swc/counter/-/counter-0.1.3.tgz" From 65162a87b6b1889b7b432e669c92777725456024 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Sat, 14 Sep 2024 21:48:24 +0800 Subject: [PATCH 31/43] fix:docker-compose.middleware.yaml start the Weaviate container by default (#8446) (#8447) --- docker/docker-compose.middleware.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 00faa2960a..251c62fee1 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -89,6 +89,7 @@ services: weaviate: image: semitechnologies/weaviate:1.19.0 profiles: + - "" - weaviate restart: always volumes: From 7e611ffbf3e018b98803123319d2e8bae81eaccd Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Sat, 14 Sep 2024 21:48:44 +0800 Subject: [PATCH 32/43] multi-retrival use dataset's top-k (#8416) --- api/core/rag/retrieval/dataset_retrieval.py | 2 +- .../tool/dataset_retriever/dataset_multi_retriever_tool.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 124c58f0fe..286ecd4c03 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -426,7 +426,7 @@ class DatasetRetrieval: retrieval_method=retrieval_model["search_method"], dataset_id=dataset.id, query=query, - top_k=top_k, + top_k=retrieval_model.get("top_k") or 2, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 6073b8e92e..ab7b40a253 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -165,7 +165,10 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + retrieval_method="keyword_search", + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, ) if documents: all_documents.extend(documents) @@ -176,7 +179,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): retrieval_method=retrieval_model["search_method"], dataset_id=dataset.id, query=query, - top_k=self.top_k, + top_k=retrieval_model.get("top_k") or 2, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, From bf16de50fe2bf16be0a3981ee8eae8c02bd20a8b Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 14 Sep 2024 21:50:02 +0800 Subject: [PATCH 33/43] fix: internal error when tool authorization (#8449) --- api/core/tools/provider/tool_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index 7ba9dda179..05c88b904e 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -153,7 +153,7 @@ class ToolProviderController(BaseModel, ABC): # check type credential_schema = credentials_need_to_validate[credential_name] - if credential_schema in { + if credential_schema.type in { ToolProviderCredentials.CredentialsType.SECRET_INPUT, ToolProviderCredentials.CredentialsType.TEXT_INPUT, }: From 4788e1c8c834ec251d3f5b9c0fadb0b1685c5c5b Mon Sep 17 00:00:00 2001 From: Ying Wang Date: Sun, 15 Sep 2024 17:08:52 +0800 Subject: [PATCH 34/43] [Python SDK] Add KnowledgeBaseClient and the corresponding test cases. (#8465) Co-authored-by: Wang Ying --- sdks/python-client/dify_client/client.py | 281 ++++++++++++++++++++++- sdks/python-client/setup.py | 2 +- sdks/python-client/tests/test_client.py | 149 +++++++++++- 3 files changed, 429 insertions(+), 3 deletions(-) diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index b6b0ced2ce..2be079bdf3 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,3 +1,4 @@ +import json import requests @@ -133,4 +134,282 @@ class WorkflowClient(DifyClient): return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) def get_result(self, workflow_run_id): - return self._send_request("GET", f"/workflows/run/{workflow_run_id}") \ No newline at end of file + return self._send_request("GET", f"/workflows/run/{workflow_run_id}") + + + +class KnowledgeBaseClient(DifyClient): + + def __init__(self, api_key, base_url: str = 'https://api.dify.ai/v1', dataset_id: str = None): + """ + Construct a KnowledgeBaseClient object. + + Args: + api_key (str): API key of Dify. + base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'. + dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to + create a new dataset. or list datasets. otherwise you need to set this. + """ + super().__init__( + api_key=api_key, + base_url=base_url + ) + self.dataset_id = dataset_id + + def _get_dataset_id(self): + if self.dataset_id is None: + raise ValueError("dataset_id is not set") + return self.dataset_id + + def create_dataset(self, name: str, **kwargs): + return self._send_request('POST', '/datasets', {'name': name}, **kwargs) + + def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): + return self._send_request('GET', f'/datasets?page={page}&limit={page_size}', **kwargs) + + def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs): + """ + Create a document by text. + + :param name: Name of the document + :param text: Text content of the document + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + data = { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'mode': 'automatic' + }, + 'name': name, + 'text': text + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" + return self._send_request("POST", url, json=data, **kwargs) + + def update_document_by_text(self, document_id, name, text, extra_params: dict = None, **kwargs): + """ + Update a document by text. + + :param document_id: ID of the document + :param name: Name of the document + :param text: Text content of the document + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + data = { + 'name': name, + 'text': text + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" + return self._send_request("POST", url, json=data, **kwargs) + + def create_document_by_file(self, file_path, original_document_id=None, extra_params: dict = None): + """ + Create a document by file. + + :param file_path: Path to the file + :param original_document_id: pass this ID if you want to replace the original document (optional) + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + files = {"file": open(file_path, "rb")} + data = { + 'process_rule': { + 'mode': 'automatic' + }, + 'indexing_technique': 'high_quality' + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + if original_document_id is not None: + data['original_document_id'] = original_document_id + url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) + + def update_document_by_file(self, document_id, file_path, extra_params: dict = None): + """ + Update a document by file. + + :param document_id: ID of the document + :param file_path: Path to the file + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: + """ + files = {"file": open(file_path, "rb")} + data = {} + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) + + def batch_indexing_status(self, batch_id: str, **kwargs): + """ + Get the status of the batch indexing. + + :param batch_id: ID of the batch uploading + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" + return self._send_request("GET", url, **kwargs) + + def delete_dataset(self): + """ + Delete this dataset. + + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}" + return self._send_request("DELETE", url) + + def delete_document(self, document_id): + """ + Delete a document. + + :param document_id: ID of the document + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" + return self._send_request("DELETE", url) + + def list_documents(self, page: int = None, page_size: int = None, keyword: str = None, **kwargs): + """ + Get a list of documents in this dataset. + + :return: Response from the API + """ + params = {} + if page is not None: + params['page'] = page + if page_size is not None: + params['limit'] = page_size + if keyword is not None: + params['keyword'] = keyword + url = f"/datasets/{self._get_dataset_id()}/documents" + return self._send_request("GET", url, params=params, **kwargs) + + def add_segments(self, document_id, segments, **kwargs): + """ + Add segments to a document. + + :param document_id: ID of the document + :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}] + :return: Response from the API + """ + data = {"segments": segments} + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" + return self._send_request("POST", url, json=data, **kwargs) + + def query_segments(self, document_id, keyword: str = None, status: str = None, **kwargs): + """ + Query segments in this document. + + :param document_id: ID of the document + :param keyword: query keyword, optional + :param status: status of the segment, optional, e.g. completed + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" + params = {} + if keyword is not None: + params['keyword'] = keyword + if status is not None: + params['status'] = status + if "params" in kwargs: + params.update(kwargs["params"]) + return self._send_request("GET", url, params=params, **kwargs) + + def delete_document_segment(self, document_id, segment_id): + """ + Delete a segment from a document. + + :param document_id: ID of the document + :param segment_id: ID of the segment + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" + return self._send_request("DELETE", url) + + def update_document_segment(self, document_id, segment_id, segment_data, **kwargs): + """ + Update a segment in a document. + + :param document_id: ID of the document + :param segment_id: ID of the segment + :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True} + :return: Response from the API + """ + data = {"segment": segment_data} + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" + return self._send_request("POST", url, json=data, **kwargs) diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py index e7253f7391..bb8ca46d97 100644 --- a/sdks/python-client/setup.py +++ b/sdks/python-client/setup.py @@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh: setup( name="dify-client", - version="0.1.11", + version="0.1.12", author="Dify", author_email="hello@dify.ai", description="A package for interacting with the Dify Service-API", diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py index 5259d082ca..301e733b6b 100644 --- a/sdks/python-client/tests/test_client.py +++ b/sdks/python-client/tests/test_client.py @@ -1,10 +1,157 @@ import os +import time import unittest -from dify_client.client import ChatClient, CompletionClient, DifyClient +from dify_client.client import ChatClient, CompletionClient, DifyClient, KnowledgeBaseClient API_KEY = os.environ.get("API_KEY") APP_ID = os.environ.get("APP_ID") +API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1") +FILE_PATH_BASE = os.path.dirname(__file__) + + +class TestKnowledgeBaseClient(unittest.TestCase): + def setUp(self): + self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) + self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) + self.dataset_id = None + self.document_id = None + self.segment_id = None + self.batch_id = None + + def _get_dataset_kb_client(self): + self.assertIsNotNone(self.dataset_id) + return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id) + + def test_001_create_dataset(self): + response = self.knowledge_base_client.create_dataset(name="test_dataset") + data = response.json() + self.assertIn("id", data) + self.dataset_id = data["id"] + self.assertEqual("test_dataset", data["name"]) + + # the following tests require to be executed in order because they use + # the dataset/document/segment ids from the previous test + self._test_002_list_datasets() + self._test_003_create_document_by_text() + time.sleep(1) + self._test_004_update_document_by_text() + # self._test_005_batch_indexing_status() + time.sleep(1) + self._test_006_update_document_by_file() + time.sleep(1) + self._test_007_list_documents() + self._test_008_delete_document() + self._test_009_create_document_by_file() + time.sleep(1) + self._test_010_add_segments() + self._test_011_query_segments() + self._test_012_update_document_segment() + self._test_013_delete_document_segment() + self._test_014_delete_dataset() + + def _test_002_list_datasets(self): + response = self.knowledge_base_client.list_datasets() + data = response.json() + self.assertIn("data", data) + self.assertIn("total", data) + + def _test_003_create_document_by_text(self): + client = self._get_dataset_kb_client() + response = client.create_document_by_text("test_document", "test_text") + data = response.json() + self.assertIn("document", data) + self.document_id = data["document"]["id"] + self.batch_id = data["batch"] + + def _test_004_update_document_by_text(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") + data = response.json() + self.assertIn("document", data) + self.assertIn("batch", data) + self.batch_id = data["batch"] + + def _test_005_batch_indexing_status(self): + client = self._get_dataset_kb_client() + response = client.batch_indexing_status(self.batch_id) + data = response.json() + self.assertEqual(response.status_code, 200) + + def _test_006_update_document_by_file(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) + data = response.json() + self.assertIn("document", data) + self.assertIn("batch", data) + self.batch_id = data["batch"] + + def _test_007_list_documents(self): + client = self._get_dataset_kb_client() + response = client.list_documents() + data = response.json() + self.assertIn("data", data) + + def _test_008_delete_document(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.delete_document(self.document_id) + data = response.json() + self.assertIn("result", data) + self.assertEqual("success", data["result"]) + + def _test_009_create_document_by_file(self): + client = self._get_dataset_kb_client() + response = client.create_document_by_file(self.README_FILE_PATH) + data = response.json() + self.assertIn("document", data) + self.document_id = data["document"]["id"] + self.batch_id = data["batch"] + + def _test_010_add_segments(self): + client = self._get_dataset_kb_client() + response = client.add_segments(self.document_id, [ + {"content": "test text segment 1"} + ]) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + segment = data["data"][0] + self.segment_id = segment["id"] + + def _test_011_query_segments(self): + client = self._get_dataset_kb_client() + response = client.query_segments(self.document_id) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + + def _test_012_update_document_segment(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.segment_id) + response = client.update_document_segment(self.document_id, self.segment_id, + {"content": "test text segment 1 updated"} + ) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + segment = data["data"] + self.assertEqual("test text segment 1 updated", segment["content"]) + + def _test_013_delete_document_segment(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.segment_id) + response = client.delete_document_segment(self.document_id, self.segment_id) + data = response.json() + self.assertIn("result", data) + self.assertEqual("success", data["result"]) + + def _test_014_delete_dataset(self): + client = self._get_dataset_kb_client() + response = client.delete_dataset() + self.assertEqual(204, response.status_code) class TestChatClient(unittest.TestCase): From b73faae0d072a2f658f5496c81449a121bee1f4f Mon Sep 17 00:00:00 2001 From: Hirotaka Miyagi <31152321+MH4GF@users.noreply.github.com> Date: Sun, 15 Sep 2024 18:09:47 +0900 Subject: [PATCH 35/43] fix(RunOnce): change to form submission instead of onKeyDown and onClick (#8460) --- .../share/text-generation/run-once/index.tsx | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/web/app/components/share/text-generation/run-once/index.tsx b/web/app/components/share/text-generation/run-once/index.tsx index e3b8d6a0a4..e74616b08d 100644 --- a/web/app/components/share/text-generation/run-once/index.tsx +++ b/web/app/components/share/text-generation/run-once/index.tsx @@ -1,4 +1,4 @@ -import type { FC } from 'react' +import type { FC, FormEvent } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' import { @@ -39,11 +39,16 @@ const RunOnce: FC = ({ onInputsChange(newInputs) } + const onSubmit = (e: FormEvent) => { + e.preventDefault() + onSend() + } + return (
{/* input form */} -
+ {promptConfig.prompt_variables.map(item => (
@@ -65,12 +70,6 @@ const RunOnce: FC = ({ placeholder={`${item.name}${!item.required ? `(${t('appDebug.variableTable.optional')})` : ''}`} value={inputs[item.key]} onChange={(e) => { onInputsChange({ ...inputs, [item.key]: e.target.value }) }} - onKeyDown={(e) => { - if (e.key === 'Enter') { - e.preventDefault() - onSend() - } - }} maxLength={item.max_length || DEFAULT_VALUE_MAX_LEN} /> )} @@ -124,8 +123,8 @@ const RunOnce: FC = ({ {t('common.operation.clear')}

{t(`${prefixSettings}.sso.description`)}