mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 01:56:01 +08:00
fix: drop some type fixme (#20344)
This commit is contained in:
parent
9b47f9f786
commit
5a991295e0
@ -129,17 +129,18 @@ def jsonable_encoder(
|
|||||||
sqlalchemy_safe=sqlalchemy_safe,
|
sqlalchemy_safe=sqlalchemy_safe,
|
||||||
)
|
)
|
||||||
if dataclasses.is_dataclass(obj):
|
if dataclasses.is_dataclass(obj):
|
||||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
# Ensure obj is a dataclass instance, not a dataclass type
|
||||||
obj_dict = dataclasses.asdict(obj) # type: ignore
|
if not isinstance(obj, type):
|
||||||
return jsonable_encoder(
|
obj_dict = dataclasses.asdict(obj)
|
||||||
obj_dict,
|
return jsonable_encoder(
|
||||||
by_alias=by_alias,
|
obj_dict,
|
||||||
exclude_unset=exclude_unset,
|
by_alias=by_alias,
|
||||||
exclude_defaults=exclude_defaults,
|
exclude_unset=exclude_unset,
|
||||||
exclude_none=exclude_none,
|
exclude_defaults=exclude_defaults,
|
||||||
custom_encoder=custom_encoder,
|
exclude_none=exclude_none,
|
||||||
sqlalchemy_safe=sqlalchemy_safe,
|
custom_encoder=custom_encoder,
|
||||||
)
|
sqlalchemy_safe=sqlalchemy_safe,
|
||||||
|
)
|
||||||
if isinstance(obj, Enum):
|
if isinstance(obj, Enum):
|
||||||
return obj.value
|
return obj.value
|
||||||
if isinstance(obj, PurePath):
|
if isinstance(obj, PurePath):
|
||||||
|
@ -85,7 +85,6 @@ class BaiduVector(BaseVector):
|
|||||||
end = min(start + batch_size, total_count)
|
end = min(start + batch_size, total_count)
|
||||||
rows = []
|
rows = []
|
||||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||||
# FIXME do you need this assert?
|
|
||||||
for i in range(start, end, 1):
|
for i in range(start, end, 1):
|
||||||
row = Row(
|
row = Row(
|
||||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||||
|
@ -245,4 +245,4 @@ class TidbService:
|
|||||||
return cluster_infos
|
return cluster_infos
|
||||||
else:
|
else:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
|
return []
|
||||||
|
@ -279,7 +279,6 @@ class ToolParameter(PluginParameter):
|
|||||||
:param options: the options of the parameter
|
:param options: the options of the parameter
|
||||||
"""
|
"""
|
||||||
# convert options to ToolParameterOption
|
# convert options to ToolParameterOption
|
||||||
# FIXME fix the type error
|
|
||||||
if options:
|
if options:
|
||||||
option_objs = [
|
option_objs = [
|
||||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||||
|
@ -66,7 +66,6 @@ class ToolFileMessageTransformer:
|
|||||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||||
raise ValueError("unexpected message type")
|
raise ValueError("unexpected message type")
|
||||||
|
|
||||||
# FIXME: should do a type check here.
|
|
||||||
assert isinstance(message.message.blob, bytes)
|
assert isinstance(message.message.blob, bytes)
|
||||||
tool_file_manager = ToolFileManager()
|
tool_file_manager = ToolFileManager()
|
||||||
file = tool_file_manager.create_file_by_raw(
|
file = tool_file_manager.create_file_by_raw(
|
||||||
|
@ -816,7 +816,6 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# FIXME: fix the type error later
|
|
||||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||||
|
|
||||||
if node_data.instruction:
|
if node_data.instruction:
|
||||||
|
@ -84,8 +84,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||||||
raise VariableError("missing value type")
|
raise VariableError("missing value type")
|
||||||
if (value := mapping.get("value")) is None:
|
if (value := mapping.get("value")) is None:
|
||||||
raise VariableError("missing value")
|
raise VariableError("missing value")
|
||||||
# FIXME: using Any here, fix it later
|
|
||||||
result: Any
|
result: Variable
|
||||||
match value_type:
|
match value_type:
|
||||||
case SegmentType.STRING:
|
case SegmentType.STRING:
|
||||||
result = StringVariable.model_validate(mapping)
|
result = StringVariable.model_validate(mapping)
|
||||||
|
@ -34,9 +34,8 @@ def clean_messages():
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# Main query with join and filter
|
# Main query with join and filter
|
||||||
# FIXME:for mypy no paginate method error
|
|
||||||
messages = (
|
messages = (
|
||||||
db.session.query(Message) # type: ignore
|
db.session.query(Message)
|
||||||
.filter(Message.created_at < plan_sandbox_clean_message_day)
|
.filter(Message.created_at < plan_sandbox_clean_message_day)
|
||||||
.order_by(Message.created_at.desc())
|
.order_by(Message.created_at.desc())
|
||||||
.limit(100)
|
.limit(100)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.ops.entities.config_entity import BaseTracingConfig
|
||||||
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, TraceAppConfig
|
from models.model import App, TraceAppConfig
|
||||||
@ -92,13 +93,12 @@ class OpsService:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||||
|
|
||||||
config_class, other_keys = (
|
provider_config: dict[str, Any] = provider_config_map[tracing_provider]
|
||||||
provider_config_map[tracing_provider]["config_class"],
|
config_class: type[BaseTracingConfig] = provider_config["config_class"]
|
||||||
provider_config_map[tracing_provider]["other_keys"],
|
other_keys: list[str] = provider_config["other_keys"]
|
||||||
)
|
|
||||||
# FIXME: ignore type error
|
default_config_instance: BaseTracingConfig = config_class(**tracing_config)
|
||||||
default_config_instance = config_class(**tracing_config) # type: ignore
|
for key in other_keys:
|
||||||
for key in other_keys: # type: ignore
|
|
||||||
if key in tracing_config and tracing_config[key] == "":
|
if key in tracing_config and tracing_config[key] == "":
|
||||||
tracing_config[key] = getattr(default_config_instance, key, None)
|
tracing_config[key] = getattr(default_config_instance, key, None)
|
||||||
|
|
||||||
|
@ -173,26 +173,27 @@ class WebsiteService:
|
|||||||
return crawl_status_data
|
return crawl_status_data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None:
|
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
|
||||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||||
# decrypt api_key
|
# decrypt api_key
|
||||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||||
# FIXME data is redefine too many times here, use Any to ease the type checking, fix it later
|
|
||||||
data: Any
|
|
||||||
if provider == "firecrawl":
|
if provider == "firecrawl":
|
||||||
|
crawl_data: list[dict[str, Any]] | None = None
|
||||||
file_key = "website_files/" + job_id + ".txt"
|
file_key = "website_files/" + job_id + ".txt"
|
||||||
if storage.exists(file_key):
|
if storage.exists(file_key):
|
||||||
d = storage.load_once(file_key)
|
stored_data = storage.load_once(file_key)
|
||||||
if d:
|
if stored_data:
|
||||||
data = json.loads(d.decode("utf-8"))
|
crawl_data = json.loads(stored_data.decode("utf-8"))
|
||||||
else:
|
else:
|
||||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||||
result = firecrawl_app.check_crawl_status(job_id)
|
result = firecrawl_app.check_crawl_status(job_id)
|
||||||
if result.get("status") != "completed":
|
if result.get("status") != "completed":
|
||||||
raise ValueError("Crawl job is not completed")
|
raise ValueError("Crawl job is not completed")
|
||||||
data = result.get("data")
|
crawl_data = result.get("data")
|
||||||
if data:
|
|
||||||
for item in data:
|
if crawl_data:
|
||||||
|
for item in crawl_data:
|
||||||
if item.get("source_url") == url:
|
if item.get("source_url") == url:
|
||||||
return dict(item)
|
return dict(item)
|
||||||
return None
|
return None
|
||||||
@ -211,23 +212,24 @@ class WebsiteService:
|
|||||||
raise ValueError("Failed to crawl")
|
raise ValueError("Failed to crawl")
|
||||||
return dict(response.json().get("data", {}))
|
return dict(response.json().get("data", {}))
|
||||||
else:
|
else:
|
||||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
# Get crawl status first
|
||||||
response = requests.post(
|
status_response = requests.post(
|
||||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||||
json={"taskId": job_id},
|
json={"taskId": job_id},
|
||||||
)
|
)
|
||||||
data = response.json().get("data", {})
|
status_data = status_response.json().get("data", {})
|
||||||
if data.get("status") != "completed":
|
if status_data.get("status") != "completed":
|
||||||
raise ValueError("Crawl job is not completed")
|
raise ValueError("Crawl job is not completed")
|
||||||
|
|
||||||
response = requests.post(
|
# Get processed data
|
||||||
|
data_response = requests.post(
|
||||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||||
)
|
)
|
||||||
data = response.json().get("data", {})
|
processed_data = data_response.json().get("data", {})
|
||||||
for item in data.get("processed", {}).values():
|
for item in processed_data.get("processed", {}).values():
|
||||||
if item.get("data", {}).get("url") == url:
|
if item.get("data", {}).get("url") == url:
|
||||||
return dict(item.get("data", {}))
|
return dict(item.get("data", {}))
|
||||||
return None
|
return None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user