mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 18:29:09 +08:00
refactor: text-embedding interfaces to returns list[int]
This commit is contained in:
parent
a6835ac64d
commit
cfa7c89dfe
@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
inputs: Mapping
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
|
@ -219,7 +219,7 @@ class ModelInstance:
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for text embedding
|
||||
|
||||
|
@ -52,7 +52,7 @@ class TextEmbeddingModel(AIModel):
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
|
@ -76,7 +76,7 @@ class PluginNumTokensResponse(BaseModel):
|
||||
Response for number of tokens.
|
||||
"""
|
||||
|
||||
num_tokens: int = Field(description="The number of tokens.")
|
||||
num_tokens: list[int] = Field(description="The number of tokens.")
|
||||
|
||||
|
||||
class PluginStringResultResponse(BaseModel):
|
||||
|
@ -17,6 +17,14 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
|
||||
from core.plugin.manager.exc import (
|
||||
PluginDaemonBadRequestError,
|
||||
PluginDaemonInternalServerError,
|
||||
PluginDaemonNotFoundError,
|
||||
PluginDaemonUnauthorizedError,
|
||||
PluginPermissionDeniedError,
|
||||
PluginUniqueIdentifierError,
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_API_URL
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY
|
||||
@ -190,17 +198,32 @@ class BasePluginManager:
|
||||
"""
|
||||
args = args or {}
|
||||
|
||||
if error_type == PluginDaemonInnerError.__name__:
|
||||
raise PluginDaemonInnerError(code=-500, message=message)
|
||||
elif error_type == InvokeRateLimitError.__name__:
|
||||
raise InvokeRateLimitError(description=args.get("description"))
|
||||
elif error_type == InvokeAuthorizationError.__name__:
|
||||
raise InvokeAuthorizationError(description=args.get("description"))
|
||||
elif error_type == InvokeBadRequestError.__name__:
|
||||
raise InvokeBadRequestError(description=args.get("description"))
|
||||
elif error_type == InvokeConnectionError.__name__:
|
||||
raise InvokeConnectionError(description=args.get("description"))
|
||||
elif error_type == InvokeServerUnavailableError.__name__:
|
||||
raise InvokeServerUnavailableError(description=args.get("description"))
|
||||
else:
|
||||
raise ValueError(f"got unknown error from plugin daemon: {error_type}, message: {message}, args: {args}")
|
||||
match error_type:
|
||||
case PluginDaemonInnerError.__name__:
|
||||
raise PluginDaemonInnerError(code=-500, message=message)
|
||||
case InvokeRateLimitError.__name__:
|
||||
raise InvokeRateLimitError(description=args.get("description"))
|
||||
case InvokeAuthorizationError.__name__:
|
||||
raise InvokeAuthorizationError(description=args.get("description"))
|
||||
case InvokeBadRequestError.__name__:
|
||||
raise InvokeBadRequestError(description=args.get("description"))
|
||||
case InvokeConnectionError.__name__:
|
||||
raise InvokeConnectionError(description=args.get("description"))
|
||||
case InvokeServerUnavailableError.__name__:
|
||||
raise InvokeServerUnavailableError(description=args.get("description"))
|
||||
case PluginDaemonInternalServerError.__name__:
|
||||
raise PluginDaemonInternalServerError(description=message)
|
||||
case PluginDaemonBadRequestError.__name__:
|
||||
raise PluginDaemonBadRequestError(description=message)
|
||||
case PluginDaemonNotFoundError.__name__:
|
||||
raise PluginDaemonNotFoundError(description=message)
|
||||
case PluginUniqueIdentifierError.__name__:
|
||||
raise PluginUniqueIdentifierError(description=message)
|
||||
case PluginDaemonUnauthorizedError.__name__:
|
||||
raise PluginDaemonUnauthorizedError(description=message)
|
||||
case PluginPermissionDeniedError.__name__:
|
||||
raise PluginPermissionDeniedError(description=message)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"got unknown error from plugin daemon: {error_type}, message: {message}, args: {args}"
|
||||
)
|
||||
|
33
api/core/plugin/manager/exc.py
Normal file
33
api/core/plugin/manager/exc.py
Normal file
@ -0,0 +1,33 @@
|
||||
class PluginDaemonError(Exception):
|
||||
"""Base class for all plugin daemon errors."""
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
self.description = description
|
||||
|
||||
|
||||
class PluginDaemonInternalServerError(PluginDaemonError):
|
||||
description: str = "Internal Server Error"
|
||||
|
||||
|
||||
class PluginDaemonBadRequestError(PluginDaemonError):
|
||||
description: str = "Bad Request"
|
||||
|
||||
|
||||
class PluginDaemonNotFoundError(PluginDaemonError):
|
||||
description: str = "Not Found"
|
||||
|
||||
|
||||
class PluginUniqueIdentifierError(PluginDaemonError):
|
||||
description: str = "Unique Identifier Error"
|
||||
|
||||
|
||||
class PluginNotFoundError(PluginDaemonError):
|
||||
description: str = "Plugin Not Found"
|
||||
|
||||
|
||||
class PluginDaemonUnauthorizedError(PluginDaemonError):
|
||||
description: str = "Unauthorized"
|
||||
|
||||
|
||||
class PluginPermissionDeniedError(PluginDaemonError):
|
||||
description: str = "Permission Denied"
|
@ -277,7 +277,7 @@ class PluginModelManager(BasePluginManager):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
) -> int:
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for text embedding
|
||||
"""
|
||||
@ -306,7 +306,7 @@ class PluginModelManager(BasePluginManager):
|
||||
for resp in response:
|
||||
return resp.num_tokens
|
||||
|
||||
return 0
|
||||
return []
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user