mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 14:39:16 +08:00
chore: cleanup ruff flake8-simplify linter rules (#8286)
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
0bb7569d46
commit
0f14873255
@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline:
|
|||||||
|
|
||||||
if isinstance(e, InvokeAuthorizationError):
|
if isinstance(e, InvokeAuthorizationError):
|
||||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
elif isinstance(e, InvokeError | ValueError):
|
||||||
err = e
|
err = e
|
||||||
else:
|
else:
|
||||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||||
|
@ -45,7 +45,7 @@ class BaichuanModel:
|
|||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if model in self._model_mapping.keys():
|
if model in self._model_mapping:
|
||||||
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||||
# we need to rename it to res_format to get its value
|
# we need to rename it to res_format to get its value
|
||||||
if parameters.get("res_format") == "json_object":
|
if parameters.get("res_format") == "json_object":
|
||||||
@ -94,7 +94,7 @@ class BaichuanModel:
|
|||||||
timeout: int,
|
timeout: int,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> Union[Iterator, dict]:
|
) -> Union[Iterator, dict]:
|
||||||
if model in self._model_mapping.keys():
|
if model in self._model_mapping:
|
||||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||||
else:
|
else:
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
raise BadRequestError(f"Unknown model: {model}")
|
||||||
|
@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -442,9 +442,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -350,9 +350,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||||||
break
|
break
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = content
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = content
|
message_text = content
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -633,9 +633,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -272,11 +272,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
text = ""
|
text = ""
|
||||||
for item in message:
|
for item in message:
|
||||||
if isinstance(item, UserPromptMessage):
|
if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage):
|
||||||
text += item.content
|
|
||||||
elif isinstance(item, SystemPromptMessage):
|
|
||||||
text += item.content
|
|
||||||
elif isinstance(item, AssistantPromptMessage):
|
|
||||||
text += item.content
|
text += item.content
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")
|
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")
|
||||||
|
@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
):
|
):
|
||||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||||
else:
|
else:
|
||||||
if copy_prompt_message.role == PromptMessageRole.USER:
|
if (
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
copy_prompt_message.role == PromptMessageRole.USER
|
||||||
elif copy_prompt_message.role == PromptMessageRole.TOOL:
|
or copy_prompt_message.role == PromptMessageRole.TOOL
|
||||||
|
):
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
new_prompt_messages.append(copy_prompt_message)
|
||||||
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||||
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||||
@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = content
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = content
|
message_text = content
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -56,14 +56,7 @@ class KeywordsModeration(Moderation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||||
for value in inputs.values():
|
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||||
if self._check_keywords_in_value(keywords_list, value):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
def _check_keywords_in_value(self, keywords_list, value) -> bool:
|
||||||
|
return any(keyword.lower() in value.lower() for keyword in keywords_list)
|
||||||
def _check_keywords_in_value(self, keywords_list, value):
|
|
||||||
for keyword in keywords_list:
|
|
||||||
if keyword.lower() in value.lower():
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
@ -223,7 +223,7 @@ class OpsTraceManager:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# auth check
|
# auth check
|
||||||
if tracing_provider not in provider_config_map.keys() and tracing_provider is not None:
|
if tracing_provider not in provider_config_map and tracing_provider is not None:
|
||||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||||
|
|
||||||
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
||||||
|
@ -127,27 +127,26 @@ class RelytVector(BaseVector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
chunks_table_data = []
|
chunks_table_data = []
|
||||||
with self.client.connect() as conn:
|
with self.client.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
||||||
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
chunks_table_data.append(
|
||||||
chunks_table_data.append(
|
{
|
||||||
{
|
"id": chunk_id,
|
||||||
"id": chunk_id,
|
"embedding": embedding,
|
||||||
"embedding": embedding,
|
"document": document,
|
||||||
"document": document,
|
"metadata": metadata,
|
||||||
"metadata": metadata,
|
}
|
||||||
}
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the batch insert when the batch size is reached
|
# Execute the batch insert when the batch size is reached
|
||||||
if len(chunks_table_data) == 500:
|
if len(chunks_table_data) == 500:
|
||||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
|
||||||
# Clear the chunks_table_data list for the next batch
|
|
||||||
chunks_table_data.clear()
|
|
||||||
|
|
||||||
# Insert any remaining records that didn't make up a full batch
|
|
||||||
if chunks_table_data:
|
|
||||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||||
|
# Clear the chunks_table_data list for the next batch
|
||||||
|
chunks_table_data.clear()
|
||||||
|
|
||||||
|
# Insert any remaining records that didn't make up a full batch
|
||||||
|
if chunks_table_data:
|
||||||
|
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||||
|
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
@ -186,11 +185,10 @@ class RelytVector(BaseVector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self.client.connect() as conn:
|
with self.client.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
delete_condition = chunks_table.c.id.in_(ids)
|
||||||
delete_condition = chunks_table.c.id.in_(ids)
|
conn.execute(chunks_table.delete().where(delete_condition))
|
||||||
conn.execute(chunks_table.delete().where(delete_condition))
|
return True
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Delete operation failed:", str(e))
|
print("Delete operation failed:", str(e))
|
||||||
return False
|
return False
|
||||||
|
@ -63,10 +63,7 @@ class TencentVector(BaseVector):
|
|||||||
|
|
||||||
def _has_collection(self) -> bool:
|
def _has_collection(self) -> bool:
|
||||||
collections = self._db.list_collections()
|
collections = self._db.list_collections()
|
||||||
for collection in collections:
|
return any(collection.collection_name == self._collection_name for collection in collections)
|
||||||
if collection.collection_name == self._collection_name:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _create_collection(self, dimension: int) -> None:
|
def _create_collection(self, dimension: int) -> None:
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||||
|
@ -124,20 +124,19 @@ class TiDBVector(BaseVector):
|
|||||||
texts = [d.page_content for d in documents]
|
texts = [d.page_content for d in documents]
|
||||||
|
|
||||||
chunks_table_data = []
|
chunks_table_data = []
|
||||||
with self._engine.connect() as conn:
|
with self._engine.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
||||||
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||||
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
|
||||||
|
|
||||||
# Execute the batch insert when the batch size is reached
|
# Execute the batch insert when the batch size is reached
|
||||||
if len(chunks_table_data) == 500:
|
if len(chunks_table_data) == 500:
|
||||||
conn.execute(insert(table).values(chunks_table_data))
|
|
||||||
# Clear the chunks_table_data list for the next batch
|
|
||||||
chunks_table_data.clear()
|
|
||||||
|
|
||||||
# Insert any remaining records that didn't make up a full batch
|
|
||||||
if chunks_table_data:
|
|
||||||
conn.execute(insert(table).values(chunks_table_data))
|
conn.execute(insert(table).values(chunks_table_data))
|
||||||
|
# Clear the chunks_table_data list for the next batch
|
||||||
|
chunks_table_data.clear()
|
||||||
|
|
||||||
|
# Insert any remaining records that didn't make up a full batch
|
||||||
|
if chunks_table_data:
|
||||||
|
conn.execute(insert(table).values(chunks_table_data))
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
@ -160,11 +159,10 @@ class TiDBVector(BaseVector):
|
|||||||
raise ValueError("No ids provided to delete.")
|
raise ValueError("No ids provided to delete.")
|
||||||
table = self._table(self._dimension)
|
table = self._table(self._dimension)
|
||||||
try:
|
try:
|
||||||
with self._engine.connect() as conn:
|
with self._engine.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
delete_condition = table.c.id.in_(ids)
|
||||||
delete_condition = table.c.id.in_(ids)
|
conn.execute(table.delete().where(delete_condition))
|
||||||
conn.execute(table.delete().where(delete_condition))
|
return True
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Delete operation failed:", str(e))
|
print("Delete operation failed:", str(e))
|
||||||
return False
|
return False
|
||||||
|
@ -48,7 +48,8 @@ class WordExtractor(BaseExtractor):
|
|||||||
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
|
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
|
||||||
|
|
||||||
self.web_path = self.file_path
|
self.web_path = self.file_path
|
||||||
self.temp_file = tempfile.NamedTemporaryFile()
|
# TODO: use a better way to handle the file
|
||||||
|
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
|
||||||
self.temp_file.write(r.content)
|
self.temp_file.write(r.content)
|
||||||
self.file_path = self.temp_file.name
|
self.file_path = self.temp_file.name
|
||||||
elif not os.path.isfile(self.file_path):
|
elif not os.path.isfile(self.file_path):
|
||||||
|
@ -120,8 +120,8 @@ class WeightRerankRunner:
|
|||||||
intersection = set(vec1.keys()) & set(vec2.keys())
|
intersection = set(vec1.keys()) & set(vec2.keys())
|
||||||
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
||||||
|
|
||||||
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
sum1 = sum(vec1[x] ** 2 for x in vec1)
|
||||||
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
sum2 = sum(vec2[x] ** 2 for x in vec2)
|
||||||
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
||||||
|
|
||||||
if not denominator:
|
if not denominator:
|
||||||
|
@ -581,8 +581,8 @@ class DatasetRetrieval:
|
|||||||
intersection = set(vec1.keys()) & set(vec2.keys())
|
intersection = set(vec1.keys()) & set(vec2.keys())
|
||||||
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
||||||
|
|
||||||
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
sum1 = sum(vec1[x] ** 2 for x in vec1)
|
||||||
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
sum2 = sum(vec2[x] ** 2 for x in vec2)
|
||||||
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
||||||
|
|
||||||
if not denominator:
|
if not denominator:
|
||||||
|
@ -201,9 +201,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
|||||||
elif value.startswith('[{"organizeId"'):
|
elif value.startswith('[{"organizeId"'):
|
||||||
value = json.loads(value)
|
value = json.loads(value)
|
||||||
value = "、".join([item["organizeName"] for item in value])
|
value = "、".join([item["organizeName"] for item in value])
|
||||||
elif value.startswith('[{"file_id"'):
|
elif value.startswith('[{"file_id"') or value == "[]":
|
||||||
value = ""
|
|
||||||
elif value == "[]":
|
|
||||||
value = ""
|
value = ""
|
||||||
elif hasattr(value, "accountId"):
|
elif hasattr(value, "accountId"):
|
||||||
value = value["fullname"]
|
value = value["fullname"]
|
||||||
|
@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
|
|||||||
models_data=[],
|
models_data=[],
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=params,
|
params=params,
|
||||||
recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True,
|
recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
|
||||||
)
|
)
|
||||||
|
|
||||||
result_str = ""
|
result_str = ""
|
||||||
|
@ -39,7 +39,7 @@ class QRCodeGeneratorTool(BuiltinTool):
|
|||||||
|
|
||||||
# get error_correction
|
# get error_correction
|
||||||
error_correction = tool_parameters.get("error_correction", "")
|
error_correction = tool_parameters.get("error_correction", "")
|
||||||
if error_correction not in self.error_correction_levels.keys():
|
if error_correction not in self.error_correction_levels:
|
||||||
return self.create_text_message("Invalid parameter error_correction")
|
return self.create_text_message("Invalid parameter error_correction")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -44,36 +44,36 @@ class SearchAPI:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_response(res: dict, type: str) -> str:
|
def _process_response(res: dict, type: str) -> str:
|
||||||
"""Process response from SearchAPI."""
|
"""Process response from SearchAPI."""
|
||||||
if "error" in res.keys():
|
if "error" in res:
|
||||||
raise ValueError(f"Got error from SearchApi: {res['error']}")
|
raise ValueError(f"Got error from SearchApi: {res['error']}")
|
||||||
|
|
||||||
toret = ""
|
toret = ""
|
||||||
if type == "text":
|
if type == "text":
|
||||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
if "answer_box" in res and "answer" in res["answer_box"]:
|
||||||
toret += res["answer_box"]["answer"] + "\n"
|
toret += res["answer_box"]["answer"] + "\n"
|
||||||
if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
if "answer_box" in res and "snippet" in res["answer_box"]:
|
||||||
toret += res["answer_box"]["snippet"] + "\n"
|
toret += res["answer_box"]["snippet"] + "\n"
|
||||||
if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
|
if "knowledge_graph" in res and "description" in res["knowledge_graph"]:
|
||||||
toret += res["knowledge_graph"]["description"] + "\n"
|
toret += res["knowledge_graph"]["description"] + "\n"
|
||||||
if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys():
|
if "organic_results" in res and "snippet" in res["organic_results"][0]:
|
||||||
for item in res["organic_results"]:
|
for item in res["organic_results"]:
|
||||||
toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
|
toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
|
||||||
if toret == "":
|
if toret == "":
|
||||||
toret = "No good search result found"
|
toret = "No good search result found"
|
||||||
|
|
||||||
elif type == "link":
|
elif type == "link":
|
||||||
if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys():
|
if "answer_box" in res and "organic_result" in res["answer_box"]:
|
||||||
if "title" in res["answer_box"]["organic_result"].keys():
|
if "title" in res["answer_box"]["organic_result"]:
|
||||||
toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n"
|
toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n"
|
||||||
elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys():
|
elif "organic_results" in res and "link" in res["organic_results"][0]:
|
||||||
toret = ""
|
toret = ""
|
||||||
for item in res["organic_results"]:
|
for item in res["organic_results"]:
|
||||||
toret += f"[{item['title']}]({item['link']})\n"
|
toret += f"[{item['title']}]({item['link']})\n"
|
||||||
elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys():
|
elif "related_questions" in res and "link" in res["related_questions"][0]:
|
||||||
toret = ""
|
toret = ""
|
||||||
for item in res["related_questions"]:
|
for item in res["related_questions"]:
|
||||||
toret += f"[{item['title']}]({item['link']})\n"
|
toret += f"[{item['title']}]({item['link']})\n"
|
||||||
elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys():
|
elif "related_searches" in res and "link" in res["related_searches"][0]:
|
||||||
toret = ""
|
toret = ""
|
||||||
for item in res["related_searches"]:
|
for item in res["related_searches"]:
|
||||||
toret += f"[{item['title']}]({item['link']})\n"
|
toret += f"[{item['title']}]({item['link']})\n"
|
||||||
|
@ -44,12 +44,12 @@ class SearchAPI:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_response(res: dict, type: str) -> str:
|
def _process_response(res: dict, type: str) -> str:
|
||||||
"""Process response from SearchAPI."""
|
"""Process response from SearchAPI."""
|
||||||
if "error" in res.keys():
|
if "error" in res:
|
||||||
raise ValueError(f"Got error from SearchApi: {res['error']}")
|
raise ValueError(f"Got error from SearchApi: {res['error']}")
|
||||||
|
|
||||||
toret = ""
|
toret = ""
|
||||||
if type == "text":
|
if type == "text":
|
||||||
if "jobs" in res.keys() and "title" in res["jobs"][0].keys():
|
if "jobs" in res and "title" in res["jobs"][0]:
|
||||||
for item in res["jobs"]:
|
for item in res["jobs"]:
|
||||||
toret += (
|
toret += (
|
||||||
"title: "
|
"title: "
|
||||||
@ -65,7 +65,7 @@ class SearchAPI:
|
|||||||
toret = "No good search result found"
|
toret = "No good search result found"
|
||||||
|
|
||||||
elif type == "link":
|
elif type == "link":
|
||||||
if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys():
|
if "jobs" in res and "apply_link" in res["jobs"][0]:
|
||||||
for item in res["jobs"]:
|
for item in res["jobs"]:
|
||||||
toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n"
|
toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n"
|
||||||
else:
|
else:
|
||||||
|
@ -44,25 +44,25 @@ class SearchAPI:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_response(res: dict, type: str) -> str:
|
def _process_response(res: dict, type: str) -> str:
|
||||||
"""Process response from SearchAPI."""
|
"""Process response from SearchAPI."""
|
||||||
if "error" in res.keys():
|
if "error" in res:
|
||||||
raise ValueError(f"Got error from SearchApi: {res['error']}")
|
raise ValueError(f"Got error from SearchApi: {res['error']}")
|
||||||
|
|
||||||
toret = ""
|
toret = ""
|
||||||
if type == "text":
|
if type == "text":
|
||||||
if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys():
|
if "organic_results" in res and "snippet" in res["organic_results"][0]:
|
||||||
for item in res["organic_results"]:
|
for item in res["organic_results"]:
|
||||||
toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
|
toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
|
||||||
if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys():
|
if "top_stories" in res and "title" in res["top_stories"][0]:
|
||||||
for item in res["top_stories"]:
|
for item in res["top_stories"]:
|
||||||
toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n"
|
toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n"
|
||||||
if toret == "":
|
if toret == "":
|
||||||
toret = "No good search result found"
|
toret = "No good search result found"
|
||||||
|
|
||||||
elif type == "link":
|
elif type == "link":
|
||||||
if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys():
|
if "organic_results" in res and "title" in res["organic_results"][0]:
|
||||||
for item in res["organic_results"]:
|
for item in res["organic_results"]:
|
||||||
toret += f"[{item['title']}]({item['link']})\n"
|
toret += f"[{item['title']}]({item['link']})\n"
|
||||||
elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys():
|
elif "top_stories" in res and "title" in res["top_stories"][0]:
|
||||||
for item in res["top_stories"]:
|
for item in res["top_stories"]:
|
||||||
toret += f"[{item['title']}]({item['link']})\n"
|
toret += f"[{item['title']}]({item['link']})\n"
|
||||||
else:
|
else:
|
||||||
|
@ -44,11 +44,11 @@ class SearchAPI:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_response(res: dict) -> str:
|
def _process_response(res: dict) -> str:
|
||||||
"""Process response from SearchAPI."""
|
"""Process response from SearchAPI."""
|
||||||
if "error" in res.keys():
|
if "error" in res:
|
||||||
raise ValueError(f"Got error from SearchApi: {res['error']}")
|
raise ValueError(f"Got error from SearchApi: {res['error']}")
|
||||||
|
|
||||||
toret = ""
|
toret = ""
|
||||||
if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys():
|
if "transcripts" in res and "text" in res["transcripts"][0]:
|
||||||
for item in res["transcripts"]:
|
for item in res["transcripts"]:
|
||||||
toret += item["text"] + " "
|
toret += item["text"] + " "
|
||||||
if toret == "":
|
if toret == "":
|
||||||
|
@ -35,7 +35,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
|||||||
if model in ["sd3", "sd3-turbo"]:
|
if model in ["sd3", "sd3-turbo"]:
|
||||||
payload["model"] = tool_parameters.get("model")
|
payload["model"] = tool_parameters.get("model")
|
||||||
|
|
||||||
if not model == "sd3-turbo":
|
if model != "sd3-turbo":
|
||||||
payload["negative_prompt"] = tool_parameters.get("negative_prompt", "")
|
payload["negative_prompt"] = tool_parameters.get("negative_prompt", "")
|
||||||
|
|
||||||
response = post(
|
response = post(
|
||||||
|
@ -206,10 +206,9 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
|
|
||||||
# Convert image to RGB and save as PNG
|
# Convert image to RGB and save as PNG
|
||||||
try:
|
try:
|
||||||
with Image.open(io.BytesIO(image_binary)) as image:
|
with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer:
|
||||||
with io.BytesIO() as buffer:
|
image.convert("RGB").save(buffer, format="PNG")
|
||||||
image.convert("RGB").save(buffer, format="PNG")
|
image_binary = buffer.getvalue()
|
||||||
image_binary = buffer.getvalue()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self.create_text_message(f"Failed to process the image: {str(e)}")
|
return self.create_text_message(f"Failed to process the image: {str(e)}")
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ class WikipediaAPIWrapper:
|
|||||||
self.doc_content_chars_max = doc_content_chars_max
|
self.doc_content_chars_max = doc_content_chars_max
|
||||||
|
|
||||||
def run(self, query: str, lang: str = "") -> str:
|
def run(self, query: str, lang: str = "") -> str:
|
||||||
if lang in wikipedia.languages().keys():
|
if lang in wikipedia.languages():
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
wikipedia.set_lang(self.lang)
|
wikipedia.set_lang(self.lang)
|
||||||
|
@ -19,9 +19,7 @@ class ToolFileMessageTransformer:
|
|||||||
result = []
|
result = []
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK:
|
||||||
result.append(message)
|
|
||||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
|
||||||
result.append(message)
|
result.append(message)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||||
# try to download image
|
# try to download image
|
||||||
|
@ -224,9 +224,7 @@ class Graph(BaseModel):
|
|||||||
"""
|
"""
|
||||||
leaf_node_ids = []
|
leaf_node_ids = []
|
||||||
for node_id in self.node_ids:
|
for node_id in self.node_ids:
|
||||||
if node_id not in self.edge_mapping:
|
if node_id not in self.edge_mapping or (
|
||||||
leaf_node_ids.append(node_id)
|
|
||||||
elif (
|
|
||||||
len(self.edge_mapping[node_id]) == 1
|
len(self.edge_mapping[node_id]) == 1
|
||||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
|
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
|
||||||
):
|
):
|
||||||
|
@ -24,7 +24,7 @@ class AnswerStreamGeneratorRouter:
|
|||||||
# parse stream output node value selectors of answer nodes
|
# parse stream output node value selectors of answer nodes
|
||||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
||||||
for answer_node_id, node_config in node_id_config_mapping.items():
|
for answer_node_id, node_config in node_id_config_mapping.items():
|
||||||
if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value:
|
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# get generate route for stream output
|
# get generate route for stream output
|
||||||
|
@ -17,7 +17,7 @@ class EndStreamGeneratorRouter:
|
|||||||
# parse stream output node value selector of end nodes
|
# parse stream output node value selector of end nodes
|
||||||
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
|
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
|
||||||
for end_node_id, node_config in node_id_config_mapping.items():
|
for end_node_id, node_config in node_id_config_mapping.items():
|
||||||
if not node_config.get("data", {}).get("type") == NodeType.END.value:
|
if node_config.get("data", {}).get("type") != NodeType.END.value:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# skip end node in parallel
|
# skip end node in parallel
|
||||||
|
@ -20,7 +20,7 @@ class ToolEntity(BaseModel):
|
|||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise ValueError("tool_configurations must be a dictionary")
|
raise ValueError("tool_configurations must be a dictionary")
|
||||||
|
|
||||||
for key in values.data.get("tool_configurations", {}).keys():
|
for key in values.data.get("tool_configurations", {}):
|
||||||
value = values.data.get("tool_configurations", {}).get(key)
|
value = values.data.get("tool_configurations", {}).get(key)
|
||||||
if not isinstance(value, str | int | float | bool):
|
if not isinstance(value, str | int | float | bool):
|
||||||
raise ValueError(f"{key} must be a string")
|
raise ValueError(f"{key} must be a string")
|
||||||
|
@ -17,14 +17,12 @@ select = [
|
|||||||
"F", # pyflakes rules
|
"F", # pyflakes rules
|
||||||
"I", # isort rules
|
"I", # isort rules
|
||||||
"N", # pep8-naming
|
"N", # pep8-naming
|
||||||
"UP", # pyupgrade rules
|
|
||||||
"RUF019", # unnecessary-key-check
|
"RUF019", # unnecessary-key-check
|
||||||
"RUF100", # unused-noqa
|
"RUF100", # unused-noqa
|
||||||
"RUF101", # redirected-noqa
|
"RUF101", # redirected-noqa
|
||||||
"S506", # unsafe-yaml-load
|
"S506", # unsafe-yaml-load
|
||||||
"SIM116", # if-else-block-instead-of-dict-lookup
|
"SIM", # flake8-simplify rules
|
||||||
"SIM401", # if-else-block-instead-of-dict-get
|
"UP", # pyupgrade rules
|
||||||
"SIM910", # dict-get-with-none-default
|
|
||||||
"W191", # tab-indentation
|
"W191", # tab-indentation
|
||||||
"W605", # invalid-escape-sequence
|
"W605", # invalid-escape-sequence
|
||||||
]
|
]
|
||||||
@ -50,6 +48,15 @@ ignore = [
|
|||||||
"B905", # zip-without-explicit-strict
|
"B905", # zip-without-explicit-strict
|
||||||
"N806", # non-lowercase-variable-in-function
|
"N806", # non-lowercase-variable-in-function
|
||||||
"N815", # mixed-case-variable-in-class-scope
|
"N815", # mixed-case-variable-in-class-scope
|
||||||
|
"SIM102", # collapsible-if
|
||||||
|
"SIM103", # needless-bool
|
||||||
|
"SIM105", # suppressible-exception
|
||||||
|
"SIM107", # return-in-try-except-finally
|
||||||
|
"SIM108", # if-else-block-instead-of-if-exp
|
||||||
|
"SIM113", # eumerate-for-loop
|
||||||
|
"SIM117", # multiple-with-statements
|
||||||
|
"SIM210", # if-expr-with-true-false
|
||||||
|
"SIM300", # yoda-conditions
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
@ -56,9 +56,7 @@ class FileService:
|
|||||||
if etl_type == "Unstructured"
|
if etl_type == "Unstructured"
|
||||||
else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
|
else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
|
||||||
)
|
)
|
||||||
if extension.lower() not in allowed_extensions:
|
if extension.lower() not in allowed_extensions or only_image and extension.lower() not in IMAGE_EXTENSIONS:
|
||||||
raise UnsupportedFileTypeError()
|
|
||||||
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
|
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
# read file content
|
# read file content
|
||||||
|
@ -54,7 +54,7 @@ class OpsService:
|
|||||||
:param tracing_config: tracing config
|
:param tracing_config: tracing config
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if tracing_provider not in provider_config_map.keys() and tracing_provider:
|
if tracing_provider not in provider_config_map and tracing_provider:
|
||||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||||
|
|
||||||
config_class, other_keys = (
|
config_class, other_keys = (
|
||||||
@ -113,7 +113,7 @@ class OpsService:
|
|||||||
:param tracing_config: tracing config
|
:param tracing_config: tracing config
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if tracing_provider not in provider_config_map.keys():
|
if tracing_provider not in provider_config_map:
|
||||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||||
|
|
||||||
# check if trace config already exists
|
# check if trace config already exists
|
||||||
|
Loading…
x
Reference in New Issue
Block a user