mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-10 19:18:57 +08:00
Feat: Add OAuth state
parameter for CSRF protection (#7709)
### What problem does this PR solve? Add OAuth `state` parameter for CSRF protection: - Updated `get_authorization_url()` to accept an optional state parameter - Generated a unique state value during OAuth login and stored in session - Verified state parameter in callback to ensure request legitimacy This PR follows OAuth 2.0 security best practices by ensuring that the authorization request originates from the same user who initiated the flow. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
115850945e
commit
6ed81d6774
@ -45,7 +45,7 @@ class OAuthClient:
|
||||
self.http_request_timeout = 7
|
||||
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self, state=None):
|
||||
"""
|
||||
Generate the authorization URL for user login.
|
||||
"""
|
||||
@ -56,6 +56,8 @@ class OAuthClient:
|
||||
}
|
||||
if self.scope:
|
||||
params["scope"] = self.scope
|
||||
if state:
|
||||
params["state"] = state
|
||||
authorization_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url
|
||||
|
||||
|
@ -146,7 +146,9 @@ def oauth_login(channel):
|
||||
raise ValueError(f"Invalid channel name: {channel}")
|
||||
auth_cli = get_auth_client(channel_config)
|
||||
|
||||
auth_url = auth_cli.get_authorization_url()
|
||||
state = get_uuid()
|
||||
session["oauth_state"] = state
|
||||
auth_url = auth_cli.get_authorization_url(state)
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@ -161,6 +163,12 @@ def oauth_callback(channel):
|
||||
raise ValueError(f"Invalid channel name: {channel}")
|
||||
auth_cli = get_auth_client(channel_config)
|
||||
|
||||
# Check the state
|
||||
state = request.args.get("state")
|
||||
if not state or state != session.get("oauth_state"):
|
||||
return redirect("/?error=invalid_state")
|
||||
session.pop("oauth_state", None)
|
||||
|
||||
# Obtain the authorization code
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
|
Loading…
x
Reference in New Issue
Block a user