| package model |
|
|
| import ( |
| "encoding/json" |
| "errors" |
| "fmt" |
| "strconv" |
| "strings" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/logger" |
|
|
| "github.com/bytedance/gopkg/util/gopool" |
| "gorm.io/gorm" |
| ) |
|
|
| |
| |
| type User struct { |
| Id int `json:"id"` |
| Username string `json:"username" gorm:"unique;index" validate:"max=20"` |
| Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` |
| OriginalPassword string `json:"original_password" gorm:"-:all"` |
| DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` |
| Role int `json:"role" gorm:"type:int;default:1"` |
| Status int `json:"status" gorm:"type:int;default:1"` |
| Email string `json:"email" gorm:"index" validate:"max=50"` |
| GitHubId string `json:"github_id" gorm:"column:github_id;index"` |
| DiscordId string `json:"discord_id" gorm:"column:discord_id;index"` |
| OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` |
| WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` |
| TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` |
| VerificationCode string `json:"verification_code" gorm:"-:all"` |
| AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` |
| Quota int `json:"quota" gorm:"type:int;default:0"` |
| UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` |
| RequestCount int `json:"request_count" gorm:"type:int;default:0;"` |
| Group string `json:"group" gorm:"type:varchar(64);default:'default'"` |
| AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` |
| AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"` |
| AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` |
| AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` |
| InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` |
| DeletedAt gorm.DeletedAt `gorm:"index"` |
| LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"` |
| Setting string `json:"setting" gorm:"type:text;column:setting"` |
| Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"` |
| StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"` |
| } |
|
|
| func (user *User) ToBaseUser() *UserBase { |
| cache := &UserBase{ |
| Id: user.Id, |
| Group: user.Group, |
| Quota: user.Quota, |
| Status: user.Status, |
| Username: user.Username, |
| Setting: user.Setting, |
| Email: user.Email, |
| } |
| return cache |
| } |
|
|
| func (user *User) GetAccessToken() string { |
| if user.AccessToken == nil { |
| return "" |
| } |
| return *user.AccessToken |
| } |
|
|
| func (user *User) SetAccessToken(token string) { |
| user.AccessToken = &token |
| } |
|
|
| func (user *User) GetSetting() dto.UserSetting { |
| setting := dto.UserSetting{} |
| if user.Setting != "" { |
| err := json.Unmarshal([]byte(user.Setting), &setting) |
| if err != nil { |
| common.SysLog("failed to unmarshal setting: " + err.Error()) |
| } |
| } |
| return setting |
| } |
|
|
| func (user *User) SetSetting(setting dto.UserSetting) { |
| settingBytes, err := json.Marshal(setting) |
| if err != nil { |
| common.SysLog("failed to marshal setting: " + err.Error()) |
| return |
| } |
| user.Setting = string(settingBytes) |
| } |
|
|
| |
| func generateDefaultSidebarConfigForRole(userRole int) string { |
| defaultConfig := map[string]interface{}{} |
|
|
| |
| defaultConfig["chat"] = map[string]interface{}{ |
| "enabled": true, |
| "playground": true, |
| "chat": true, |
| } |
|
|
| |
| defaultConfig["console"] = map[string]interface{}{ |
| "enabled": true, |
| "detail": true, |
| "token": true, |
| "log": true, |
| "midjourney": true, |
| "task": true, |
| } |
|
|
| |
| defaultConfig["personal"] = map[string]interface{}{ |
| "enabled": true, |
| "topup": true, |
| "personal": true, |
| } |
|
|
| |
| if userRole == common.RoleAdminUser { |
| |
| defaultConfig["admin"] = map[string]interface{}{ |
| "enabled": true, |
| "channel": true, |
| "models": true, |
| "redemption": true, |
| "user": true, |
| "setting": false, |
| } |
| } else if userRole == common.RoleRootUser { |
| |
| defaultConfig["admin"] = map[string]interface{}{ |
| "enabled": true, |
| "channel": true, |
| "models": true, |
| "redemption": true, |
| "user": true, |
| "setting": true, |
| } |
| } |
| |
|
|
| |
| configBytes, err := json.Marshal(defaultConfig) |
| if err != nil { |
| common.SysLog("生成默认边栏配置失败: " + err.Error()) |
| return "" |
| } |
|
|
| return string(configBytes) |
| } |
|
|
| |
| func CheckUserExistOrDeleted(username string, email string) (bool, error) { |
| var user User |
|
|
| |
| |
| var err error |
| if email == "" { |
| err = DB.Unscoped().First(&user, "username = ?", username).Error |
| } else { |
| err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error |
| } |
| if err != nil { |
| if errors.Is(err, gorm.ErrRecordNotFound) { |
| |
| return false, nil |
| } |
| |
| return false, err |
| } |
| |
| return true, nil |
| } |
|
|
| func GetMaxUserId() int { |
| var user User |
| DB.Unscoped().Last(&user) |
| return user.Id |
| } |
|
|
| func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) { |
| |
| tx := DB.Begin() |
| if tx.Error != nil { |
| return nil, 0, tx.Error |
| } |
| defer func() { |
| if r := recover(); r != nil { |
| tx.Rollback() |
| } |
| }() |
|
|
| |
| err = tx.Unscoped().Model(&User{}).Count(&total).Error |
| if err != nil { |
| tx.Rollback() |
| return nil, 0, err |
| } |
|
|
| |
| err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error |
| if err != nil { |
| tx.Rollback() |
| return nil, 0, err |
| } |
|
|
| |
| if err = tx.Commit().Error; err != nil { |
| return nil, 0, err |
| } |
|
|
| return users, total, nil |
| } |
|
|
| func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) { |
| var users []*User |
| var total int64 |
| var err error |
|
|
| |
| tx := DB.Begin() |
| if tx.Error != nil { |
| return nil, 0, tx.Error |
| } |
| defer func() { |
| if r := recover(); r != nil { |
| tx.Rollback() |
| } |
| }() |
|
|
| |
| query := tx.Unscoped().Model(&User{}) |
|
|
| |
| likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?" |
|
|
| |
| keywordInt, err := strconv.Atoi(keyword) |
| if err == nil { |
| |
| likeCondition = "id = ? OR " + likeCondition |
| if group != "" { |
| query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?", |
| keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) |
| } else { |
| query = query.Where(likeCondition, |
| keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") |
| } |
| } else { |
| |
| if group != "" { |
| query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?", |
| "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) |
| } else { |
| query = query.Where(likeCondition, |
| "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") |
| } |
| } |
|
|
| |
| err = query.Count(&total).Error |
| if err != nil { |
| tx.Rollback() |
| return nil, 0, err |
| } |
|
|
| |
| err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error |
| if err != nil { |
| tx.Rollback() |
| return nil, 0, err |
| } |
|
|
| |
| if err = tx.Commit().Error; err != nil { |
| return nil, 0, err |
| } |
|
|
| return users, total, nil |
| } |
|
|
| func GetUserById(id int, selectAll bool) (*User, error) { |
| if id == 0 { |
| return nil, errors.New("id 为空!") |
| } |
| user := User{Id: id} |
| var err error = nil |
| if selectAll { |
| err = DB.First(&user, "id = ?", id).Error |
| } else { |
| err = DB.Omit("password").First(&user, "id = ?", id).Error |
| } |
| return &user, err |
| } |
|
|
| func GetUserIdByAffCode(affCode string) (int, error) { |
| if affCode == "" { |
| return 0, errors.New("affCode 为空!") |
| } |
| var user User |
| err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error |
| return user.Id, err |
| } |
|
|
| func DeleteUserById(id int) (err error) { |
| if id == 0 { |
| return errors.New("id 为空!") |
| } |
| user := User{Id: id} |
| return user.Delete() |
| } |
|
|
| func HardDeleteUserById(id int) error { |
| if id == 0 { |
| return errors.New("id 为空!") |
| } |
| err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error |
| return err |
| } |
|
|
| func inviteUser(inviterId int) (err error) { |
| user, err := GetUserById(inviterId, true) |
| if err != nil { |
| return err |
| } |
| user.AffCount++ |
| user.AffQuota += common.QuotaForInviter |
| user.AffHistoryQuota += common.QuotaForInviter |
| return DB.Save(user).Error |
| } |
|
|
| func (user *User) TransferAffQuotaToQuota(quota int) error { |
| |
| if float64(quota) < common.QuotaPerUnit { |
| return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit))) |
| } |
|
|
| |
| tx := DB.Begin() |
| if tx.Error != nil { |
| return tx.Error |
| } |
| defer tx.Rollback() |
|
|
| |
| err := tx.Set("gorm:query_option", "FOR UPDATE").First(&user, user.Id).Error |
| if err != nil { |
| return err |
| } |
|
|
| |
| if user.AffQuota < quota { |
| return errors.New("邀请额度不足!") |
| } |
|
|
| |
| user.AffQuota -= quota |
| user.Quota += quota |
|
|
| |
| if err := tx.Save(user).Error; err != nil { |
| return err |
| } |
|
|
| |
| return tx.Commit().Error |
| } |
|
|
| func (user *User) Insert(inviterId int) error { |
| var err error |
| if user.Password != "" { |
| user.Password, err = common.Password2Hash(user.Password) |
| if err != nil { |
| return err |
| } |
| } |
| user.Quota = common.QuotaForNewUser |
| |
| user.AffCode = common.GetRandomString(4) |
|
|
| |
| if user.Setting == "" { |
| defaultSetting := dto.UserSetting{} |
| |
| user.SetSetting(defaultSetting) |
| } |
|
|
| result := DB.Create(user) |
| if result.Error != nil { |
| return result.Error |
| } |
|
|
| |
| |
| var createdUser User |
| if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil { |
| |
| defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role) |
| if defaultSidebarConfig != "" { |
| currentSetting := createdUser.GetSetting() |
| currentSetting.SidebarModules = defaultSidebarConfig |
| createdUser.SetSetting(currentSetting) |
| createdUser.Update(false) |
| common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role)) |
| } |
| } |
|
|
| if common.QuotaForNewUser > 0 { |
| RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) |
| } |
| if inviterId != 0 { |
| if common.QuotaForInvitee > 0 { |
| _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) |
| RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) |
| } |
| if common.QuotaForInviter > 0 { |
| |
| RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) |
| _ = inviteUser(inviterId) |
| } |
| } |
| return nil |
| } |
|
|
| func (user *User) Update(updatePassword bool) error { |
| var err error |
| if updatePassword { |
| user.Password, err = common.Password2Hash(user.Password) |
| if err != nil { |
| return err |
| } |
| } |
| newUser := *user |
| DB.First(&user, user.Id) |
| if err = DB.Model(user).Updates(newUser).Error; err != nil { |
| return err |
| } |
|
|
| |
| return updateUserCache(*user) |
| } |
|
|
| func (user *User) Edit(updatePassword bool) error { |
| var err error |
| if updatePassword { |
| user.Password, err = common.Password2Hash(user.Password) |
| if err != nil { |
| return err |
| } |
| } |
|
|
| newUser := *user |
| updates := map[string]interface{}{ |
| "username": newUser.Username, |
| "display_name": newUser.DisplayName, |
| "group": newUser.Group, |
| "quota": newUser.Quota, |
| "remark": newUser.Remark, |
| } |
| if updatePassword { |
| updates["password"] = newUser.Password |
| } |
|
|
| DB.First(&user, user.Id) |
| if err = DB.Model(user).Updates(updates).Error; err != nil { |
| return err |
| } |
|
|
| |
| return updateUserCache(*user) |
| } |
|
|
| func (user *User) Delete() error { |
| if user.Id == 0 { |
| return errors.New("id 为空!") |
| } |
| if err := DB.Delete(user).Error; err != nil { |
| return err |
| } |
|
|
| |
| return invalidateUserCache(user.Id) |
| } |
|
|
| func (user *User) HardDelete() error { |
| if user.Id == 0 { |
| return errors.New("id 为空!") |
| } |
| err := DB.Unscoped().Delete(user).Error |
| return err |
| } |
|
|
| |
| func (user *User) ValidateAndFill() (err error) { |
| |
| |
| |
| password := user.Password |
| username := strings.TrimSpace(user.Username) |
| if username == "" || password == "" { |
| return errors.New("用户名或密码为空") |
| } |
| |
| DB.Where("username = ? OR email = ?", username, username).First(user) |
| okay := common.ValidatePasswordAndHash(password, user.Password) |
| if !okay || user.Status != common.UserStatusEnabled { |
| return errors.New("用户名或密码错误,或用户已被封禁") |
| } |
| return nil |
| } |
|
|
| func (user *User) FillUserById() error { |
| if user.Id == 0 { |
| return errors.New("id 为空!") |
| } |
| DB.Where(User{Id: user.Id}).First(user) |
| return nil |
| } |
|
|
| func (user *User) FillUserByEmail() error { |
| if user.Email == "" { |
| return errors.New("email 为空!") |
| } |
| DB.Where(User{Email: user.Email}).First(user) |
| return nil |
| } |
|
|
| func (user *User) FillUserByGitHubId() error { |
| if user.GitHubId == "" { |
| return errors.New("GitHub id 为空!") |
| } |
| DB.Where(User{GitHubId: user.GitHubId}).First(user) |
| return nil |
| } |
|
|
| func (user *User) FillUserByDiscordId() error { |
| if user.DiscordId == "" { |
| return errors.New("discord id 为空!") |
| } |
| DB.Where(User{DiscordId: user.DiscordId}).First(user) |
| return nil |
| } |
|
|
| func (user *User) FillUserByOidcId() error { |
| if user.OidcId == "" { |
| return errors.New("oidc id 为空!") |
| } |
| DB.Where(User{OidcId: user.OidcId}).First(user) |
| return nil |
| } |
|
|
| func (user *User) FillUserByWeChatId() error { |
| if user.WeChatId == "" { |
| return errors.New("WeChat id 为空!") |
| } |
| DB.Where(User{WeChatId: user.WeChatId}).First(user) |
| return nil |
| } |
|
|
| func (user *User) FillUserByTelegramId() error { |
| if user.TelegramId == "" { |
| return errors.New("Telegram id 为空!") |
| } |
| err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error |
| if errors.Is(err, gorm.ErrRecordNotFound) { |
| return errors.New("该 Telegram 账户未绑定") |
| } |
| return nil |
| } |
|
|
| func IsEmailAlreadyTaken(email string) bool { |
| return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1 |
| } |
|
|
| func IsWeChatIdAlreadyTaken(wechatId string) bool { |
| return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 |
| } |
|
|
| func IsGitHubIdAlreadyTaken(githubId string) bool { |
| return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 |
| } |
|
|
| func IsDiscordIdAlreadyTaken(discordId string) bool { |
| return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1 |
| } |
|
|
| func IsOidcIdAlreadyTaken(oidcId string) bool { |
| return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 |
| } |
|
|
| func IsTelegramIdAlreadyTaken(telegramId string) bool { |
| return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 |
| } |
|
|
| func ResetUserPasswordByEmail(email string, password string) error { |
| if email == "" || password == "" { |
| return errors.New("邮箱地址或密码为空!") |
| } |
| hashedPassword, err := common.Password2Hash(password) |
| if err != nil { |
| return err |
| } |
| err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error |
| return err |
| } |
|
|
| func IsAdmin(userId int) bool { |
| if userId == 0 { |
| return false |
| } |
| var user User |
| err := DB.Where("id = ?", userId).Select("role").Find(&user).Error |
| if err != nil { |
| common.SysLog("no such user " + err.Error()) |
| return false |
| } |
| return user.Role >= common.RoleAdminUser |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| func ValidateAccessToken(token string) (user *User) { |
| if token == "" { |
| return nil |
| } |
| token = strings.Replace(token, "Bearer ", "", 1) |
| user = &User{} |
| if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 { |
| return user |
| } |
| return nil |
| } |
|
|
| |
| func GetUserQuota(id int, fromDB bool) (quota int, err error) { |
| defer func() { |
| |
| if shouldUpdateRedis(fromDB, err) { |
| gopool.Go(func() { |
| if err := updateUserQuotaCache(id, quota); err != nil { |
| common.SysLog("failed to update user quota cache: " + err.Error()) |
| } |
| }) |
| } |
| }() |
| if !fromDB && common.RedisEnabled { |
| quota, err := getUserQuotaCache(id) |
| if err == nil { |
| return quota, nil |
| } |
| |
| } |
| fromDB = true |
| err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error |
| if err != nil { |
| return 0, err |
| } |
|
|
| return quota, nil |
| } |
|
|
| func GetUserUsedQuota(id int) (quota int, err error) { |
| err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error |
| return quota, err |
| } |
|
|
| func GetUserEmail(id int) (email string, err error) { |
| err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error |
| return email, err |
| } |
|
|
| |
| func GetUserGroup(id int, fromDB bool) (group string, err error) { |
| defer func() { |
| |
| if shouldUpdateRedis(fromDB, err) { |
| gopool.Go(func() { |
| if err := updateUserGroupCache(id, group); err != nil { |
| common.SysLog("failed to update user group cache: " + err.Error()) |
| } |
| }) |
| } |
| }() |
| if !fromDB && common.RedisEnabled { |
| group, err := getUserGroupCache(id) |
| if err == nil { |
| return group, nil |
| } |
| |
| } |
| fromDB = true |
| err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error |
| if err != nil { |
| return "", err |
| } |
|
|
| return group, nil |
| } |
|
|
| |
| func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) { |
| var setting string |
| defer func() { |
| |
| if shouldUpdateRedis(fromDB, err) { |
| gopool.Go(func() { |
| if err := updateUserSettingCache(id, setting); err != nil { |
| common.SysLog("failed to update user setting cache: " + err.Error()) |
| } |
| }) |
| } |
| }() |
| if !fromDB && common.RedisEnabled { |
| setting, err := getUserSettingCache(id) |
| if err == nil { |
| return setting, nil |
| } |
| |
| } |
| fromDB = true |
| err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error |
| if err != nil { |
| return settingMap, err |
| } |
| userBase := &UserBase{ |
| Setting: setting, |
| } |
| return userBase.GetSetting(), nil |
| } |
|
|
| func IncreaseUserQuota(id int, quota int, db bool) (err error) { |
| if quota < 0 { |
| return errors.New("quota 不能为负数!") |
| } |
| gopool.Go(func() { |
| err := cacheIncrUserQuota(id, int64(quota)) |
| if err != nil { |
| common.SysLog("failed to increase user quota: " + err.Error()) |
| } |
| }) |
| if !db && common.BatchUpdateEnabled { |
| addNewRecord(BatchUpdateTypeUserQuota, id, quota) |
| return nil |
| } |
| return increaseUserQuota(id, quota) |
| } |
|
|
| func increaseUserQuota(id int, quota int) (err error) { |
| err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error |
| if err != nil { |
| return err |
| } |
| return err |
| } |
|
|
| func DecreaseUserQuota(id int, quota int) (err error) { |
| if quota < 0 { |
| return errors.New("quota 不能为负数!") |
| } |
| gopool.Go(func() { |
| err := cacheDecrUserQuota(id, int64(quota)) |
| if err != nil { |
| common.SysLog("failed to decrease user quota: " + err.Error()) |
| } |
| }) |
| if common.BatchUpdateEnabled { |
| addNewRecord(BatchUpdateTypeUserQuota, id, -quota) |
| return nil |
| } |
| return decreaseUserQuota(id, quota) |
| } |
|
|
| func decreaseUserQuota(id int, quota int) (err error) { |
| err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error |
| if err != nil { |
| return err |
| } |
| return err |
| } |
|
|
| func DeltaUpdateUserQuota(id int, delta int) (err error) { |
| if delta == 0 { |
| return nil |
| } |
| if delta > 0 { |
| return IncreaseUserQuota(id, delta, false) |
| } else { |
| return DecreaseUserQuota(id, -delta) |
| } |
| } |
|
|
| |
| |
| |
| |
|
|
| func GetRootUser() (user *User) { |
| DB.Where("role = ?", common.RoleRootUser).First(&user) |
| return user |
| } |
|
|
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { |
| if common.BatchUpdateEnabled { |
| addNewRecord(BatchUpdateTypeUsedQuota, id, quota) |
| addNewRecord(BatchUpdateTypeRequestCount, id, 1) |
| return |
| } |
| updateUserUsedQuotaAndRequestCount(id, quota, 1) |
| } |
|
|
| func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { |
| err := DB.Model(&User{}).Where("id = ?", id).Updates( |
| map[string]interface{}{ |
| "used_quota": gorm.Expr("used_quota + ?", quota), |
| "request_count": gorm.Expr("request_count + ?", count), |
| }, |
| ).Error |
| if err != nil { |
| common.SysLog("failed to update user used quota and request count: " + err.Error()) |
| return |
| } |
|
|
| |
| |
| |
| |
| } |
|
|
| func updateUserUsedQuota(id int, quota int) { |
| err := DB.Model(&User{}).Where("id = ?", id).Updates( |
| map[string]interface{}{ |
| "used_quota": gorm.Expr("used_quota + ?", quota), |
| }, |
| ).Error |
| if err != nil { |
| common.SysLog("failed to update user used quota: " + err.Error()) |
| } |
| } |
|
|
| func updateUserRequestCount(id int, count int) { |
| err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error |
| if err != nil { |
| common.SysLog("failed to update user request count: " + err.Error()) |
| } |
| } |
|
|
| |
| func GetUsernameById(id int, fromDB bool) (username string, err error) { |
| defer func() { |
| |
| if shouldUpdateRedis(fromDB, err) { |
| gopool.Go(func() { |
| if err := updateUserNameCache(id, username); err != nil { |
| common.SysLog("failed to update user name cache: " + err.Error()) |
| } |
| }) |
| } |
| }() |
| if !fromDB && common.RedisEnabled { |
| username, err := getUserNameCache(id) |
| if err == nil { |
| return username, nil |
| } |
| |
| } |
| fromDB = true |
| err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error |
| if err != nil { |
| return "", err |
| } |
|
|
| return username, nil |
| } |
|
|
| func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool { |
| var user User |
| err := DB.Unscoped().Where("linux_do_id = ?", linuxDOId).First(&user).Error |
| return !errors.Is(err, gorm.ErrRecordNotFound) |
| } |
|
|
| func (user *User) FillUserByLinuxDOId() error { |
| if user.LinuxDOId == "" { |
| return errors.New("linux do id is empty") |
| } |
| err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error |
| return err |
| } |
|
|
| func RootUserExists() bool { |
| var user User |
| err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error |
| if err != nil { |
| return false |
| } |
| return true |
| } |
|
|