From 6ed81d6774c40d70d39f2ed4fff4310d944ce68a Mon Sep 17 00:00:00 2001 From: Chaoxi Weng Date: Tue, 20 May 2025 09:40:31 +0800 Subject: [PATCH] Feat: Add OAuth `state` parameter for CSRF protection (#7709) ### What problem does this PR solve? Add OAuth `state` parameter for CSRF protection: - Updated `get_authorization_url()` to accept an optional state parameter - Generated a unique state value during OAuth login and stored in session - Verified state parameter in callback to ensure request legitimacy This PR follows OAuth 2.0 security best practices by ensuring that the authorization request originates from the same user who initiated the flow. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/auth/oauth.py | 4 +++- api/apps/user_app.py | 10 +++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/api/apps/auth/oauth.py b/api/apps/auth/oauth.py index a908e81ea..6f7e0e5b5 100644 --- a/api/apps/auth/oauth.py +++ b/api/apps/auth/oauth.py @@ -45,7 +45,7 @@ class OAuthClient: self.http_request_timeout = 7 - def get_authorization_url(self): + def get_authorization_url(self, state=None): """ Generate the authorization URL for user login. """ @@ -56,6 +56,8 @@ class OAuthClient: } if self.scope: params["scope"] = self.scope + if state: + params["state"] = state authorization_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}" return authorization_url diff --git a/api/apps/user_app.py b/api/apps/user_app.py index ddca41018..401d6977d 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -146,7 +146,9 @@ def oauth_login(channel): raise ValueError(f"Invalid channel name: {channel}") auth_cli = get_auth_client(channel_config) - auth_url = auth_cli.get_authorization_url() + state = get_uuid() + session["oauth_state"] = state + auth_url = auth_cli.get_authorization_url(state) return redirect(auth_url) @@ -161,6 +163,12 @@ def oauth_callback(channel): raise ValueError(f"Invalid channel name: {channel}") auth_cli = get_auth_client(channel_config) + # Check the state + state = request.args.get("state") + if not state or state != session.get("oauth_state"): + return redirect("/?error=invalid_state") + session.pop("oauth_state", None) + # Obtain the authorization code code = request.args.get("code") if not code: