chore: refurbish Python code by applying refurb linter rules (#8296)

This commit is contained in:
Bowen Liang 2024-09-12 15:50:49 +08:00 committed by GitHub
parent c69f5b07ba
commit 40fb4d16ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
105 changed files with 220 additions and 276 deletions

View File

@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):
site = app.site site = app.site
if not site: if not site:
desc = args["desc"] if args["desc"] else "" desc = args["desc"] or ""
copy_right = args["copyright"] if args["copyright"] else "" copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" custom_disclaimer = args["custom_disclaimer"] or ""
else: else:
desc = site.description if site.description else args["desc"] if args["desc"] else "" desc = site.description or args["desc"] or ""
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" copy_right = site.copyright or args["copyright"] or ""
privacy_policy = ( privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()

View File

@ -99,14 +99,10 @@ class ChatMessageTextApi(Resource):
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = ( voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)

View File

@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account: if not account:
# Create account # Create account
account_name = user_info.name if user_info.name else "Dify" account_name = user_info.name or "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
) )

View File

@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
return { return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
}
class DatasetRetrievalSettingApi(Resource): class DatasetRetrievalSettingApi(Resource):

View File

@ -86,14 +86,10 @@ class ChatTextApi(InstalledAppResource):
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = ( voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)

View File

@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ApiToolManageService.test_api_tool_preview( return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id, current_user.current_tenant_id,
args["provider_name"] if args["provider_name"] else "", args["provider_name"] or "",
args["tool_name"], args["tool_name"],
args["credentials"], args["credentials"],
args["parameters"], args["parameters"],

View File

@ -84,14 +84,10 @@ class TextApi(Resource):
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = ( voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts( response = AudioService.transcript_tts(

View File

@ -83,14 +83,10 @@ class TextApi(WebApiResource):
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = ( voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None

View File

@ -256,7 +256,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model, model=model_instance.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer), message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="", system_fingerprint="",
) )
), ),

View File

@ -298,7 +298,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model=model_instance.model, model=model_instance.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer), message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="", system_fingerprint="",
) )
), ),

View File

@ -161,7 +161,7 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode), app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity, prompt_template_entity=prompt_template_entity,
inputs=inputs, inputs=inputs,
query=query if query else "", query=query or "",
files=files, files=files,
context=context, context=context,
memory=memory, memory=memory,
@ -189,7 +189,7 @@ class AppRunner:
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
query=query if query else "", query=query or "",
files=files, files=files,
context=context, context=context,
memory_config=memory_config, memory_config=memory_config,
@ -238,7 +238,7 @@ class AppRunner:
model=app_generate_entity.model_conf.model, model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text), message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage(), usage=usage or LLMUsage.empty_usage(),
), ),
), ),
PublishFrom.APPLICATION_MANAGER, PublishFrom.APPLICATION_MANAGER,
@ -351,7 +351,7 @@ class AppRunner:
tenant_id=tenant_id, tenant_id=tenant_id,
app_config=app_generate_entity.app_config, app_config=app_generate_entity.app_config,
inputs=inputs, inputs=inputs,
query=query if query else "", query=query or "",
message_id=message_id, message_id=message_id,
trace_manager=app_generate_entity.trace_manager, trace_manager=app_generate_entity.trace_manager,
) )

View File

@ -3,6 +3,7 @@ import importlib.util
import json import json
import logging import logging
import os import os
from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -63,8 +64,7 @@ class Extensible:
builtin_file_path = os.path.join(subdir_path, "__builtin__") builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path): if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding="utf-8") as f: position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position = int(f.read().strip())
position_map[extension_name] = position position_map[extension_name] = position
if (extension_name + ".py") not in file_names: if (extension_name + ".py") not in file_names:

View File

@ -39,7 +39,7 @@ class TokenBufferMemory:
) )
if message_limit and message_limit > 0: if message_limit and message_limit > 0:
message_limit = message_limit if message_limit <= 500 else 500 message_limit = min(message_limit, 500)
else: else:
message_limit = 500 message_limit = 500

View File

@ -449,7 +449,7 @@ if you are not sure about the structure.
model=real_model, model=real_model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=prompt_message, message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(), usage=usage or LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint, system_fingerprint=system_fingerprint,
), ),
credentials=credentials, credentials=credentials,

View File

@ -409,7 +409,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
), ),
) )
elif isinstance(chunk, ContentBlockDeltaEvent): elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else "" chunk_text = chunk.delta.text or ""
full_assistant_content += chunk_text full_assistant_content += chunk_text
# transform assistant message to prompt message # transform assistant message to prompt message

View File

@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
model=real_model, model=real_model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=prompt_message, message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(), usage=usage or LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint, system_fingerprint=system_fingerprint,
), ),
credentials=credentials, credentials=credentials,

View File

@ -225,7 +225,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
continue continue
# transform assistant message to prompt message # transform assistant message to prompt message
text = delta.text if delta.text else "" text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text) assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text full_text += text
@ -400,15 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
continue continue
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
)
full_assistant_content += delta.delta.content if delta.delta.content else "" full_assistant_content += delta.delta.content or ""
real_model = chunk.model real_model = chunk.model
system_fingerprint = chunk.system_fingerprint system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content if delta.delta.content else "" completion += delta.delta.content or ""
yield LLMResultChunk( yield LLMResultChunk(
model=real_model, model=real_model,

View File

@ -84,7 +84,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
) )
for i in range(len(sentences)) for i in range(len(sentences))
] ]
for index, future in enumerate(futures): for future in futures:
yield from future.result().__enter__().iter_bytes(1024) yield from future.result().__enter__().iter_bytes(1024)
else: else:

View File

@ -331,10 +331,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif "contentBlockDelta" in chunk: elif "contentBlockDelta" in chunk:
delta = chunk["contentBlockDelta"]["delta"] delta = chunk["contentBlockDelta"]["delta"]
if "text" in delta: if "text" in delta:
chunk_text = delta["text"] if delta["text"] else "" chunk_text = delta["text"] or ""
full_assistant_content += chunk_text full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else "", content=chunk_text or "",
) )
index = chunk["contentBlockDelta"]["contentBlockIndex"] index = chunk["contentBlockDelta"]["contentBlockIndex"]
yield LLMResultChunk( yield LLMResultChunk(
@ -751,7 +751,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif model_prefix == "cohere": elif model_prefix == "cohere":
output = response_body.get("generations")[0].get("text") output = response_body.get("generations")[0].get("text")
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, output if output else "") completion_tokens = self.get_num_tokens(model, credentials, output or "")
else: else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@ -828,7 +828,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=content_delta if content_delta else "", content=content_delta or "",
) )
index += 1 index += 1

View File

@ -302,11 +302,11 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
if delta.delta.function_call: if delta.delta.function_call:
function_calls = [delta.delta.function_call] function_calls = [delta.delta.function_call]
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
) )
if delta.finish_reason is not None: if delta.finish_reason is not None:

View File

@ -511,7 +511,7 @@ class LocalAILanguageModel(LargeLanguageModel):
delta = chunk.choices[0] delta = chunk.choices[0]
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])
if delta.finish_reason is not None: if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage # temp_assistant_prompt_message is used to calculate usage
@ -578,11 +578,11 @@ class LocalAILanguageModel(LargeLanguageModel):
if delta.delta.function_call: if delta.delta.function_call:
function_calls = [delta.delta.function_call] function_calls = [delta.delta.function_call]
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
) )
if delta.finish_reason is not None: if delta.finish_reason is not None:

View File

@ -211,7 +211,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
index=0, index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]), message=AssistantPromptMessage(content=message.content, tool_calls=[]),
usage=usage, usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason or None,
), ),
) )
elif message.function_call: elif message.function_call:
@ -244,7 +244,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]), message=AssistantPromptMessage(content=message.content, tool_calls=[]),
finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason or None,
), ),
) )

View File

@ -65,7 +65,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
inputs = [] inputs = []
used_tokens = 0 used_tokens = 0
for i, text in enumerate(texts): for text in texts:
# Here token count is only an approximation based on the GPT2 tokenizer # Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text) num_tokens = self._get_num_tokens_by_gpt2(text)

View File

@ -508,7 +508,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
continue continue
# transform assistant message to prompt message # transform assistant message to prompt message
text = delta.text if delta.text else "" text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text) assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text full_text += text
@ -760,11 +760,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
final_tool_calls.extend(tool_calls) final_tool_calls.extend(tool_calls)
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
)
full_assistant_content += delta.delta.content if delta.delta.content else "" full_assistant_content += delta.delta.content or ""
if has_finish_reason: if has_finish_reason:
final_chunk = LLMResultChunk( final_chunk = LLMResultChunk(

View File

@ -88,7 +88,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
) )
for i in range(len(sentences)) for i in range(len(sentences))
] ]
for index, future in enumerate(futures): for future in futures:
yield from future.result().__enter__().iter_bytes(1024) yield from future.result().__enter__().iter_bytes(1024)
else: else:

View File

@ -179,9 +179,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
features = [] features = []
function_calling_type = credentials.get("function_calling_type", "no_call") function_calling_type = credentials.get("function_calling_type", "no_call")
if function_calling_type in ["function_call"]: if function_calling_type == "function_call":
features.append(ModelFeature.TOOL_CALL) features.append(ModelFeature.TOOL_CALL)
elif function_calling_type in ["tool_call"]: elif function_calling_type == "tool_call":
features.append(ModelFeature.MULTI_TOOL_CALL) features.append(ModelFeature.MULTI_TOOL_CALL)
stream_function_calling = credentials.get("stream_function_calling", "supported") stream_function_calling = credentials.get("stream_function_calling", "supported")

View File

@ -179,7 +179,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
index=0, index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]), message=AssistantPromptMessage(content=message.content, tool_calls=[]),
usage=usage, usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason or None,
), ),
) )
else: else:
@ -189,7 +189,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]), message=AssistantPromptMessage(content=message.content, tool_calls=[]),
finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason or None,
), ),
) )

View File

@ -106,7 +106,7 @@ class OpenLLMGenerate:
timeout = 120 timeout = 120
data = { data = {
"stop": stop if stop else [], "stop": stop or [],
"prompt": "\n".join([message.content for message in prompt_messages]), "prompt": "\n".join([message.content for message in prompt_messages]),
"llm_config": default_llm_config, "llm_config": default_llm_config,
} }

View File

@ -214,7 +214,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
index += 1 index += 1
assistant_prompt_message = AssistantPromptMessage(content=output if output else "") assistant_prompt_message = AssistantPromptMessage(content=output or "")
if index < prediction_output_length: if index < prediction_output_length:
yield LLMResultChunk( yield LLMResultChunk(

View File

@ -1,5 +1,6 @@
import json import json
import logging import logging
import operator
from typing import Any, Optional from typing import Any, Optional
import boto3 import boto3
@ -94,7 +95,7 @@ class SageMakerRerankModel(RerankModel):
for idx in range(len(scores)): for idx in range(len(scores)):
candidate_docs.append({"content": docs[idx], "score": scores[idx]}) candidate_docs.append({"content": docs[idx], "score": scores[idx]})
sorted(candidate_docs, key=lambda x: x["score"], reverse=True) sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
line = 3 line = 3
rerank_documents = [] rerank_documents = []

View File

@ -260,7 +260,7 @@ class SageMakerText2SpeechModel(TTSModel):
for payload in payloads for payload in payloads
] ]
for index, future in enumerate(futures): for future in futures:
resp = future.result() resp = future.result()
audio_bytes = requests.get(resp.get("s3_presign_url")).content audio_bytes = requests.get(resp.get("s3_presign_url")).content
for i in range(0, len(audio_bytes), 1024): for i in range(0, len(audio_bytes), 1024):

View File

@ -220,7 +220,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
delta = content delta = content
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=delta if delta else "", content=delta or "",
) )
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)

View File

@ -1,6 +1,7 @@
import base64 import base64
import hashlib import hashlib
import hmac import hmac
import operator
import time import time
import requests import requests
@ -127,7 +128,7 @@ class FlashRecognizer:
return s return s
def _build_req_with_signature(self, secret_key, params, header): def _build_req_with_signature(self, secret_key, params, header):
query = sorted(params.items(), key=lambda d: d[0]) query = sorted(params.items(), key=operator.itemgetter(0))
signstr = self._format_sign_string(query) signstr = self._format_sign_string(query)
signature = self._sign(signstr, secret_key) signature = self._sign(signstr, secret_key)
header["Authorization"] = signature header["Authorization"] = signature

View File

@ -4,6 +4,7 @@ import tempfile
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path
from typing import Optional, Union, cast from typing import Optional, Union, cast
from dashscope import Generation, MultiModalConversation, get_tokenizer from dashscope import Generation, MultiModalConversation, get_tokenizer
@ -454,8 +455,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}") file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}")
with open(file_path, "wb") as image_file: Path(file_path).write_bytes(base64.b64decode(encoded_string))
image_file.write(base64.b64decode(encoded_string))
return f"file://{file_path}" return f"file://{file_path}"

View File

@ -368,11 +368,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
final_tool_calls.extend(tool_calls) final_tool_calls.extend(tool_calls)
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
)
full_assistant_content += delta.delta.content if delta.delta.content else "" full_assistant_content += delta.delta.content or ""
if has_finish_reason: if has_finish_reason:
final_chunk = LLMResultChunk( final_chunk = LLMResultChunk(

View File

@ -231,10 +231,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
), ),
) )
elif isinstance(chunk, ContentBlockDeltaEvent): elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else "" chunk_text = chunk.delta.text or ""
full_assistant_content += chunk_text full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else "", content=chunk_text or "",
) )
index = chunk.index index = chunk.index
yield LLMResultChunk( yield LLMResultChunk(

View File

@ -1,5 +1,6 @@
# coding : utf-8 # coding : utf-8
import datetime import datetime
from itertools import starmap
import pytz import pytz
@ -48,7 +49,7 @@ class SignResult:
self.authorization = "" self.authorization = ""
def __str__(self): def __str__(self):
return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()]) return "\n".join(list(starmap("{}:{}".format, self.__dict__.items())))
class Credentials: class Credentials:

View File

@ -1,5 +1,6 @@
import hashlib import hashlib
import hmac import hmac
import operator
from functools import reduce from functools import reduce
from urllib.parse import quote from urllib.parse import quote
@ -40,4 +41,4 @@ class Util:
if len(hv) == 1: if len(hv) == 1:
hv = "0" + hv hv = "0" + hv
lst.append(hv) lst.append(hv)
return reduce(lambda x, y: x + y, lst) return reduce(operator.add, lst)

View File

@ -174,9 +174,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=index, index=index,
message=AssistantPromptMessage( message=AssistantPromptMessage(content=message["content"] or "", tool_calls=[]),
content=message["content"] if message["content"] else "", tool_calls=[]
),
usage=usage, usage=usage,
finish_reason=choice.get("finish_reason"), finish_reason=choice.get("finish_reason"),
), ),
@ -208,7 +206,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage( message=AssistantPromptMessage(
content=message["content"] if message["content"] else "", content=message["content"] or "",
tool_calls=tool_calls, tool_calls=tool_calls,
), ),
usage=self._calc_response_usage( usage=self._calc_response_usage(
@ -284,7 +282,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage( message=AssistantPromptMessage(
content=message.content if message.content else "", content=message.content or "",
tool_calls=tool_calls, tool_calls=tool_calls,
), ),
usage=self._calc_response_usage( usage=self._calc_response_usage(

View File

@ -199,7 +199,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
secret_key=credentials["secret_key"], secret_key=credentials["secret_key"],
) )
user = user if user else "ErnieBotDefault" user = user or "ErnieBotDefault"
# convert prompt messages to baichuan messages # convert prompt messages to baichuan messages
messages = [ messages = [
@ -289,7 +289,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
index=0, index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]), message=AssistantPromptMessage(content=message.content, tool_calls=[]),
usage=usage, usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason or None,
), ),
) )
else: else:
@ -299,7 +299,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]), message=AssistantPromptMessage(content=message.content, tool_calls=[]),
finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason or None,
), ),
) )

View File

@ -85,7 +85,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel):
api_key = credentials["api_key"] api_key = credentials["api_key"]
secret_key = credentials["secret_key"] secret_key = credentials["secret_key"]
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
user = user if user else "ErnieBotDefault" user = user or "ErnieBotDefault"
context_size = self._get_context_size(model, credentials) context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials) max_chunks = self._get_max_chunks(model, credentials)

View File

@ -589,7 +589,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
# convert tool call to assistant message tool call # convert tool call to assistant message tool call
tool_calls = assistant_message.tool_calls tool_calls = assistant_message.tool_calls
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls or [])
function_call = assistant_message.function_call function_call = assistant_message.function_call
if function_call: if function_call:
assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
@ -652,7 +652,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
) )
if delta.finish_reason is not None: if delta.finish_reason is not None:
@ -749,7 +749,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
delta = chunk.choices[0] delta = chunk.choices[0]
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])
if delta.finish_reason is not None: if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage # temp_assistant_prompt_message is used to calculate usage

View File

@ -215,7 +215,7 @@ class XinferenceText2SpeechModel(TTSModel):
for i in range(len(sentences)) for i in range(len(sentences))
] ]
for index, future in enumerate(futures): for future in futures:
response = future.result() response = future.result()
for i in range(0, len(response), 1024): for i in range(0, len(response), 1024):
yield response[i : i + 1024] yield response[i : i + 1024]

View File

@ -414,10 +414,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls content=delta.delta.content or "", tool_calls=assistant_tool_calls
) )
full_assistant_content += delta.delta.content if delta.delta.content else "" full_assistant_content += delta.delta.content or ""
if delta.finish_reason is not None and chunk.usage is not None: if delta.finish_reason is not None and chunk.usage is not None:
completion_tokens = chunk.usage.completion_tokens completion_tokens = chunk.usage.completion_tokens

View File

@ -30,6 +30,8 @@ def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
return {key: val for key, val in merged.items() if val is not None} return {key: val for key, val in merged.items() if val is not None}
from itertools import starmap
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
@ -159,7 +161,7 @@ class HttpClient:
return [(key, str_data)] return [(key, str_data)]
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
items = flatten([self._object_to_formdata(k, v) for k, v in data.items()]) items = flatten(list(starmap(self._object_to_formdata, data.items())))
serialized: dict[str, object] = {} serialized: dict[str, object] = {}
for key, value in items: for key, value in items:

View File

@ -65,7 +65,7 @@ class LangFuseDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info) self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
user_id = trace_info.metadata.get("user_id") user_id = trace_info.metadata.get("user_id")
if trace_info.message_id: if trace_info.message_id:
trace_id = trace_info.message_id trace_id = trace_info.message_id
@ -84,7 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance):
) )
self.add_trace(langfuse_trace_data=trace_data) self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan( workflow_span_data = LangfuseSpan(
id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id), id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
name=TraceTaskName.WORKFLOW_TRACE.value, name=TraceTaskName.WORKFLOW_TRACE.value,
input=trace_info.workflow_run_inputs, input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs, output=trace_info.workflow_run_outputs,
@ -93,7 +93,7 @@ class LangFuseDataTrace(BaseTraceInstance):
end_time=trace_info.end_time, end_time=trace_info.end_time,
metadata=trace_info.metadata, metadata=trace_info.metadata,
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
status_message=trace_info.error if trace_info.error else "", status_message=trace_info.error or "",
) )
self.add_span(langfuse_span_data=workflow_span_data) self.add_span(langfuse_span_data=workflow_span_data)
else: else:
@ -143,7 +143,7 @@ class LangFuseDataTrace(BaseTraceInstance):
else: else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at if node_execution.created_at else datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
@ -172,10 +172,8 @@ class LangFuseDataTrace(BaseTraceInstance):
end_time=finished_at, end_time=finished_at,
metadata=metadata, metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error if trace_info.error else "", status_message=trace_info.error or "",
parent_observation_id=( parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
),
) )
else: else:
span_data = LangfuseSpan( span_data = LangfuseSpan(
@ -188,7 +186,7 @@ class LangFuseDataTrace(BaseTraceInstance):
end_time=finished_at, end_time=finished_at,
metadata=metadata, metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error if trace_info.error else "", status_message=trace_info.error or "",
) )
self.add_span(langfuse_span_data=span_data) self.add_span(langfuse_span_data=span_data)
@ -212,7 +210,7 @@ class LangFuseDataTrace(BaseTraceInstance):
output=outputs, output=outputs,
metadata=metadata, metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error if trace_info.error else "", status_message=trace_info.error or "",
usage=generation_usage, usage=generation_usage,
) )
@ -277,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance):
output=message_data.answer, output=message_data.answer,
metadata=metadata, metadata=metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
status_message=message_data.error if message_data.error else "", status_message=message_data.error or "",
usage=generation_usage, usage=generation_usage,
) )
@ -319,7 +317,7 @@ class LangFuseDataTrace(BaseTraceInstance):
end_time=trace_info.end_time, end_time=trace_info.end_time,
metadata=trace_info.metadata, metadata=trace_info.metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
status_message=message_data.error if message_data.error else "", status_message=message_data.error or "",
usage=generation_usage, usage=generation_usage,
) )

View File

@ -82,7 +82,7 @@ class LangSmithDataTrace(BaseTraceInstance):
langsmith_run = LangSmithRunModel( langsmith_run = LangSmithRunModel(
file_list=trace_info.file_list, file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens, total_tokens=trace_info.total_tokens,
id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value, name=TraceTaskName.WORKFLOW_TRACE.value,
inputs=trace_info.workflow_run_inputs, inputs=trace_info.workflow_run_inputs,
run_type=LangSmithRunType.tool, run_type=LangSmithRunType.tool,
@ -94,7 +94,7 @@ class LangSmithDataTrace(BaseTraceInstance):
}, },
error=trace_info.error, error=trace_info.error,
tags=["workflow"], tags=["workflow"],
parent_run_id=trace_info.message_id if trace_info.message_id else None, parent_run_id=trace_info.message_id or None,
) )
self.add_run(langsmith_run) self.add_run(langsmith_run)
@ -133,7 +133,7 @@ class LangSmithDataTrace(BaseTraceInstance):
else: else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at if node_execution.created_at else datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
@ -180,9 +180,7 @@ class LangSmithDataTrace(BaseTraceInstance):
extra={ extra={
"metadata": metadata, "metadata": metadata,
}, },
parent_run_id=trace_info.workflow_app_log_id parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
if trace_info.workflow_app_log_id
else trace_info.workflow_run_id,
tags=["node_execution"], tags=["node_execution"],
) )

View File

@ -354,11 +354,11 @@ class TraceTask:
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
workflow_run_version = workflow_run.version workflow_run_version = workflow_run.version
error = workflow_run.error if workflow_run.error else "" error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else [] file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id # get workflow_app_log_id
@ -452,7 +452,7 @@ class TraceTask:
message_tokens=message_tokens, message_tokens=message_tokens,
answer_tokens=message_data.answer_tokens, answer_tokens=message_data.answer_tokens,
total_tokens=message_tokens + message_data.answer_tokens, total_tokens=message_tokens + message_data.answer_tokens,
error=message_data.error if message_data.error else "", error=message_data.error or "",
inputs=inputs, inputs=inputs,
outputs=message_data.answer, outputs=message_data.answer,
file_list=file_list, file_list=file_list,
@ -487,7 +487,7 @@ class TraceTask:
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
moderation_trace_info = ModerationTraceInfo( moderation_trace_info = ModerationTraceInfo(
message_id=workflow_app_log_id if workflow_app_log_id else message_id, message_id=workflow_app_log_id or message_id,
inputs=inputs, inputs=inputs,
message_data=message_data.to_dict(), message_data=message_data.to_dict(),
flagged=moderation_result.flagged, flagged=moderation_result.flagged,
@ -527,7 +527,7 @@ class TraceTask:
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
suggested_question_trace_info = SuggestedQuestionTraceInfo( suggested_question_trace_info = SuggestedQuestionTraceInfo(
message_id=workflow_app_log_id if workflow_app_log_id else message_id, message_id=workflow_app_log_id or message_id,
message_data=message_data.to_dict(), message_data=message_data.to_dict(),
inputs=message_data.message, inputs=message_data.message,
outputs=message_data.answer, outputs=message_data.answer,
@ -569,7 +569,7 @@ class TraceTask:
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
message_id=message_id, message_id=message_id,
inputs=message_data.query if message_data.query else message_data.inputs, inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents], documents=[doc.model_dump() for doc in documents],
start_time=timer.get("start"), start_time=timer.get("start"),
end_time=timer.get("end"), end_time=timer.get("end"),
@ -695,8 +695,7 @@ class TraceQueueManager:
self.start_timer() self.start_timer()
def add_trace_task(self, trace_task: TraceTask): def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer global trace_manager_timer, trace_manager_queue
global trace_manager_queue
try: try:
if self.trace_instance: if self.trace_instance:
trace_task.app_id = self.app_id trace_task.app_id = self.app_id

View File

@ -112,11 +112,11 @@ class SimplePromptTransform(PromptTransform):
for v in prompt_template_config["special_variable_keys"]: for v in prompt_template_config["special_variable_keys"]:
# support #context#, #query# and #histories# # support #context#, #query# and #histories#
if v == "#context#": if v == "#context#":
variables["#context#"] = context if context else "" variables["#context#"] = context or ""
elif v == "#query#": elif v == "#query#":
variables["#query#"] = query if query else "" variables["#query#"] = query or ""
elif v == "#histories#": elif v == "#histories#":
variables["#histories#"] = histories if histories else "" variables["#histories#"] = histories or ""
prompt_template = prompt_template_config["prompt_template"] prompt_template = prompt_template_config["prompt_template"]
prompt = prompt_template.format(variables) prompt = prompt_template.format(variables)

View File

@ -34,7 +34,7 @@ class BaseKeyword(ABC):
raise NotImplementedError raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]: for text in texts.copy():
doc_id = text.metadata["doc_id"] doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id) exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node: if exists_duplicate_node:

View File

@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest( request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id, dbinstance_id=self.config.instance_id,
region_id=self.config.region_id, region_id=self.config.region_id,
@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest( request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id, dbinstance_id=self.config.instance_id,
region_id=self.config.region_id, region_id=self.config.region_id,

View File

@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name) collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
ids: list[str] = results["ids"][0] ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0] documents: list[str] = results["documents"][0]

View File

@ -86,8 +86,8 @@ class ElasticSearchVector(BaseVector):
id=uuids[i], id=uuids[i],
document={ document={
Field.CONTENT_KEY.value: documents[i].page_content, Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None, Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}, Field.METADATA_KEY.value: documents[i].metadata or {},
}, },
) )
self._client.indices.refresh(index=self._collection_name) self._client.indices.refresh(index=self._collection_name)
@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
if score > score_threshold: if score > score_threshold:
doc.metadata["score"] = score doc.metadata["score"] = score
docs.append(doc) docs.append(doc)

View File

@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
for result in results[0]: for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value) metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"] metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
if result["distance"] > score_threshold: if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc) docs.append(doc)

View File

@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") or 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
where_str = ( where_str = (
f"WHERE dist < {1 - score_threshold}" f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0

View File

@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
metadata = {} metadata = {}
metadata["score"] = hit["_score"] metadata["score"] = hit["_score"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
if hit["_score"] > score_threshold: if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc) docs.append(doc)

View File

@ -200,7 +200,7 @@ class OracleVector(BaseVector):
[numpy.array(query_vector)], [numpy.array(query_vector)],
) )
docs = [] docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
for record in cur: for record in cur:
metadata, text, distance = record metadata, text, distance = record
score = 1 - distance score = 1 - distance
@ -212,7 +212,7 @@ class OracleVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later # just not implement fetch by score_threshold now, may be later
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
if len(query) > 0: if len(query) > 0:
# Check which language the query is in # Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+") zh_pattern = re.compile("[\u4e00-\u9fa5]+")

View File

@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
metadata = record.meta metadata = record.meta
score = 1 - dis score = 1 - dis
metadata["score"] = score metadata["score"] = score
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
if score > score_threshold: if score > score_threshold:
doc = Document(page_content=record.text, metadata=metadata) doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc) docs.append(doc)

View File

@ -144,7 +144,7 @@ class PGVector(BaseVector):
(json.dumps(query_vector),), (json.dumps(query_vector),),
) )
docs = [] docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
for record in cur: for record in cur:
metadata, text, distance = record metadata, text, distance = record
score = 1 - distance score = 1 - distance

View File

@ -339,7 +339,7 @@ class QdrantVector(BaseVector):
for result in results: for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {} metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold # duplicate check score threshold
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
if result.score > score_threshold: if result.score > score_threshold:
metadata["score"] = result.score metadata["score"] = result.score
doc = Document( doc = Document(

View File

@ -230,7 +230,7 @@ class RelytVector(BaseVector):
# Organize results. # Organize results.
docs = [] docs = []
for document, score in results: for document, score in results:
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
if 1 - score > score_threshold: if 1 - score > score_threshold:
docs.append(document) docs.append(document)
return docs return docs

View File

@ -153,7 +153,7 @@ class TencentVector(BaseVector):
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout, timeout=self._client_config.timeout,
) )
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
return self._get_search_res(res, score_threshold) return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

View File

@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
filter = kwargs.get("filter") filter = kwargs.get("filter")
distance = 1 - score_threshold distance = 1 - score_threshold

View File

@ -49,7 +49,7 @@ class BaseVector(ABC):
raise NotImplementedError raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]: for text in texts.copy():
doc_id = text.metadata["doc_id"] doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id) exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node: if exists_duplicate_node:

View File

@ -153,7 +153,7 @@ class Vector:
return CacheEmbedding(embedding_model) return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]: for text in texts.copy():
doc_id = text.metadata["doc_id"] doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id) exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node: if exists_duplicate_node:

View File

@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 score_threshold = kwargs.get("score_threshold", 0.0)
# check score threshold # check score threshold
if score > score_threshold: if score > score_threshold:
doc.metadata["score"] = score doc.metadata["score"] = score

View File

@ -12,7 +12,7 @@ import mimetypes
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Mapping from collections.abc import Generator, Iterable, Mapping
from io import BufferedReader, BytesIO from io import BufferedReader, BytesIO
from pathlib import PurePath from pathlib import Path, PurePath
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, model_validator from pydantic import BaseModel, ConfigDict, model_validator
@ -56,8 +56,7 @@ class Blob(BaseModel):
def as_string(self) -> str: def as_string(self) -> str:
"""Read data as a string.""" """Read data as a string."""
if self.data is None and self.path: if self.data is None and self.path:
with open(str(self.path), encoding=self.encoding) as f: return Path(str(self.path)).read_text(encoding=self.encoding)
return f.read()
elif isinstance(self.data, bytes): elif isinstance(self.data, bytes):
return self.data.decode(self.encoding) return self.data.decode(self.encoding)
elif isinstance(self.data, str): elif isinstance(self.data, str):
@ -72,8 +71,7 @@ class Blob(BaseModel):
elif isinstance(self.data, str): elif isinstance(self.data, str):
return self.data.encode(self.encoding) return self.data.encode(self.encoding)
elif self.data is None and self.path: elif self.data is None and self.path:
with open(str(self.path), "rb") as f: return Path(str(self.path)).read_bytes()
return f.read()
else: else:
raise ValueError(f"Unable to get bytes for blob {self}") raise ValueError(f"Unable to get bytes for blob {self}")

View File

@ -68,8 +68,7 @@ class ExtractProcessor:
suffix = "." + re.search(r"\.(\w+)$", filename).group(1) suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, "wb") as file: Path(file_path).write_bytes(response.content)
file.write(response.content)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
if return_text: if return_text:
delimiter = "\n" delimiter = "\n"
@ -111,7 +110,7 @@ class ExtractProcessor:
) )
elif file_extension in [".htm", ".html"]: elif file_extension in [".htm", ".html"]:
extractor = HtmlExtractor(file_path) extractor = HtmlExtractor(file_path)
elif file_extension in [".docx"]: elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv": elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True) extractor = CSVExtractor(file_path, autodetect_encoding=True)
@ -143,7 +142,7 @@ class ExtractProcessor:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True) extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in [".htm", ".html"]: elif file_extension in [".htm", ".html"]:
extractor = HtmlExtractor(file_path) extractor = HtmlExtractor(file_path)
elif file_extension in [".docx"]: elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv": elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True) extractor = CSVExtractor(file_path, autodetect_encoding=True)

View File

@ -1,6 +1,7 @@
"""Document loader helpers.""" """Document loader helpers."""
import concurrent.futures import concurrent.futures
from pathlib import Path
from typing import NamedTuple, Optional, cast from typing import NamedTuple, Optional, cast
@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
import chardet import chardet
def read_and_detect(file_path: str) -> list[dict]: def read_and_detect(file_path: str) -> list[dict]:
with open(file_path, "rb") as f: rawdata = Path(file_path).read_bytes()
rawdata = f.read()
return cast(list[dict], chardet.detect_all(rawdata)) return cast(list[dict], chardet.detect_all(rawdata))
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:

View File

@ -1,6 +1,7 @@
"""Abstract interface for document loader implementations.""" """Abstract interface for document loader implementations."""
import re import re
from pathlib import Path
from typing import Optional, cast from typing import Optional, cast
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
@ -102,15 +103,13 @@ class MarkdownExtractor(BaseExtractor):
"""Parse file into tuples.""" """Parse file into tuples."""
content = "" content = ""
try: try:
with open(filepath, encoding=self._encoding) as f: content = Path(filepath).read_text(encoding=self._encoding)
content = f.read()
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
if self._autodetect_encoding: if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath) detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings: for encoding in detected_encodings:
try: try:
with open(filepath, encoding=encoding.encoding) as f: content = Path(filepath).read_text(encoding=encoding.encoding)
content = f.read()
break break
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue

View File

@ -1,5 +1,6 @@
"""Abstract interface for document loader implementations.""" """Abstract interface for document loader implementations."""
from pathlib import Path
from typing import Optional from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
@ -25,15 +26,13 @@ class TextExtractor(BaseExtractor):
"""Load from file path.""" """Load from file path."""
text = "" text = ""
try: try:
with open(self._file_path, encoding=self._encoding) as f: text = Path(self._file_path).read_text(encoding=self._encoding)
text = f.read()
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
if self._autodetect_encoding: if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path) detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings: for encoding in detected_encodings:
try: try:
with open(self._file_path, encoding=encoding.encoding) as f: text = Path(self._file_path).read_text(encoding=encoding.encoding)
text = f.read()
break break
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue

View File

@ -153,7 +153,7 @@ class WordExtractor(BaseExtractor):
if col_index >= total_cols: if col_index >= total_cols:
break break
cell_content = self._parse_cell(cell, image_map).strip() cell_content = self._parse_cell(cell, image_map).strip()
cell_colspan = cell.grid_span if cell.grid_span else 1 cell_colspan = cell.grid_span or 1
for i in range(cell_colspan): for i in range(cell_colspan):
if col_index + i < total_cols: if col_index + i < total_cols:
row_cells[col_index + i] = cell_content if i == 0 else "" row_cells[col_index + i] = cell_content if i == 0 else ""

View File

@ -256,7 +256,7 @@ class DatasetRetrieval:
# get retrieval model config # get retrieval model config
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset: if dataset:
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k # get top k
top_k = retrieval_model_config["top_k"] top_k = retrieval_model_config["top_k"]
@ -410,7 +410,7 @@ class DatasetRetrieval:
return [] return []
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
@ -433,9 +433,7 @@ class DatasetRetrieval:
reranking_model=retrieval_model.get("reranking_model", None) reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"] if retrieval_model["reranking_enable"]
else None, else None,
reranking_mode=retrieval_model.get("reranking_mode") reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None), weights=retrieval_model.get("weights", None),
) )
@ -486,7 +484,7 @@ class DatasetRetrieval:
} }
for dataset in available_datasets: for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k # get top k
top_k = retrieval_model_config["top_k"] top_k = retrieval_model_config["top_k"]

View File

@ -106,7 +106,7 @@ class ApiToolProviderController(ToolProviderController):
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
"llm": tool_bundle.summary or "", "llm": tool_bundle.summary or "",
}, },
"parameters": tool_bundle.parameters if tool_bundle.parameters else [], "parameters": tool_bundle.parameters or [],
} }
) )

View File

@ -1,4 +1,5 @@
import json import json
import operator
from typing import Any, Union from typing import Any, Union
import boto3 import boto3
@ -71,7 +72,7 @@ class SageMakerReRankTool(BuiltinTool):
candidate_docs[idx]["score"] = scores[idx] candidate_docs[idx]["score"] = scores[idx]
line = 8 line = 8
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True) sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
line = 9 line = 9
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]

View File

@ -115,7 +115,7 @@ class GetWorksheetFieldsTool(BuiltinTool):
fields.append(field) fields.append(field)
fields_list.append( fields_list.append(
f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}" f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}"
f"|{field['options'] if field['options'] else ''}|" f"|{field['options'] or ''}|"
) )
fields.append( fields.append(

View File

@ -130,7 +130,7 @@ class GetWorksheetPivotDataTool(BuiltinTool):
# ] # ]
rows = [] rows = []
for row in data["data"]: for row in data["data"]:
row_data = row["rows"] if row["rows"] else {} row_data = row["rows"] or {}
row_data.update(row["columns"]) row_data.update(row["columns"])
row_data.update(row["values"]) row_data.update(row["values"])
rows.append(row_data) rows.append(row_data)

View File

@ -113,7 +113,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"."
if result["total"] > 0: if result["total"] > 0:
result_text += ( result_text += (
f" The following are {result['total'] if result['total'] < limit else limit}" f" The following are {min(limit, result['total'])}"
f" pieces of data presented in a table format:\n\n{table_header}" f" pieces of data presented in a table format:\n\n{table_header}"
) )
for row in rows: for row in rows:

View File

@ -37,7 +37,7 @@ class SearchAPI:
return { return {
"engine": "youtube_transcripts", "engine": "youtube_transcripts",
"video_id": video_id, "video_id": video_id,
"lang": language if language else "en", "lang": language or "en",
**{key: value for key, value in kwargs.items() if value not in [None, ""]}, **{key: value for key, value in kwargs.items() if value not in [None, ""]},
} }

View File

@ -160,7 +160,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callback.on_query(query, dataset.id) hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
@ -183,9 +183,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
reranking_model=retrieval_model.get("reranking_model", None) reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"] if retrieval_model["reranking_enable"]
else None, else None,
reranking_mode=retrieval_model.get("reranking_mode") reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None), weights=retrieval_model.get("weights", None),
) )

View File

@ -55,7 +55,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
hit_callback.on_query(query, dataset.id) hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
@ -76,9 +76,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
reranking_model=retrieval_model.get("reranking_model", None) reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"] if retrieval_model["reranking_enable"]
else None, else None,
reranking_mode=retrieval_model.get("reranking_mode") reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None), weights=retrieval_model.get("weights", None),
) )
else: else:

View File

@ -8,6 +8,7 @@ import subprocess
import tempfile import tempfile
import unicodedata import unicodedata
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
from urllib.parse import unquote from urllib.parse import unquote
import chardet import chardet
@ -98,7 +99,7 @@ def get_url(url: str, user_agent: str = None) -> str:
authors=a["byline"], authors=a["byline"],
publish_date=a["date"], publish_date=a["date"],
top_image="", top_image="",
text=a["plain_text"] if a["plain_text"] else "", text=a["plain_text"] or "",
) )
return res return res
@ -117,8 +118,7 @@ def extract_using_readabilipy(html):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary # Read output of call to Readability.parse() from JSON file and return as Python dictionary
with open(article_json_path, encoding="utf-8") as json_file: input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
input_json = json.loads(json_file.read())
# Deleting files after processing # Deleting files after processing
os.unlink(article_json_path) os.unlink(article_json_path)

View File

@ -21,7 +21,7 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
with open(file_path, encoding="utf-8") as yaml_file: with open(file_path, encoding="utf-8") as yaml_file:
try: try:
yaml_content = yaml.safe_load(yaml_file) yaml_content = yaml.safe_load(yaml_file)
return yaml_content if yaml_content else default_value return yaml_content or default_value
except Exception as e: except Exception as e:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
except Exception as e: except Exception as e:

View File

@ -268,7 +268,7 @@ class Graph(BaseModel):
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph." f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
) )
new_route = route[:] new_route = route.copy()
new_route.append(graph_edge.target_node_id) new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node( cls._check_connected_to_previous_node(
route=new_route, route=new_route,
@ -679,8 +679,7 @@ class Graph(BaseModel):
all_routes_node_ids = set() all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {} parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items(): for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids: all_routes_node_ids.update(node_ids)
all_routes_node_ids.add(node_id)
if branch_node_id in reverse_edge_mapping: if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]: for graph_edge in reverse_edge_mapping[branch_node_id]:

View File

@ -74,7 +74,7 @@ class CodeNode(BaseNode):
:return: :return:
""" """
if not isinstance(value, str): if not isinstance(value, str):
if isinstance(value, type(None)): if value is None:
return None return None
else: else:
raise ValueError(f"Output variable `{variable}` must be a string") raise ValueError(f"Output variable `{variable}` must be a string")
@ -95,7 +95,7 @@ class CodeNode(BaseNode):
:return: :return:
""" """
if not isinstance(value, int | float): if not isinstance(value, int | float):
if isinstance(value, type(None)): if value is None:
return None return None
else: else:
raise ValueError(f"Output variable `{variable}` must be a number") raise ValueError(f"Output variable `{variable}` must be a number")
@ -182,7 +182,7 @@ class CodeNode(BaseNode):
f"Output {prefix}.{output_name} is not a valid array." f"Output {prefix}.{output_name} is not a valid array."
f" make sure all elements are of the same type." f" make sure all elements are of the same type."
) )
elif isinstance(output_value, type(None)): elif output_value is None:
pass pass
else: else:
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
@ -284,7 +284,7 @@ class CodeNode(BaseNode):
for i, value in enumerate(result[output_name]): for i, value in enumerate(result[output_name]):
if not isinstance(value, dict): if not isinstance(value, dict):
if isinstance(value, type(None)): if value is None:
pass pass
else: else:
raise ValueError( raise ValueError(

View File

@ -79,7 +79,7 @@ class IfElseNode(BaseNode):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs, inputs=node_inputs,
process_data=process_datas, process_data=process_datas,
edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default' edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
outputs=outputs, outputs=outputs,
) )

View File

@ -580,7 +580,7 @@ class LLMNode(BaseNode):
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template, prompt_template=node_data.prompt_template,
inputs=inputs, inputs=inputs,
query=query if query else "", query=query or "",
files=files, files=files,
context=context, context=context,
memory_config=node_data.memory, memory_config=node_data.memory,

View File

@ -250,7 +250,7 @@ class QuestionClassifierNode(LLMNode):
for class_ in classes: for class_ in classes:
category = {"category_id": class_.id, "category_name": class_.name} category = {"category_id": class_.id, "category_name": class_.name}
categories.append(category) categories.append(category)
instruction = node_data.instruction if node_data.instruction else "" instruction = node_data.instruction or ""
input_text = query input_text = query
memory_str = "" memory_str = ""
if memory: if memory:

View File

@ -18,8 +18,7 @@ def handle(sender, **kwargs):
added_dataset_ids = dataset_ids added_dataset_ids = dataset_ids
else: else:
old_dataset_ids = set() old_dataset_ids = set()
for app_dataset_join in app_dataset_joins: old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
old_dataset_ids.add(app_dataset_join.dataset_id)
added_dataset_ids = dataset_ids - old_dataset_ids added_dataset_ids = dataset_ids - old_dataset_ids
removed_dataset_ids = old_dataset_ids - dataset_ids removed_dataset_ids = old_dataset_ids - dataset_ids

View File

@ -22,8 +22,7 @@ def handle(sender, **kwargs):
added_dataset_ids = dataset_ids added_dataset_ids = dataset_ids
else: else:
old_dataset_ids = set() old_dataset_ids = set()
for app_dataset_join in app_dataset_joins: old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
old_dataset_ids.add(app_dataset_join.dataset_id)
added_dataset_ids = dataset_ids - old_dataset_ids added_dataset_ids = dataset_ids - old_dataset_ids
removed_dataset_ids = old_dataset_ids - dataset_ids removed_dataset_ids = old_dataset_ids - dataset_ids

View File

@ -1,6 +1,7 @@
import os import os
import shutil import shutil
from collections.abc import Generator from collections.abc import Generator
from pathlib import Path
from flask import Flask from flask import Flask
@ -26,8 +27,7 @@ class LocalStorage(BaseStorage):
folder = os.path.dirname(filename) folder = os.path.dirname(filename)
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
with open(os.path.join(os.getcwd(), filename), "wb") as f: Path(os.path.join(os.getcwd(), filename)).write_bytes(data)
f.write(data)
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
if not self.folder or self.folder.endswith("/"): if not self.folder or self.folder.endswith("/"):
@ -38,9 +38,7 @@ class LocalStorage(BaseStorage):
if not os.path.exists(filename): if not os.path.exists(filename):
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
with open(filename, "rb") as f: data = Path(filename).read_bytes()
data = f.read()
return data return data
def load_stream(self, filename: str) -> Generator: def load_stream(self, filename: str) -> Generator:

View File

@ -144,7 +144,7 @@ class Dataset(db.Model):
"top_k": 2, "top_k": 2,
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }
return self.retrieval_model if self.retrieval_model else default_retrieval_model return self.retrieval_model or default_retrieval_model
@property @property
def tags(self): def tags(self):
@ -160,7 +160,7 @@ class Dataset(db.Model):
.all() .all()
) )
return tags if tags else [] return tags or []
@staticmethod @staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str: def gen_collection_name_by_id(dataset_id: str) -> str:

View File

@ -118,7 +118,7 @@ class App(db.Model):
@property @property
def api_base_url(self): def api_base_url(self):
return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")) + "/v1" return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
@property @property
def tenant(self): def tenant(self):
@ -207,7 +207,7 @@ class App(db.Model):
.all() .all()
) )
return tags if tags else [] return tags or []
class AppModelConfig(db.Model): class AppModelConfig(db.Model):
@ -908,7 +908,7 @@ class Message(db.Model):
"id": message_file.id, "id": message_file.id,
"type": message_file.type, "type": message_file.type,
"url": url, "url": url,
"belongs_to": message_file.belongs_to if message_file.belongs_to else "user", "belongs_to": message_file.belongs_to or "user",
} }
) )
@ -1212,7 +1212,7 @@ class Site(db.Model):
@property @property
def app_base_url(self): def app_base_url(self):
return dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip("/") return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
class ApiToken(db.Model): class ApiToken(db.Model):
@ -1488,7 +1488,7 @@ class TraceAppConfig(db.Model):
@property @property
def tracing_config_dict(self): def tracing_config_dict(self):
return self.tracing_config if self.tracing_config else {} return self.tracing_config or {}
@property @property
def tracing_config_str(self): def tracing_config_str(self):

View File

@ -15,6 +15,7 @@ select = [
"C4", # flake8-comprehensions "C4", # flake8-comprehensions
"E", # pycodestyle E rules "E", # pycodestyle E rules
"F", # pyflakes rules "F", # pyflakes rules
"FURB", # refurb rules
"I", # isort rules "I", # isort rules
"N", # pep8-naming "N", # pep8-naming
"RUF019", # unnecessary-key-check "RUF019", # unnecessary-key-check
@ -37,6 +38,8 @@ ignore = [
"F405", # undefined-local-with-import-star-usage "F405", # undefined-local-with-import-star-usage
"F821", # undefined-name "F821", # undefined-name
"F841", # unused-variable "F841", # unused-variable
"FURB113", # repeated-append
"FURB152", # math-constant
"UP007", # non-pep604-annotation "UP007", # non-pep604-annotation
"UP032", # f-string "UP032", # f-string
"B005", # strip-with-multi-characters "B005", # strip-with-multi-characters

View File

@ -544,7 +544,7 @@ class RegisterService:
"""Register account""" """Register account"""
try: try:
account = AccountService.create_account( account = AccountService.create_account(
email=email, name=name, interface_language=language if language else languages[0], password=password email=email, name=name, interface_language=language or languages[0], password=password
) )
account.status = AccountStatus.ACTIVE.value if not status else status.value account.status = AccountStatus.ACTIVE.value if not status else status.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)

View File

@ -81,13 +81,11 @@ class AppDslService:
raise ValueError("Missing app in data argument") raise ValueError("Missing app in data argument")
# get app basic info # get app basic info
name = args.get("name") if args.get("name") else app_data.get("name") name = args.get("name") or app_data.get("name")
description = args.get("description") if args.get("description") else app_data.get("description", "") description = args.get("description") or app_data.get("description", "")
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") icon_type = args.get("icon_type") or app_data.get("icon_type")
icon = args.get("icon") if args.get("icon") else app_data.get("icon") icon = args.get("icon") or app_data.get("icon")
icon_background = ( icon_background = args.get("icon_background") or app_data.get("icon_background")
args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background")
)
use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False)
# import dsl and create app # import dsl and create app

View File

@ -155,7 +155,7 @@ class DatasetService:
dataset.tenant_id = tenant_id dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None dataset.embedding_model = embedding_model.model if embedding_model else None
dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
db.session.add(dataset) db.session.add(dataset)
db.session.commit() db.session.commit()
return dataset return dataset
@ -681,11 +681,7 @@ class DocumentService:
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }
dataset.retrieval_model = ( dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model
document_data.get("retrieval_model")
if document_data.get("retrieval_model")
else default_retrieval_model
)
documents = [] documents = []
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))

View File

@ -33,7 +33,7 @@ class HitTestingService:
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
if not retrieval_model: if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model or default_retrieval_model
all_documents = RetrievalService.retrieve( all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"), retrieval_method=retrieval_model.get("search_method", "semantic_search"),
@ -46,9 +46,7 @@ class HitTestingService:
reranking_model=retrieval_model.get("reranking_model", None) reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"] if retrieval_model["reranking_enable"]
else None, else None,
reranking_mode=retrieval_model.get("reranking_mode") reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None), weights=retrieval_model.get("weights", None),
) )

View File

@ -1,6 +1,7 @@
import logging import logging
import mimetypes import mimetypes
import os import os
from pathlib import Path
from typing import Optional, cast from typing import Optional, cast
import requests import requests
@ -453,9 +454,8 @@ class ModelProviderService:
mimetype = mimetype or "application/octet-stream" mimetype = mimetype or "application/octet-stream"
# read binary from file # read binary from file
with open(file_path, "rb") as f: byte_data = Path(file_path).read_bytes()
byte_data = f.read() return byte_data, mimetype
return byte_data, mimetype
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
""" """

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
from os import path from os import path
from pathlib import Path
from typing import Optional from typing import Optional
import requests import requests
@ -218,10 +219,9 @@ class RecommendedAppService:
return cls.builtin_data return cls.builtin_data
root_path = current_app.root_path root_path = current_app.root_path
with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f: cls.builtin_data = json.loads(
json_data = f.read() Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8")
data = json.loads(json_data) )
cls.builtin_data = data
return cls.builtin_data return cls.builtin_data

Some files were not shown because too many files have changed in this diff Show More