mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 22:09:11 +08:00
feat: add interfaces of OAuth handler methods for authorization (#18889)
This commit is contained in:
parent
7ccec5cd95
commit
0e0ec4691a
@ -1,6 +1,7 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Generic, Optional, TypeVar
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
@ -158,3 +159,11 @@ class PluginInstallTaskStartResponse(BaseModel):
|
|||||||
class PluginUploadResponse(BaseModel):
|
class PluginUploadResponse(BaseModel):
|
||||||
unique_identifier: str = Field(description="The unique identifier of the plugin.")
|
unique_identifier: str = Field(description="The unique identifier of the plugin.")
|
||||||
manifest: PluginDeclaration
|
manifest: PluginDeclaration
|
||||||
|
|
||||||
|
|
||||||
|
class PluginOAuthAuthorizationUrlResponse(BaseModel):
|
||||||
|
authorization_url: str = Field(description="The URL of the authorization.")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginOAuthCredentialsResponse(BaseModel):
|
||||||
|
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
|
||||||
|
@ -1,6 +1,98 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from werkzeug import Request
|
||||||
|
|
||||||
|
from core.plugin.entities.plugin_daemon import PluginOAuthAuthorizationUrlResponse, PluginOAuthCredentialsResponse
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
|
||||||
|
|
||||||
class OAuthHandler(BasePluginClient):
|
class OAuthHandler(BasePluginClient):
|
||||||
def get_authorization_url(self, tenant_id: str, user_id: str, provider_name: str) -> str:
|
def get_authorization_url(
|
||||||
return "1234567890"
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
system_credentials: Mapping[str, Any],
|
||||||
|
) -> PluginOAuthAuthorizationUrlResponse:
|
||||||
|
return self._request_with_plugin_daemon_response(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
||||||
|
PluginOAuthAuthorizationUrlResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"system_credentials": system_credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
system_credentials: Mapping[str, Any],
|
||||||
|
request: Request,
|
||||||
|
) -> PluginOAuthCredentialsResponse:
|
||||||
|
"""
|
||||||
|
Get credentials from the given request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# encode request to raw http request
|
||||||
|
raw_request_bytes = self._convert_request_to_raw_data(request)
|
||||||
|
|
||||||
|
return self._request_with_plugin_daemon_response(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
||||||
|
PluginOAuthCredentialsResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"system_credentials": system_credentials,
|
||||||
|
"raw_request_bytes": raw_request_bytes,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||||
|
"""
|
||||||
|
Convert a Request object to raw HTTP data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The Request object to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The raw HTTP data as bytes.
|
||||||
|
"""
|
||||||
|
# Start with the request line
|
||||||
|
method = request.method
|
||||||
|
path = request.path
|
||||||
|
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
||||||
|
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
||||||
|
|
||||||
|
# Add headers
|
||||||
|
for header_name, header_value in request.headers.items():
|
||||||
|
raw_data += f"{header_name}: {header_value}\r\n".encode()
|
||||||
|
|
||||||
|
# Add empty line to separate headers from body
|
||||||
|
raw_data += b"\r\n"
|
||||||
|
|
||||||
|
# Add body if exists
|
||||||
|
body = request.get_data(as_text=False)
|
||||||
|
if body:
|
||||||
|
raw_data += body
|
||||||
|
|
||||||
|
return raw_data
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
from werkzeug import Request
|
||||||
|
from werkzeug.datastructures import Headers
|
||||||
|
from werkzeug.test import EnvironBuilder
|
||||||
|
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_convert_request_to_raw_data():
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
builder = EnvironBuilder(
|
||||||
|
method="GET",
|
||||||
|
path="/test",
|
||||||
|
headers=Headers({"Content-Type": "application/json"}),
|
||||||
|
)
|
||||||
|
request = Request(builder.get_environ())
|
||||||
|
raw_request_bytes = oauth_handler._convert_request_to_raw_data(request)
|
||||||
|
|
||||||
|
assert b"GET /test HTTP/1.1" in raw_request_bytes
|
||||||
|
assert b"Content-Type: application/json" in raw_request_bytes
|
||||||
|
assert b"\r\n\r\n" in raw_request_bytes
|
Loading…
x
Reference in New Issue
Block a user