| import datetime |
|
|
| import pytz |
| from flask import request |
| from flask_login import current_user |
| from flask_restful import Resource, fields, marshal_with, reqparse |
|
|
| from configs import dify_config |
| from constants.languages import supported_language |
| from controllers.console import api |
| from controllers.console.workspace.error import ( |
| AccountAlreadyInitedError, |
| CurrentPasswordIncorrectError, |
| InvalidInvitationCodeError, |
| RepeatPasswordNotMatchError, |
| ) |
| from controllers.console.wraps import account_initialization_required, setup_required |
| from extensions.ext_database import db |
| from fields.member_fields import account_fields |
| from libs.helper import TimestampField, timezone |
| from libs.login import login_required |
| from models import AccountIntegrate, InvitationCode |
| from services.account_service import AccountService |
| from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError |
|
|
|
|
| class AccountInitApi(Resource): |
| @setup_required |
| @login_required |
| def post(self): |
| account = current_user |
|
|
| if account.status == "active": |
| raise AccountAlreadyInitedError() |
|
|
| parser = reqparse.RequestParser() |
|
|
| if dify_config.EDITION == "CLOUD": |
| parser.add_argument("invitation_code", type=str, location="json") |
|
|
| parser.add_argument("interface_language", type=supported_language, required=True, location="json") |
| parser.add_argument("timezone", type=timezone, required=True, location="json") |
| args = parser.parse_args() |
|
|
| if dify_config.EDITION == "CLOUD": |
| if not args["invitation_code"]: |
| raise ValueError("invitation_code is required") |
|
|
| |
| invitation_code = ( |
| db.session.query(InvitationCode) |
| .filter( |
| InvitationCode.code == args["invitation_code"], |
| InvitationCode.status == "unused", |
| ) |
| .first() |
| ) |
|
|
| if not invitation_code: |
| raise InvalidInvitationCodeError() |
|
|
| invitation_code.status = "used" |
| invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
| invitation_code.used_by_tenant_id = account.current_tenant_id |
| invitation_code.used_by_account_id = account.id |
|
|
| account.interface_language = args["interface_language"] |
| account.timezone = args["timezone"] |
| account.interface_theme = "light" |
| account.status = "active" |
| account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
| db.session.commit() |
|
|
| return {"result": "success"} |
|
|
|
|
| class AccountProfileApi(Resource): |
| @setup_required |
| @login_required |
| @account_initialization_required |
| @marshal_with(account_fields) |
| def get(self): |
| return current_user |
|
|
|
|
| class AccountNameApi(Resource): |
| @setup_required |
| @login_required |
| @account_initialization_required |
| @marshal_with(account_fields) |
| def post(self): |
| parser = reqparse.RequestParser() |
| parser.add_argument("name", type=str, required=True, location="json") |
| args = parser.parse_args() |
|
|
| |
| if len(args["name"]) < 3 or len(args["name"]) > 30: |
| raise ValueError("Account name must be between 3 and 30 characters.") |
|
|
| updated_account = AccountService.update_account(current_user, name=args["name"]) |
|
|
| return updated_account |
|
|
|
|
| class AccountAvatarApi(Resource): |
| @setup_required |
| @login_required |
| @account_initialization_required |
| @marshal_with(account_fields) |
| def post(self): |
| parser = reqparse.RequestParser() |
| parser.add_argument("avatar", type=str, required=True, location="json") |
| args = parser.parse_args() |
|
|
| updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) |
|
|
| return updated_account |
|
|
|
|
| class AccountInterfaceLanguageApi(Resource): |
| @setup_required |
| @login_required |
| @account_initialization_required |
| @marshal_with(account_fields) |
| def post(self): |
| parser = reqparse.RequestParser() |
| parser.add_argument("interface_language", type=supported_language, required=True, location="json") |
| args = parser.parse_args() |
|
|
| updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) |
|
|
| return updated_account |
|
|
|
|
| class AccountInterfaceThemeApi(Resource): |
| @setup_required |
| @login_required |
| @account_initialization_required |
| @marshal_with(account_fields) |
| def post(self): |
| parser = reqparse.RequestParser() |
| parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") |
| args = parser.parse_args() |
|
|
| updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) |
|
|
| return updated_account |
|
|
|
|
| class AccountTimezoneApi(Resource): |
| @setup_required |
| @login_required |
| @account_initialization_required |
| @marshal_with(account_fields) |
| def post(self): |
| parser = reqparse.RequestParser() |
| parser.add_argument("timezone", type=str, required=True, location="json") |
| args = parser.parse_args() |
|
|
| |
| if args["timezone"] not in pytz.all_timezones: |
| raise ValueError("Invalid timezone string.") |
|
|
| updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) |
|
|
| return updated_account |
|
|
|
|
| class AccountPasswordApi(Resource): |
| @setup_required |
| @login_required |
| @account_initialization_required |
| @marshal_with(account_fields) |
| def post(self): |
| parser = reqparse.RequestParser() |
| parser.add_argument("password", type=str, required=False, location="json") |
| parser.add_argument("new_password", type=str, required=True, location="json") |
| parser.add_argument("repeat_new_password", type=str, required=True, location="json") |
| args = parser.parse_args() |
|
|
| if args["new_password"] != args["repeat_new_password"]: |
| raise RepeatPasswordNotMatchError() |
|
|
| try: |
| AccountService.update_account_password(current_user, args["password"], args["new_password"]) |
| except ServiceCurrentPasswordIncorrectError: |
| raise CurrentPasswordIncorrectError() |
|
|
| return {"result": "success"} |
|
|
|
|
| class AccountIntegrateApi(Resource): |
| integrate_fields = { |
| "provider": fields.String, |
| "created_at": TimestampField, |
| "is_bound": fields.Boolean, |
| "link": fields.String, |
| } |
|
|
| integrate_list_fields = { |
| "data": fields.List(fields.Nested(integrate_fields)), |
| } |
|
|
| @setup_required |
| @login_required |
| @account_initialization_required |
| @marshal_with(integrate_list_fields) |
| def get(self): |
| account = current_user |
|
|
| account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() |
|
|
| base_url = request.url_root.rstrip("/") |
| oauth_base_path = "/console/api/oauth/login" |
| providers = ["github", "google"] |
|
|
| integrate_data = [] |
| for provider in providers: |
| existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) |
| if existing_integrate: |
| integrate_data.append( |
| { |
| "id": existing_integrate.id, |
| "provider": provider, |
| "created_at": existing_integrate.created_at, |
| "is_bound": True, |
| "link": None, |
| } |
| ) |
| else: |
| integrate_data.append( |
| { |
| "id": None, |
| "provider": provider, |
| "created_at": None, |
| "is_bound": False, |
| "link": f"{base_url}{oauth_base_path}/{provider}", |
| } |
| ) |
|
|
| return {"data": integrate_data} |
|
|
|
|
| |
| api.add_resource(AccountInitApi, "/account/init") |
| api.add_resource(AccountProfileApi, "/account/profile") |
| api.add_resource(AccountNameApi, "/account/name") |
| api.add_resource(AccountAvatarApi, "/account/avatar") |
| api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") |
| api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") |
| api.add_resource(AccountTimezoneApi, "/account/timezone") |
| api.add_resource(AccountPasswordApi, "/account/password") |
| api.add_resource(AccountIntegrateApi, "/account/integrates") |
| |
| |
|
|