fix(app_generator_service): overload type hints (#11507)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-10 09:06:34 +08:00 committed by GitHub
parent ec00b25793
commit fd354d999d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 155 additions and 22 deletions

View File

@ -3,7 +3,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Optional, Union from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
@ -36,6 +36,29 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator): class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int _dialogue_count: int
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate( def generate(
self, self,
app_model: App, app_model: App,
@ -44,7 +67,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]: ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
""" """
Generate App response. Generate App response.

View File

@ -2,7 +2,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Union from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
@ -28,6 +28,39 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator): class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate( def generate(
self, self,
*, *,
@ -36,7 +69,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]: ):
""" """
Generate App response. Generate App response.

View File

@ -1,7 +1,7 @@
import logging import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
@ -34,9 +34,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[True] = True, streaming: Literal[True],
) -> Generator[str, None, None]: ... ) -> Generator[str, None, None]: ...
@overload @overload
@ -44,19 +44,29 @@ class ChatAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[False] = False, streaming: Literal[False],
) -> dict: ... ) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Union[dict, Generator[str, None, None]]: ):
""" """
Generate App response. Generate App response.

View File

@ -1,7 +1,7 @@
import logging import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[True] = True, streaming: Literal[True],
) -> Generator[str, None, None]: ... ) -> Generator[str, None, None]: ...
@overload @overload
@ -44,14 +44,29 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[False] = False, streaming: Literal[False],
) -> dict: ... ) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate( def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True self,
) -> Union[dict, Generator[str, None, None]]: app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
""" """
Generate App response. Generate App response.

View File

@ -3,7 +3,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, Union from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
@ -30,6 +30,35 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator): class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@overload
def generate( def generate(
self, self,
*, *,
@ -41,7 +70,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True, streaming: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]: ) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
files: Sequence[Mapping[str, Any]] = args.get("files") or [] files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files # parse files