mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 06:45:55 +08:00
fix: mypy linter
This commit is contained in:
parent
fb309462ad
commit
b7d168ac59
@ -2,7 +2,7 @@ import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -129,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
@ -2,8 +2,8 @@ import datetime
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
from flask import session
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
|
@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
|
@ -1,8 +1,8 @@
|
||||
import io
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
|
@ -1,5 +1,5 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with
|
||||
from flask_restful import Resource, marshal_with # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
|
@ -1,4 +1,4 @@
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
|
@ -3,7 +3,7 @@ from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import reqparse # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict = {}
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
|
@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
@ -644,7 +644,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
from typing import Any, Mapping, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
@classmethod
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
|
@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
) -> Generator[str | Mapping[str, Any], None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[Mapping, Generator[str, None, None]]:
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -531,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -178,7 +178,7 @@ class ModelInstance:
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> list[int]:
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
|
||||
@ -191,7 +191,7 @@ class ModelInstance:
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
list[int],
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
|
@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
):
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke workflow app
|
||||
"""
|
||||
@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
):
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke completion app
|
||||
"""
|
||||
|
@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
|
||||
return summary.message.content
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
new_lines: list[str] = []
|
||||
# split long line into multiple lines
|
||||
for i in range(len(lines)):
|
||||
line = lines[i]
|
||||
@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines:
|
||||
for i in new_lines: # type: ignore
|
||||
if len(messages) == 0:
|
||||
messages.append(i)
|
||||
messages.append(i) # type: ignore
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
||||
messages[-1] += i
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
||||
messages.append(i)
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
|
||||
messages[-1] += i # type: ignore
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
|
||||
messages.append(i) # type: ignore
|
||||
else:
|
||||
messages[-1] += i
|
||||
messages[-1] += i # type: ignore
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
@ -103,7 +103,7 @@ class BasePluginManager:
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
yield type(**json.loads(line))
|
||||
yield type(**json.loads(line)) # type: ignore
|
||||
|
||||
def _request_with_model(
|
||||
self,
|
||||
|
@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
|
||||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
|
@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
|
||||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
|
@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value) -> str:
|
||||
def transform_variable_name(cls, value: str) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
|
@ -9,7 +9,7 @@ import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union, cast
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
@ -182,7 +182,7 @@ def generate_text_hash(text: str) -> str:
|
||||
return sha256(hash_text.encode()).hexdigest()
|
||||
|
||||
|
||||
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype="application/json")
|
||||
else:
|
||||
|
@ -900,6 +900,9 @@ class RegisterService:
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
|
||||
) -> str:
|
||||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
|
||||
"""Invite new member"""
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter_by(email=email).first()
|
||||
|
@ -298,7 +298,7 @@ class WorkflowService:
|
||||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
):
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Handle node run result
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user