fix: mypy linter

This commit is contained in:
Yeuoly 2025-01-08 21:11:42 +08:00
parent fb309462ad
commit b7d168ac59
27 changed files with 62 additions and 45 deletions

View File

@ -2,7 +2,7 @@ import base64
import secrets
from flask import request
from flask_restful import Resource, reqparse
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -129,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
pass
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
return {"result": "success"}

View File

@ -4,7 +4,7 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restful import Resource
from flask_restful import Resource # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized

View File

@ -2,8 +2,8 @@ import datetime
import json
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound

View File

@ -1,7 +1,7 @@
import os
from flask import session
from flask_restful import Resource, reqparse
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session

View File

@ -1,6 +1,6 @@
from functools import wraps
from flask_login import current_user
from flask_login import current_user # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden

View File

@ -1,5 +1,5 @@
from flask_login import current_user
from flask_restful import Resource
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required

View File

@ -1,5 +1,5 @@
from flask_login import current_user
from flask_restful import Resource, reqparse
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api

View File

@ -1,8 +1,8 @@
import io
from flask import request, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from configs import dify_config

View File

@ -1,5 +1,5 @@
from flask import request
from flask_restful import Resource, marshal_with
from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden
import services

View File

@ -1,4 +1,4 @@
from flask_restful import Resource
from flask_restful import Resource # type: ignore
from controllers.console.wraps import setup_required
from controllers.inner_api import api

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Optional
from flask import request
from flask_restful import reqparse
from flask_restful import reqparse # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session

View File

@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[],
)
usage_dict = {}
usage_dict: dict[str, Optional[LLMUsage]] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",

View File

@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
)
else:
inputs = self.application_generate_entity.inputs

View File

@ -644,7 +644,9 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue

View File

@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Any, Union
from typing import Any, Mapping, Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)

View File

@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
) -> Generator[str | Mapping[str, Any], None, None]: ...
@overload
def generate(
@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
def generate(
self,
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[Mapping, Generator[str, None, None]]:
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.

View File

@ -531,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline:
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue

View File

@ -178,7 +178,7 @@ class ModelInstance:
def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> list[int]:
) -> int:
"""
Get number of tokens for llm
@ -191,7 +191,7 @@ class ModelInstance:
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
list[int],
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,

View File

@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
stream: bool,
inputs: Mapping,
files: list[dict],
):
) -> Generator[Mapping | str, None, None] | Mapping:
"""
invoke workflow app
"""
@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
stream: bool,
inputs: Mapping,
files: list[dict],
):
) -> Generator[Mapping | str, None, None] | Mapping:
"""
invoke completion app
"""

View File

@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
return summary.message.content
lines = content.split("\n")
new_lines = []
new_lines: list[str] = []
# split long line into multiple lines
for i in range(len(lines)):
line = lines[i]
@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
for i in new_lines: # type: ignore
if len(messages) == 0:
messages.append(i)
messages.append(i) # type: ignore
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
messages[-1] += i # type: ignore
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
messages.append(i) # type: ignore
else:
messages[-1] += i
messages[-1] += i # type: ignore
summaries = []
for i in range(len(messages)):

View File

@ -103,7 +103,7 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line))
yield type(**json.loads(line)) # type: ignore
def _request_with_model(
self,

View File

@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
items.append((provider, model.model))
return items
def get_runtime_parameters(self) -> list[ToolParameter]:
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
parameters = []
options = []

View File

@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
items.append((provider, model.model, voices))
return items
def get_runtime_parameters(self) -> list[ToolParameter]:
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
parameters = []
options = []

View File

@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
@field_validator("variable_name", mode="before")
@classmethod
def transform_variable_name(cls, value) -> str:
def transform_variable_name(cls, value: str) -> str:
"""
The variable name must be a string.
"""

View File

@ -9,7 +9,7 @@ import uuid
from collections.abc import Generator
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
@ -182,7 +182,7 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype="application/json")
else:

View File

@ -900,6 +900,9 @@ class RegisterService:
def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
) -> str:
if not inviter:
raise ValueError("Inviter is required")
"""Invite new member"""
with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first()

View File

@ -298,7 +298,7 @@ class WorkflowService:
start_at: float,
tenant_id: str,
node_id: str,
):
) -> WorkflowNodeExecution:
"""
Handle node run result