fix: mypy linter

This commit is contained in:
Yeuoly 2025-01-08 21:11:42 +08:00
parent fb309462ad
commit b7d168ac59
27 changed files with 62 additions and 45 deletions

View File

@ -2,7 +2,7 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -129,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
pass pass
except AccountRegisterError as are: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
return {"result": "success"} return {"result": "success"}

View File

@ -4,7 +4,7 @@ from typing import Optional
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource from flask_restful import Resource # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized

View File

@ -2,8 +2,8 @@ import datetime
import json import json
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound

View File

@ -1,7 +1,7 @@
import os import os
from flask import session from flask import session
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

View File

@ -1,6 +1,6 @@
from functools import wraps from functools import wraps
from flask_login import current_user from flask_login import current_user # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden

View File

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource from flask_restful import Resource # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required

View File

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api

View File

@ -1,8 +1,8 @@
import io import io
from flask import request, send_file from flask import request, send_file
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config

View File

@ -1,5 +1,5 @@
from flask import request from flask import request
from flask_restful import Resource, marshal_with from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services

View File

@ -1,4 +1,4 @@
from flask_restful import Resource from flask_restful import Resource # type: ignore
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import api from controllers.inner_api import api

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Optional from typing import Optional
from flask import request from flask import request
from flask_restful import reqparse from flask_restful import reqparse # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

View File

@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[], callbacks=[],
) )
usage_dict = {} usage_dict: dict[str, Optional[LLMUsage]] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
agent_response="", agent_response="",

View File

@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

View File

@ -644,7 +644,9 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent): elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event) yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else: else:
continue continue

View File

@ -1,7 +1,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Union from typing import Any, Mapping, Union
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
def convert( def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)

View File

@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
) -> Generator[str, None, None]: ... ) -> Generator[str | Mapping[str, Any], None, None]: ...
@overload @overload
def generate( def generate(
@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = False, streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
def generate( def generate(
self, self,
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser], user: Union[Account, EndUser],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> Union[Mapping, Generator[str, None, None]]: ) -> Union[Mapping, Generator[Mapping | str, None, None]]:
""" """
Generate App response. Generate App response.

View File

@ -531,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline:
delta_text, from_variable_selector=event.from_variable_selector delta_text, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueAgentLogEvent): elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event) yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else: else:
continue continue

View File

@ -178,7 +178,7 @@ class ModelInstance:
def get_llm_num_tokens( def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> list[int]: ) -> int:
""" """
Get number of tokens for llm Get number of tokens for llm
@ -191,7 +191,7 @@ class ModelInstance:
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast( return cast(
list[int], int,
self._round_robin_invoke( self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,

View File

@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
stream: bool, stream: bool,
inputs: Mapping, inputs: Mapping,
files: list[dict], files: list[dict],
): ) -> Generator[Mapping | str, None, None] | Mapping:
""" """
invoke workflow app invoke workflow app
""" """
@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
stream: bool, stream: bool,
inputs: Mapping, inputs: Mapping,
files: list[dict], files: list[dict],
): ) -> Generator[Mapping | str, None, None] | Mapping:
""" """
invoke completion app invoke completion app
""" """

View File

@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
return summary.message.content return summary.message.content
lines = content.split("\n") lines = content.split("\n")
new_lines = [] new_lines: list[str] = []
# split long line into multiple lines # split long line into multiple lines
for i in range(len(lines)): for i in range(len(lines)):
line = lines[i] line = lines[i]
@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
# merge lines into messages with max tokens # merge lines into messages with max tokens
messages: list[str] = [] messages: list[str] = []
for i in new_lines: for i in new_lines: # type: ignore
if len(messages) == 0: if len(messages) == 0:
messages.append(i) messages.append(i) # type: ignore
else: else:
if len(messages[-1]) + len(i) < max_tokens * 0.5: if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
messages[-1] += i messages[-1] += i # type: ignore
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
messages.append(i) messages.append(i) # type: ignore
else: else:
messages[-1] += i messages[-1] += i # type: ignore
summaries = [] summaries = []
for i in range(len(messages)): for i in range(len(messages)):

View File

@ -103,7 +103,7 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API and yield the response as a model. Make a stream request to the plugin daemon inner API and yield the response as a model.
""" """
for line in self._stream_request(method, path, params, headers, data, files): for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line)) yield type(**json.loads(line)) # type: ignore
def _request_with_model( def _request_with_model(
self, self,

View File

@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
items.append((provider, model.model)) items.append((provider, model.model))
return items return items
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
parameters = [] parameters = []
options = [] options = []

View File

@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
items.append((provider, model.model, voices)) items.append((provider, model.model, voices))
return items return items
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
parameters = [] parameters = []
options = [] options = []

View File

@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
@field_validator("variable_name", mode="before") @field_validator("variable_name", mode="before")
@classmethod @classmethod
def transform_variable_name(cls, value) -> str: def transform_variable_name(cls, value: str) -> str:
""" """
The variable name must be a string. The variable name must be a string.
""" """

View File

@ -9,7 +9,7 @@ import uuid
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Mapping, Optional, Union, cast
from zoneinfo import available_timezones from zoneinfo import available_timezones
from flask import Response, stream_with_context from flask import Response, stream_with_context
@ -182,7 +182,7 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest() return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response: def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict): if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype="application/json") return Response(response=json.dumps(response), status=200, mimetype="application/json")
else: else:

View File

@ -900,6 +900,9 @@ class RegisterService:
def invite_new_member( def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
) -> str: ) -> str:
if not inviter:
raise ValueError("Inviter is required")
"""Invite new member""" """Invite new member"""
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first() account = session.query(Account).filter_by(email=email).first()

View File

@ -298,7 +298,7 @@ class WorkflowService:
start_at: float, start_at: float,
tenant_id: str, tenant_id: str,
node_id: str, node_id: str,
): ) -> WorkflowNodeExecution:
""" """
Handle node run result Handle node run result