| | import logging |
| | from datetime import datetime, timezone |
| | from typing import Optional |
| |
|
| | import requests |
| | from flask import current_app, redirect, request |
| | from flask_restful import Resource |
| | from werkzeug.exceptions import Unauthorized |
| |
|
| | from configs import dify_config |
| | from constants.languages import languages |
| | from events.tenant_event import tenant_was_created |
| | from extensions.ext_database import db |
| | from libs.helper import extract_remote_ip |
| | from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo |
| | from models import Account |
| | from models.account import AccountStatus |
| | from services.account_service import AccountService, RegisterService, TenantService |
| | from services.errors.account import AccountNotFoundError |
| | from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError |
| | from services.feature_service import FeatureService |
| |
|
| | from .. import api |
| |
|
| |
|
| | def get_oauth_providers(): |
| | with current_app.app_context(): |
| | if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET: |
| | github_oauth = None |
| | else: |
| | github_oauth = GitHubOAuth( |
| | client_id=dify_config.GITHUB_CLIENT_ID, |
| | client_secret=dify_config.GITHUB_CLIENT_SECRET, |
| | redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", |
| | ) |
| | if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: |
| | google_oauth = None |
| | else: |
| | google_oauth = GoogleOAuth( |
| | client_id=dify_config.GOOGLE_CLIENT_ID, |
| | client_secret=dify_config.GOOGLE_CLIENT_SECRET, |
| | redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", |
| | ) |
| |
|
| | OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} |
| | return OAUTH_PROVIDERS |
| |
|
| |
|
| | class OAuthLogin(Resource): |
| | def get(self, provider: str): |
| | invite_token = request.args.get("invite_token") or None |
| | OAUTH_PROVIDERS = get_oauth_providers() |
| | with current_app.app_context(): |
| | oauth_provider = OAUTH_PROVIDERS.get(provider) |
| | print(vars(oauth_provider)) |
| | if not oauth_provider: |
| | return {"error": "Invalid provider"}, 400 |
| |
|
| | auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) |
| | return redirect(auth_url) |
| |
|
| |
|
| | class OAuthCallback(Resource): |
| | def get(self, provider: str): |
| | OAUTH_PROVIDERS = get_oauth_providers() |
| | with current_app.app_context(): |
| | oauth_provider = OAUTH_PROVIDERS.get(provider) |
| | if not oauth_provider: |
| | return {"error": "Invalid provider"}, 400 |
| |
|
| | code = request.args.get("code") |
| | state = request.args.get("state") |
| | invite_token = None |
| | if state: |
| | invite_token = state |
| |
|
| | try: |
| | token = oauth_provider.get_access_token(code) |
| | user_info = oauth_provider.get_user_info(token) |
| | except requests.exceptions.HTTPError as e: |
| | logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") |
| | return {"error": "OAuth process failed"}, 400 |
| |
|
| | if invite_token and RegisterService.is_valid_invite_token(invite_token): |
| | invitation = RegisterService._get_invitation_by_token(token=invite_token) |
| | if invitation: |
| | invitation_email = invitation.get("email", None) |
| | if invitation_email != user_info.email: |
| | return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") |
| |
|
| | return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") |
| |
|
| | try: |
| | account = _generate_account(provider, user_info) |
| | except AccountNotFoundError: |
| | return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") |
| | except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): |
| | return redirect( |
| | f"{dify_config.CONSOLE_WEB_URL}/signin" |
| | "?message=Workspace not found, please contact system admin to invite you to join in a workspace." |
| | ) |
| |
|
| | |
| | if account.status == AccountStatus.BANNED.value: |
| | return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") |
| |
|
| | if account.status == AccountStatus.PENDING.value: |
| | account.status = AccountStatus.ACTIVE.value |
| | account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) |
| | db.session.commit() |
| |
|
| | try: |
| | TenantService.create_owner_tenant_if_not_exist(account) |
| | except Unauthorized: |
| | return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") |
| | except WorkSpaceNotAllowedCreateError: |
| | return redirect( |
| | f"{dify_config.CONSOLE_WEB_URL}/signin" |
| | "?message=Workspace not found, please contact system admin to invite you to join in a workspace." |
| | ) |
| |
|
| | token_pair = AccountService.login( |
| | account=account, |
| | ip_address=extract_remote_ip(request), |
| | ) |
| |
|
| | return redirect( |
| | f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" |
| | ) |
| |
|
| |
|
| | def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: |
| | account = Account.get_by_openid(provider, user_info.id) |
| |
|
| | if not account: |
| | account = Account.query.filter_by(email=user_info.email).first() |
| |
|
| | return account |
| |
|
| |
|
| | def _generate_account(provider: str, user_info: OAuthUserInfo): |
| | |
| | account = _get_account_by_openid_or_email(provider, user_info) |
| |
|
| | if account: |
| | tenant = TenantService.get_join_tenants(account) |
| | if not tenant: |
| | if not FeatureService.get_system_features().is_allow_create_workspace: |
| | raise WorkSpaceNotAllowedCreateError() |
| | else: |
| | tenant = TenantService.create_tenant(f"{account.name}'s Workspace") |
| | TenantService.create_tenant_member(tenant, account, role="owner") |
| | account.current_tenant = tenant |
| | tenant_was_created.send(tenant) |
| |
|
| | if not account: |
| | if not FeatureService.get_system_features().is_allow_register: |
| | raise AccountNotFoundError() |
| | account_name = user_info.name or "Dify" |
| | account = RegisterService.register( |
| | email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider |
| | ) |
| |
|
| | |
| | preferred_lang = request.accept_languages.best_match(languages) |
| | if preferred_lang and preferred_lang in languages: |
| | interface_language = preferred_lang |
| | else: |
| | interface_language = languages[0] |
| | account.interface_language = interface_language |
| | db.session.commit() |
| |
|
| | |
| | AccountService.link_account_integrate(provider, user_info.id, account) |
| |
|
| | return account |
| |
|
| |
|
| | api.add_resource(OAuthLogin, "/oauth/login/<provider>") |
| | api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") |
| |
|