| |
| |
| |
|
|
| package db |
|
|
| import ( |
| dbsql "database/sql" |
| "errors" |
| "regexp" |
| "strconv" |
| "strings" |
| "sync" |
|
|
| "github.com/GoAdminGroup/go-admin/modules/db/dialect" |
| "github.com/GoAdminGroup/go-admin/modules/logger" |
| ) |
|
|
| |
| type SQL struct { |
| dialect.SQLComponent |
| diver Connection |
| dialect dialect.Dialect |
| conn string |
| tx *dbsql.Tx |
| } |
|
|
| |
| var SQLPool = sync.Pool{ |
| New: func() interface{} { |
| return &SQL{ |
| SQLComponent: dialect.SQLComponent{ |
| Fields: make([]string, 0), |
| TableName: "", |
| Args: make([]interface{}, 0), |
| Wheres: make([]dialect.Where, 0), |
| Leftjoins: make([]dialect.Join, 0), |
| UpdateRaws: make([]dialect.RawUpdate, 0), |
| WhereRaws: "", |
| Order: "", |
| Group: "", |
| Limit: "", |
| }, |
| diver: nil, |
| dialect: nil, |
| } |
| }, |
| } |
|
|
| |
| type H map[string]interface{} |
|
|
| |
| func newSQL() *SQL { |
| return SQLPool.Get().(*SQL) |
| } |
|
|
| |
| |
| |
|
|
| |
| func Table(table string) *SQL { |
| sql := newSQL() |
| sql.TableName = table |
| sql.conn = "default" |
| return sql |
| } |
|
|
| |
| func WithDriver(conn Connection) *SQL { |
| sql := newSQL() |
| sql.diver = conn |
| sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
| sql.conn = "default" |
| return sql |
| } |
|
|
| |
| func WithDriverAndConnection(connName string, conn Connection) *SQL { |
| sql := newSQL() |
| sql.diver = conn |
| sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
| sql.conn = connName |
| return sql |
| } |
|
|
| |
| func (sql *SQL) WithDriver(conn Connection) *SQL { |
| sql.diver = conn |
| sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
| return sql |
| } |
|
|
| |
| func (sql *SQL) WithConnection(conn string) *SQL { |
| sql.conn = conn |
| return sql |
| } |
|
|
| |
| func (sql *SQL) WithTx(tx *dbsql.Tx) *SQL { |
| sql.tx = tx |
| return sql |
| } |
|
|
| |
| func (sql *SQL) Table(table string) *SQL { |
| sql.clean() |
| sql.TableName = table |
| return sql |
| } |
|
|
| |
| func (sql *SQL) Select(fields ...string) *SQL { |
| sql.Fields = fields |
| sql.Functions = make([]string, len(fields)) |
| reg, _ := regexp.Compile(`(.*?)\((.*?)\)`) |
| for k, field := range fields { |
| res := reg.FindAllStringSubmatch(field, -1) |
| if len(res) > 0 && len(res[0]) > 2 { |
| sql.Functions[k] = res[0][1] |
| sql.Fields[k] = res[0][2] |
| } |
| } |
| return sql |
| } |
|
|
| |
| func (sql *SQL) OrderBy(fields ...string) *SQL { |
| if len(fields) == 0 { |
| panic("wrong order field") |
| } |
| for i := 0; i < len(fields); i++ { |
| if i == len(fields)-2 { |
| sql.Order += " " + sql.wrap(fields[i]) + " " + fields[i+1] |
| return sql |
| } |
| sql.Order += " " + sql.wrap(fields[i]) + " and " |
| } |
| return sql |
| } |
|
|
| |
| func (sql *SQL) OrderByRaw(order string) *SQL { |
| if order != "" { |
| sql.Order += " " + order |
| } |
| return sql |
| } |
|
|
| func (sql *SQL) GroupBy(fields ...string) *SQL { |
| if len(fields) == 0 { |
| panic("wrong group by field") |
| } |
| for i := 0; i < len(fields); i++ { |
| if i == len(fields)-1 { |
| sql.Group += " " + sql.wrap(fields[i]) |
| } else { |
| sql.Group += " " + sql.wrap(fields[i]) + "," |
| } |
| } |
| return sql |
| } |
|
|
| |
| func (sql *SQL) GroupByRaw(group string) *SQL { |
| if group != "" { |
| sql.Group += " " + group |
| } |
| return sql |
| } |
|
|
| |
| func (sql *SQL) Skip(offset int) *SQL { |
| sql.Offset = strconv.Itoa(offset) |
| return sql |
| } |
|
|
| |
| func (sql *SQL) Take(take int) *SQL { |
| sql.Limit = strconv.Itoa(take) |
| return sql |
| } |
|
|
| |
| func (sql *SQL) Where(field string, operation string, arg interface{}) *SQL { |
| sql.Wheres = append(sql.Wheres, dialect.Where{ |
| Field: field, |
| Operation: operation, |
| Qmark: "?", |
| }) |
| sql.Args = append(sql.Args, arg) |
| return sql |
| } |
|
|
| |
| func (sql *SQL) WhereIn(field string, arg []interface{}) *SQL { |
| if len(arg) == 0 { |
| panic("wrong parameter") |
| } |
| sql.Wheres = append(sql.Wheres, dialect.Where{ |
| Field: field, |
| Operation: "in", |
| Qmark: "(" + strings.Repeat("?,", len(arg)-1) + "?)", |
| }) |
| sql.Args = append(sql.Args, arg...) |
| return sql |
| } |
|
|
| |
| func (sql *SQL) WhereNotIn(field string, arg []interface{}) *SQL { |
| if len(arg) == 0 { |
| panic("wrong parameter") |
| } |
| sql.Wheres = append(sql.Wheres, dialect.Where{ |
| Field: field, |
| Operation: "not in", |
| Qmark: "(" + strings.Repeat("?,", len(arg)-1) + "?)", |
| }) |
| sql.Args = append(sql.Args, arg...) |
| return sql |
| } |
|
|
| |
| func (sql *SQL) Find(arg interface{}) (map[string]interface{}, error) { |
| return sql.Where("id", "=", arg).First() |
| } |
|
|
| |
| func (sql *SQL) Count() (int64, error) { |
| var ( |
| res map[string]interface{} |
| err error |
| driver = sql.diver.Name() |
| ) |
|
|
| if res, err = sql.Select("count(*)").First(); err != nil { |
| return 0, err |
| } |
|
|
| if driver == DriverPostgresql { |
| return res["count"].(int64), nil |
| } else if driver == DriverMssql { |
| return res[""].(int64), nil |
| } |
|
|
| return res["count(*)"].(int64), nil |
| } |
|
|
| |
| func (sql *SQL) Sum(field string) (float64, error) { |
| var ( |
| res map[string]interface{} |
| err error |
| key = "sum(" + sql.wrap(field) + ")" |
| ) |
| if res, err = sql.Select("sum(" + field + ")").First(); err != nil { |
| return 0, err |
| } |
|
|
| if res == nil { |
| return 0, nil |
| } |
|
|
| if r, ok := res[key].(float64); ok { |
| return r, nil |
| } else if r, ok := res[key].([]uint8); ok { |
| return strconv.ParseFloat(string(r), 64) |
| } else { |
| return 0, nil |
| } |
| } |
|
|
| |
| func (sql *SQL) Max(field string) (interface{}, error) { |
| var ( |
| res map[string]interface{} |
| err error |
| key = "max(" + sql.wrap(field) + ")" |
| ) |
| if res, err = sql.Select("max(" + field + ")").First(); err != nil { |
| return 0, err |
| } |
|
|
| if res == nil { |
| return 0, nil |
| } |
|
|
| return res[key], nil |
| } |
|
|
| |
| func (sql *SQL) Min(field string) (interface{}, error) { |
| var ( |
| res map[string]interface{} |
| err error |
| key = "min(" + sql.wrap(field) + ")" |
| ) |
| if res, err = sql.Select("min(" + field + ")").First(); err != nil { |
| return 0, err |
| } |
|
|
| if res == nil { |
| return 0, nil |
| } |
|
|
| return res[key], nil |
| } |
|
|
| |
| func (sql *SQL) Avg(field string) (interface{}, error) { |
| var ( |
| res map[string]interface{} |
| err error |
| key = "avg(" + sql.wrap(field) + ")" |
| ) |
| if res, err = sql.Select("avg(" + field + ")").First(); err != nil { |
| return 0, err |
| } |
|
|
| if res == nil { |
| return 0, nil |
| } |
|
|
| return res[key], nil |
| } |
|
|
| |
| func (sql *SQL) WhereRaw(raw string, args ...interface{}) *SQL { |
| sql.WhereRaws = raw |
| sql.Args = append(sql.Args, args...) |
| return sql |
| } |
|
|
| |
| func (sql *SQL) UpdateRaw(raw string, args ...interface{}) *SQL { |
| sql.UpdateRaws = append(sql.UpdateRaws, dialect.RawUpdate{ |
| Expression: raw, |
| Args: args, |
| }) |
| return sql |
| } |
|
|
| |
| func (sql *SQL) LeftJoin(table string, fieldA string, operation string, fieldB string) *SQL { |
| sql.Leftjoins = append(sql.Leftjoins, dialect.Join{ |
| FieldA: fieldA, |
| FieldB: fieldB, |
| Table: table, |
| Operation: operation, |
| }) |
| return sql |
| } |
|
|
| |
| |
| |
|
|
| |
| type TxFn func(tx *dbsql.Tx) (error, map[string]interface{}) |
|
|
| |
| |
| func (sql *SQL) WithTransaction(fn TxFn) (res map[string]interface{}, err error) { |
|
|
| tx := sql.diver.BeginTxAndConnection(sql.conn) |
|
|
| defer func() { |
| if p := recover(); p != nil { |
| |
| _ = tx.Rollback() |
| panic(p) |
| } else if err != nil { |
| |
| _ = tx.Rollback() |
| } else { |
| |
| err = tx.Commit() |
| } |
| }() |
|
|
| err, res = fn(tx) |
| return |
| } |
|
|
| |
| |
| func (sql *SQL) WithTransactionByLevel(level dbsql.IsolationLevel, fn TxFn) (res map[string]interface{}, err error) { |
|
|
| tx := sql.diver.BeginTxWithLevelAndConnection(sql.conn, level) |
|
|
| defer func() { |
| if p := recover(); p != nil { |
| |
| _ = tx.Rollback() |
| panic(p) |
| } else if err != nil { |
| |
| _ = tx.Rollback() |
| } else { |
| |
| err = tx.Commit() |
| } |
| }() |
|
|
| err, res = fn(tx) |
| return |
| } |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| func (sql *SQL) First() (map[string]interface{}, error) { |
| defer RecycleSQL(sql) |
|
|
| sql.dialect.Select(&sql.SQLComponent) |
|
|
| res, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
| if err != nil { |
| return nil, err |
| } |
|
|
| if len(res) < 1 { |
| return nil, errors.New("out of index") |
| } |
| return res[0], nil |
| } |
|
|
| |
| func (sql *SQL) All() ([]map[string]interface{}, error) { |
| defer RecycleSQL(sql) |
|
|
| sql.dialect.Select(&sql.SQLComponent) |
|
|
| return sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
| } |
|
|
| |
| func (sql *SQL) ShowColumns() ([]map[string]interface{}, error) { |
| defer RecycleSQL(sql) |
|
|
| return sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowColumns(sql.TableName)) |
| } |
|
|
| |
| func (sql *SQL) ShowColumnsWithComment(database string) ([]map[string]interface{}, error) { |
| defer RecycleSQL(sql) |
|
|
| return sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowColumnsWithComment(database, sql.TableName)) |
| } |
|
|
| |
| func (sql *SQL) ShowTables() ([]string, error) { |
| defer RecycleSQL(sql) |
|
|
| models, err := sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowTables()) |
|
|
| if err != nil { |
| return []string{}, err |
| } |
|
|
| tables := make([]string, 0) |
| if len(models) == 0 { |
| return tables, nil |
| } |
|
|
| key := "Tables_in_" + sql.TableName |
| if sql.diver.Name() == DriverPostgresql || sql.diver.Name() == DriverSqlite { |
| key = "tablename" |
| } else if sql.diver.Name() == DriverMssql { |
| key = "TABLE_NAME" |
| } else if _, ok := models[0][key].(string); !ok { |
| key = "Tables_in_" + strings.ToLower(sql.TableName) |
| } |
|
|
| for i := 0; i < len(models); i++ { |
| |
| if sql.diver.Name() == DriverSqlite && models[i][key].(string) == "sqlite_sequence" { |
| continue |
| } |
|
|
| tables = append(tables, models[i][key].(string)) |
| } |
|
|
| return tables, nil |
| } |
|
|
| |
| func (sql *SQL) Update(values dialect.H) (int64, error) { |
| defer RecycleSQL(sql) |
|
|
| sql.Values = values |
|
|
| sql.dialect.Update(&sql.SQLComponent) |
|
|
| res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
| if err != nil { |
| return 0, err |
| } |
|
|
| if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
| return 0, errors.New("no affect row") |
| } |
|
|
| return res.LastInsertId() |
| } |
|
|
| |
| func (sql *SQL) Delete() error { |
| defer RecycleSQL(sql) |
|
|
| sql.dialect.Delete(&sql.SQLComponent) |
|
|
| res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
| if err != nil { |
| return err |
| } |
|
|
| if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
| return errors.New("no affect row") |
| } |
|
|
| return nil |
| } |
|
|
| |
| func (sql *SQL) Exec() (int64, error) { |
| defer RecycleSQL(sql) |
|
|
| sql.dialect.Update(&sql.SQLComponent) |
|
|
| res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
| if err != nil { |
| return 0, err |
| } |
|
|
| if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
| return 0, errors.New("no affect row") |
| } |
|
|
| return res.LastInsertId() |
| } |
|
|
| const postgresInsertCheckTableName = "goadmin_menu|goadmin_permissions|goadmin_roles|goadmin_users" |
|
|
| |
| func (sql *SQL) Insert(values dialect.H) (int64, error) { |
| defer RecycleSQL(sql) |
|
|
| sql.Values = values |
|
|
| sql.dialect.Insert(&sql.SQLComponent) |
|
|
| if sql.diver.Name() == DriverPostgresql && (strings.Contains(postgresInsertCheckTableName, sql.TableName)) { |
|
|
| resMap, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement+" RETURNING id", sql.Args...) |
|
|
| if err != nil { |
|
|
| |
| _, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
| if err != nil { |
| return 0, err |
| } |
|
|
| res, err := sql.diver.QueryWithConnection(sql.conn, `SELECT max("id") as "id" FROM "`+sql.TableName+`"`) |
|
|
| if err != nil { |
| return 0, err |
| } |
|
|
| if len(res) != 0 { |
| return res[0]["id"].(int64), nil |
| } |
|
|
| return 0, err |
| } |
|
|
| if len(resMap) == 0 { |
| return 0, errors.New("no affect row") |
| } |
|
|
| return resMap[0]["id"].(int64), nil |
| } |
|
|
| res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
| if err != nil { |
| return 0, err |
| } |
|
|
| if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
| return 0, errors.New("no affect row") |
| } |
|
|
| return res.LastInsertId() |
| } |
|
|
| func (sql *SQL) wrap(field string) string { |
| return sql.diver.GetDelimiter() + field + sql.diver.GetDelimiter2() |
| } |
|
|
| func (sql *SQL) clean() { |
| sql.Functions = make([]string, 0) |
| sql.Group = "" |
| sql.Values = make(map[string]interface{}) |
| sql.Fields = make([]string, 0) |
| sql.TableName = "" |
| sql.Wheres = make([]dialect.Where, 0) |
| sql.Leftjoins = make([]dialect.Join, 0) |
| sql.Args = make([]interface{}, 0) |
| sql.Order = "" |
| sql.Offset = "" |
| sql.Limit = "" |
| sql.WhereRaws = "" |
| sql.UpdateRaws = make([]dialect.RawUpdate, 0) |
| sql.Statement = "" |
| } |
|
|
| |
| func RecycleSQL(sql *SQL) { |
|
|
| logger.LogSQL(sql.Statement, sql.Args) |
|
|
| sql.clean() |
|
|
| sql.conn = "" |
| sql.diver = nil |
| sql.tx = nil |
| sql.dialect = nil |
|
|
| SQLPool.Put(sql) |
| } |
|
|