diff --git a/api/apps/auth/oauth.py b/api/apps/auth/oauth.py index a908e81ea..6f7e0e5b5 100644 --- a/api/apps/auth/oauth.py +++ b/api/apps/auth/oauth.py @@ -45,7 +45,7 @@ class OAuthClient: self.http_request_timeout = 7 - def get_authorization_url(self): + def get_authorization_url(self, state=None): """ Generate the authorization URL for user login. """ @@ -56,6 +56,8 @@ class OAuthClient: } if self.scope: params["scope"] = self.scope + if state: + params["state"] = state authorization_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}" return authorization_url diff --git a/api/apps/user_app.py b/api/apps/user_app.py index ddca41018..401d6977d 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -146,7 +146,9 @@ def oauth_login(channel): raise ValueError(f"Invalid channel name: {channel}") auth_cli = get_auth_client(channel_config) - auth_url = auth_cli.get_authorization_url() + state = get_uuid() + session["oauth_state"] = state + auth_url = auth_cli.get_authorization_url(state) return redirect(auth_url) @@ -161,6 +163,12 @@ def oauth_callback(channel): raise ValueError(f"Invalid channel name: {channel}") auth_cli = get_auth_client(channel_config) + # Check the state + state = request.args.get("state") + if not state or state != session.get("oauth_state"): + return redirect("/?error=invalid_state") + session.pop("oauth_state", None) + # Obtain the authorization code code = request.args.get("code") if not code: