| | |
| | |
| | |
| |
|
| | package db |
| |
|
| | import ( |
| | dbsql "database/sql" |
| | "errors" |
| | "regexp" |
| | "strconv" |
| | "strings" |
| | "sync" |
| |
|
| | "github.com/GoAdminGroup/go-admin/modules/db/dialect" |
| | "github.com/GoAdminGroup/go-admin/modules/logger" |
| | ) |
| |
|
| | |
| | type SQL struct { |
| | dialect.SQLComponent |
| | diver Connection |
| | dialect dialect.Dialect |
| | conn string |
| | tx *dbsql.Tx |
| | } |
| |
|
| | |
| | var SQLPool = sync.Pool{ |
| | New: func() interface{} { |
| | return &SQL{ |
| | SQLComponent: dialect.SQLComponent{ |
| | Fields: make([]string, 0), |
| | TableName: "", |
| | Args: make([]interface{}, 0), |
| | Wheres: make([]dialect.Where, 0), |
| | Leftjoins: make([]dialect.Join, 0), |
| | UpdateRaws: make([]dialect.RawUpdate, 0), |
| | WhereRaws: "", |
| | Order: "", |
| | Group: "", |
| | Limit: "", |
| | }, |
| | diver: nil, |
| | dialect: nil, |
| | } |
| | }, |
| | } |
| |
|
| | |
| | type H map[string]interface{} |
| |
|
| | |
| | func newSQL() *SQL { |
| | return SQLPool.Get().(*SQL) |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | func Table(table string) *SQL { |
| | sql := newSQL() |
| | sql.TableName = table |
| | sql.conn = "default" |
| | return sql |
| | } |
| |
|
| | |
| | func WithDriver(conn Connection) *SQL { |
| | sql := newSQL() |
| | sql.diver = conn |
| | sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
| | sql.conn = "default" |
| | return sql |
| | } |
| |
|
| | |
| | func WithDriverAndConnection(connName string, conn Connection) *SQL { |
| | sql := newSQL() |
| | sql.diver = conn |
| | sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
| | sql.conn = connName |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) WithDriver(conn Connection) *SQL { |
| | sql.diver = conn |
| | sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) WithConnection(conn string) *SQL { |
| | sql.conn = conn |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) WithTx(tx *dbsql.Tx) *SQL { |
| | sql.tx = tx |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) Table(table string) *SQL { |
| | sql.clean() |
| | sql.TableName = table |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) Select(fields ...string) *SQL { |
| | sql.Fields = fields |
| | sql.Functions = make([]string, len(fields)) |
| | reg, _ := regexp.Compile(`(.*?)\((.*?)\)`) |
| | for k, field := range fields { |
| | res := reg.FindAllStringSubmatch(field, -1) |
| | if len(res) > 0 && len(res[0]) > 2 { |
| | sql.Functions[k] = res[0][1] |
| | sql.Fields[k] = res[0][2] |
| | } |
| | } |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) OrderBy(fields ...string) *SQL { |
| | if len(fields) == 0 { |
| | panic("wrong order field") |
| | } |
| | for i := 0; i < len(fields); i++ { |
| | if i == len(fields)-2 { |
| | sql.Order += " " + sql.wrap(fields[i]) + " " + fields[i+1] |
| | return sql |
| | } |
| | sql.Order += " " + sql.wrap(fields[i]) + " and " |
| | } |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) OrderByRaw(order string) *SQL { |
| | if order != "" { |
| | sql.Order += " " + order |
| | } |
| | return sql |
| | } |
| |
|
| | func (sql *SQL) GroupBy(fields ...string) *SQL { |
| | if len(fields) == 0 { |
| | panic("wrong group by field") |
| | } |
| | for i := 0; i < len(fields); i++ { |
| | if i == len(fields)-1 { |
| | sql.Group += " " + sql.wrap(fields[i]) |
| | } else { |
| | sql.Group += " " + sql.wrap(fields[i]) + "," |
| | } |
| | } |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) GroupByRaw(group string) *SQL { |
| | if group != "" { |
| | sql.Group += " " + group |
| | } |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) Skip(offset int) *SQL { |
| | sql.Offset = strconv.Itoa(offset) |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) Take(take int) *SQL { |
| | sql.Limit = strconv.Itoa(take) |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) Where(field string, operation string, arg interface{}) *SQL { |
| | sql.Wheres = append(sql.Wheres, dialect.Where{ |
| | Field: field, |
| | Operation: operation, |
| | Qmark: "?", |
| | }) |
| | sql.Args = append(sql.Args, arg) |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) WhereIn(field string, arg []interface{}) *SQL { |
| | if len(arg) == 0 { |
| | panic("wrong parameter") |
| | } |
| | sql.Wheres = append(sql.Wheres, dialect.Where{ |
| | Field: field, |
| | Operation: "in", |
| | Qmark: "(" + strings.Repeat("?,", len(arg)-1) + "?)", |
| | }) |
| | sql.Args = append(sql.Args, arg...) |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) WhereNotIn(field string, arg []interface{}) *SQL { |
| | if len(arg) == 0 { |
| | panic("wrong parameter") |
| | } |
| | sql.Wheres = append(sql.Wheres, dialect.Where{ |
| | Field: field, |
| | Operation: "not in", |
| | Qmark: "(" + strings.Repeat("?,", len(arg)-1) + "?)", |
| | }) |
| | sql.Args = append(sql.Args, arg...) |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) Find(arg interface{}) (map[string]interface{}, error) { |
| | return sql.Where("id", "=", arg).First() |
| | } |
| |
|
| | |
| | func (sql *SQL) Count() (int64, error) { |
| | var ( |
| | res map[string]interface{} |
| | err error |
| | driver = sql.diver.Name() |
| | ) |
| |
|
| | if res, err = sql.Select("count(*)").First(); err != nil { |
| | return 0, err |
| | } |
| |
|
| | if driver == DriverPostgresql { |
| | return res["count"].(int64), nil |
| | } else if driver == DriverMssql { |
| | return res[""].(int64), nil |
| | } |
| |
|
| | return res["count(*)"].(int64), nil |
| | } |
| |
|
| | |
| | func (sql *SQL) Sum(field string) (float64, error) { |
| | var ( |
| | res map[string]interface{} |
| | err error |
| | key = "sum(" + sql.wrap(field) + ")" |
| | ) |
| | if res, err = sql.Select("sum(" + field + ")").First(); err != nil { |
| | return 0, err |
| | } |
| |
|
| | if res == nil { |
| | return 0, nil |
| | } |
| |
|
| | if r, ok := res[key].(float64); ok { |
| | return r, nil |
| | } else if r, ok := res[key].([]uint8); ok { |
| | return strconv.ParseFloat(string(r), 64) |
| | } else { |
| | return 0, nil |
| | } |
| | } |
| |
|
| | |
| | func (sql *SQL) Max(field string) (interface{}, error) { |
| | var ( |
| | res map[string]interface{} |
| | err error |
| | key = "max(" + sql.wrap(field) + ")" |
| | ) |
| | if res, err = sql.Select("max(" + field + ")").First(); err != nil { |
| | return 0, err |
| | } |
| |
|
| | if res == nil { |
| | return 0, nil |
| | } |
| |
|
| | return res[key], nil |
| | } |
| |
|
| | |
| | func (sql *SQL) Min(field string) (interface{}, error) { |
| | var ( |
| | res map[string]interface{} |
| | err error |
| | key = "min(" + sql.wrap(field) + ")" |
| | ) |
| | if res, err = sql.Select("min(" + field + ")").First(); err != nil { |
| | return 0, err |
| | } |
| |
|
| | if res == nil { |
| | return 0, nil |
| | } |
| |
|
| | return res[key], nil |
| | } |
| |
|
| | |
| | func (sql *SQL) Avg(field string) (interface{}, error) { |
| | var ( |
| | res map[string]interface{} |
| | err error |
| | key = "avg(" + sql.wrap(field) + ")" |
| | ) |
| | if res, err = sql.Select("avg(" + field + ")").First(); err != nil { |
| | return 0, err |
| | } |
| |
|
| | if res == nil { |
| | return 0, nil |
| | } |
| |
|
| | return res[key], nil |
| | } |
| |
|
| | |
| | func (sql *SQL) WhereRaw(raw string, args ...interface{}) *SQL { |
| | sql.WhereRaws = raw |
| | sql.Args = append(sql.Args, args...) |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) UpdateRaw(raw string, args ...interface{}) *SQL { |
| | sql.UpdateRaws = append(sql.UpdateRaws, dialect.RawUpdate{ |
| | Expression: raw, |
| | Args: args, |
| | }) |
| | return sql |
| | } |
| |
|
| | |
| | func (sql *SQL) LeftJoin(table string, fieldA string, operation string, fieldB string) *SQL { |
| | sql.Leftjoins = append(sql.Leftjoins, dialect.Join{ |
| | FieldA: fieldA, |
| | FieldB: fieldB, |
| | Table: table, |
| | Operation: operation, |
| | }) |
| | return sql |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | type TxFn func(tx *dbsql.Tx) (error, map[string]interface{}) |
| |
|
| | |
| | |
| | func (sql *SQL) WithTransaction(fn TxFn) (res map[string]interface{}, err error) { |
| |
|
| | tx := sql.diver.BeginTxAndConnection(sql.conn) |
| |
|
| | defer func() { |
| | if p := recover(); p != nil { |
| | |
| | _ = tx.Rollback() |
| | panic(p) |
| | } else if err != nil { |
| | |
| | _ = tx.Rollback() |
| | } else { |
| | |
| | err = tx.Commit() |
| | } |
| | }() |
| |
|
| | err, res = fn(tx) |
| | return |
| | } |
| |
|
| | |
| | |
| | func (sql *SQL) WithTransactionByLevel(level dbsql.IsolationLevel, fn TxFn) (res map[string]interface{}, err error) { |
| |
|
| | tx := sql.diver.BeginTxWithLevelAndConnection(sql.conn, level) |
| |
|
| | defer func() { |
| | if p := recover(); p != nil { |
| | |
| | _ = tx.Rollback() |
| | panic(p) |
| | } else if err != nil { |
| | |
| | _ = tx.Rollback() |
| | } else { |
| | |
| | err = tx.Commit() |
| | } |
| | }() |
| |
|
| | err, res = fn(tx) |
| | return |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | func (sql *SQL) First() (map[string]interface{}, error) { |
| | defer RecycleSQL(sql) |
| |
|
| | sql.dialect.Select(&sql.SQLComponent) |
| |
|
| | res, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
| |
|
| | if err != nil { |
| | return nil, err |
| | } |
| |
|
| | if len(res) < 1 { |
| | return nil, errors.New("out of index") |
| | } |
| | return res[0], nil |
| | } |
| |
|
| | |
| | func (sql *SQL) All() ([]map[string]interface{}, error) { |
| | defer RecycleSQL(sql) |
| |
|
| | sql.dialect.Select(&sql.SQLComponent) |
| |
|
| | return sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
| | } |
| |
|
| | |
| | func (sql *SQL) ShowColumns() ([]map[string]interface{}, error) { |
| | defer RecycleSQL(sql) |
| |
|
| | return sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowColumns(sql.TableName)) |
| | } |
| |
|
| | |
| | func (sql *SQL) ShowColumnsWithComment(database string) ([]map[string]interface{}, error) { |
| | defer RecycleSQL(sql) |
| |
|
| | return sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowColumnsWithComment(database, sql.TableName)) |
| | } |
| |
|
| | |
| | func (sql *SQL) ShowTables() ([]string, error) { |
| | defer RecycleSQL(sql) |
| |
|
| | models, err := sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowTables()) |
| |
|
| | if err != nil { |
| | return []string{}, err |
| | } |
| |
|
| | tables := make([]string, 0) |
| | if len(models) == 0 { |
| | return tables, nil |
| | } |
| |
|
| | key := "Tables_in_" + sql.TableName |
| | if sql.diver.Name() == DriverPostgresql || sql.diver.Name() == DriverSqlite { |
| | key = "tablename" |
| | } else if sql.diver.Name() == DriverMssql { |
| | key = "TABLE_NAME" |
| | } else if _, ok := models[0][key].(string); !ok { |
| | key = "Tables_in_" + strings.ToLower(sql.TableName) |
| | } |
| |
|
| | for i := 0; i < len(models); i++ { |
| | |
| | if sql.diver.Name() == DriverSqlite && models[i][key].(string) == "sqlite_sequence" { |
| | continue |
| | } |
| |
|
| | tables = append(tables, models[i][key].(string)) |
| | } |
| |
|
| | return tables, nil |
| | } |
| |
|
| | |
| | func (sql *SQL) Update(values dialect.H) (int64, error) { |
| | defer RecycleSQL(sql) |
| |
|
| | sql.Values = values |
| |
|
| | sql.dialect.Update(&sql.SQLComponent) |
| |
|
| | res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
| |
|
| | if err != nil { |
| | return 0, err |
| | } |
| |
|
| | if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
| | return 0, errors.New("no affect row") |
| | } |
| |
|
| | return res.LastInsertId() |
| | } |
| |
|
| | |
| | func (sql *SQL) Delete() error { |
| | defer RecycleSQL(sql) |
| |
|
| | sql.dialect.Delete(&sql.SQLComponent) |
| |
|
| | res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
| |
|
| | if err != nil { |
| | return err |
| | } |
| |
|
| | if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
| | return errors.New("no affect row") |
| | } |
| |
|
| | return nil |
| | } |
| |
|
| | |
| | func (sql *SQL) Exec() (int64, error) { |
| | defer RecycleSQL(sql) |
| |
|
| | sql.dialect.Update(&sql.SQLComponent) |
| |
|
| | res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
| |
|
| | if err != nil { |
| | return 0, err |
| | } |
| |
|
| | if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
| | return 0, errors.New("no affect row") |
| | } |
| |
|
| | return res.LastInsertId() |
| | } |
| |
|
| | const postgresInsertCheckTableName = "goadmin_menu|goadmin_permissions|goadmin_roles|goadmin_users" |
| |
|
| | |
| | func (sql *SQL) Insert(values dialect.H) (int64, error) { |
| | defer RecycleSQL(sql) |
| |
|
| | sql.Values = values |
| |
|
| | sql.dialect.Insert(&sql.SQLComponent) |
| |
|
| | if sql.diver.Name() == DriverPostgresql && (strings.Contains(postgresInsertCheckTableName, sql.TableName)) { |
| |
|
| | resMap, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement+" RETURNING id", sql.Args...) |
| |
|
| | if err != nil { |
| |
|
| | |
| | _, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
| |
|
| | if err != nil { |
| | return 0, err |
| | } |
| |
|
| | res, err := sql.diver.QueryWithConnection(sql.conn, `SELECT max("id") as "id" FROM "`+sql.TableName+`"`) |
| |
|
| | if err != nil { |
| | return 0, err |
| | } |
| |
|
| | if len(res) != 0 { |
| | return res[0]["id"].(int64), nil |
| | } |
| |
|
| | return 0, err |
| | } |
| |
|
| | if len(resMap) == 0 { |
| | return 0, errors.New("no affect row") |
| | } |
| |
|
| | return resMap[0]["id"].(int64), nil |
| | } |
| |
|
| | res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
| |
|
| | if err != nil { |
| | return 0, err |
| | } |
| |
|
| | if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
| | return 0, errors.New("no affect row") |
| | } |
| |
|
| | return res.LastInsertId() |
| | } |
| |
|
| | func (sql *SQL) wrap(field string) string { |
| | return sql.diver.GetDelimiter() + field + sql.diver.GetDelimiter2() |
| | } |
| |
|
| | func (sql *SQL) clean() { |
| | sql.Functions = make([]string, 0) |
| | sql.Group = "" |
| | sql.Values = make(map[string]interface{}) |
| | sql.Fields = make([]string, 0) |
| | sql.TableName = "" |
| | sql.Wheres = make([]dialect.Where, 0) |
| | sql.Leftjoins = make([]dialect.Join, 0) |
| | sql.Args = make([]interface{}, 0) |
| | sql.Order = "" |
| | sql.Offset = "" |
| | sql.Limit = "" |
| | sql.WhereRaws = "" |
| | sql.UpdateRaws = make([]dialect.RawUpdate, 0) |
| | sql.Statement = "" |
| | } |
| |
|
| | |
| | func RecycleSQL(sql *SQL) { |
| |
|
| | logger.LogSQL(sql.Statement, sql.Args) |
| |
|
| | sql.clean() |
| |
|
| | sql.conn = "" |
| | sql.diver = nil |
| | sql.tx = nil |
| | sql.dialect = nil |
| |
|
| | SQLPool.Put(sql) |
| | } |
| |
|