Skip to content

Commit

Permalink
Merge pull request #10 from XIELongDragon/suppor-custom-tag
Browse files Browse the repository at this point in the history
Suppor custom tag
  • Loading branch information
XIELongDragon authored May 17, 2023
2 parents 73b61c8 + 8da9d8c commit ce048d2
Show file tree
Hide file tree
Showing 58 changed files with 761 additions and 669 deletions.
57 changes: 36 additions & 21 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ type (
logger Logger
dialect string
// nolint: stylecheck // keep for backwards compatibility
Db SQLDatabase
qf exec.QueryFactory
qfOnce sync.Once
Db SQLDatabase
tagName string
qf exec.QueryFactory
qfOnce sync.Once
}
)

Expand Down Expand Up @@ -62,11 +63,12 @@ type (
// panic(err.Error())
// }
// fmt.Printf("%+v", ids)
func newDatabase(dialect string, db SQLDatabase) *Database {
func newDatabase(dialect string, tagName string, db SQLDatabase) *Database {
return &Database{
logger: nil,
dialect: dialect,
Db: db,
tagName: tagName,
qf: nil,
qfOnce: sync.Once{},
}
Expand All @@ -83,7 +85,7 @@ func (d *Database) Begin() (*TxDatabase, error) {
if err != nil {
return nil, err
}
tx := NewTx(d.dialect, sqlTx)
tx := NewTxWithTagName(d.dialect, d.tagName, sqlTx)
tx.Logger(d.logger)
return tx, nil
}
Expand All @@ -94,7 +96,7 @@ func (d *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*TxDatabas
if err != nil {
return nil, err
}
tx := NewTx(d.dialect, sqlTx)
tx := NewTxWithTagName(d.dialect, d.tagName, sqlTx)
tx.Logger(d.logger)
return tx, nil
}
Expand All @@ -117,27 +119,27 @@ func (d *Database) WithTx(fn func(*TxDatabase) error) error {
//
// from...: Sources for you dataset, could be table names (strings), a goqu.Literal or another goqu.Dataset
func (d *Database) From(from ...interface{}) *SelectDataset {
return newDataset(d.dialect, d.queryFactory()).From(from...)
return newDataset(d.dialect, d.tagName, d.queryFactory()).From(from...)
}

func (d *Database) Select(cols ...interface{}) *SelectDataset {
return newDataset(d.dialect, d.queryFactory()).Select(cols...)
return newDataset(d.dialect, d.tagName, d.queryFactory()).Select(cols...)
}

func (d *Database) Update(table interface{}) *UpdateDataset {
return newUpdateDataset(d.dialect, d.queryFactory()).Table(table)
return newUpdateDataset(d.dialect, d.tagName, d.queryFactory()).Table(table)
}

func (d *Database) Insert(table interface{}) *InsertDataset {
return newInsertDataset(d.dialect, d.queryFactory()).Into(table)
return newInsertDataset(d.dialect, d.tagName, d.queryFactory()).Into(table)
}

func (d *Database) Delete(table interface{}) *DeleteDataset {
return newDeleteDataset(d.dialect, d.queryFactory()).From(table)
return newDeleteDataset(d.dialect, d.tagName, d.queryFactory()).From(table)
}

func (d *Database) Truncate(table ...interface{}) *TruncateDataset {
return newTruncateDataset(d.dialect, d.queryFactory()).Table(table...)
return newTruncateDataset(d.dialect, d.tagName, d.queryFactory()).Table(table...)
}

// Sets the logger for to use when logging queries
Expand Down Expand Up @@ -332,9 +334,13 @@ func (d *Database) QueryRowContext(ctx context.Context, query string, args ...in
return d.Db.QueryRowContext(ctx, query, args...)
}

func (d *Database) GetTagName() string {
return d.tagName
}

func (d *Database) queryFactory() exec.QueryFactory {
d.qfOnce.Do(func() {
d.qf = exec.NewQueryFactory(d)
d.qf = exec.NewQueryFactory(d.tagName, d)
})
return d.qf
}
Expand Down Expand Up @@ -450,6 +456,7 @@ type (
TxDatabase struct {
logger Logger
dialect string
tagName string
Tx SQLTx
qf exec.QueryFactory
qfOnce sync.Once
Expand All @@ -458,7 +465,11 @@ type (

// Creates a new TxDatabase
func NewTx(dialect string, tx SQLTx) *TxDatabase {
return &TxDatabase{dialect: dialect, Tx: tx}
return &TxDatabase{dialect: dialect, tagName: "db", Tx: tx}
}

func NewTxWithTagName(dialect string, tagName string, tx SQLTx) *TxDatabase {
return &TxDatabase{dialect: dialect, tagName: tagName, Tx: tx}
}

// returns this databases dialect
Expand All @@ -468,27 +479,27 @@ func (td *TxDatabase) Dialect() string {

// Creates a new Dataset for querying a Database.
func (td *TxDatabase) From(cols ...interface{}) *SelectDataset {
return newDataset(td.dialect, td.queryFactory()).From(cols...)
return newDataset(td.dialect, td.tagName, td.queryFactory()).From(cols...)
}

func (td *TxDatabase) Select(cols ...interface{}) *SelectDataset {
return newDataset(td.dialect, td.queryFactory()).Select(cols...)
return newDataset(td.dialect, td.tagName, td.queryFactory()).Select(cols...)
}

func (td *TxDatabase) Update(table interface{}) *UpdateDataset {
return newUpdateDataset(td.dialect, td.queryFactory()).Table(table)
return newUpdateDataset(td.dialect, td.tagName, td.queryFactory()).Table(table)
}

func (td *TxDatabase) Insert(table interface{}) *InsertDataset {
return newInsertDataset(td.dialect, td.queryFactory()).Into(table)
return newInsertDataset(td.dialect, td.tagName, td.queryFactory()).Into(table)
}

func (td *TxDatabase) Delete(table interface{}) *DeleteDataset {
return newDeleteDataset(td.dialect, td.queryFactory()).From(table)
return newDeleteDataset(td.dialect, td.tagName, td.queryFactory()).From(table)
}

func (td *TxDatabase) Truncate(table ...interface{}) *TruncateDataset {
return newTruncateDataset(td.dialect, td.queryFactory()).Table(table...)
return newTruncateDataset(td.dialect, td.tagName, td.queryFactory()).Table(table...)
}

// Sets the logger
Expand Down Expand Up @@ -554,9 +565,13 @@ func (td *TxDatabase) QueryRowContext(ctx context.Context, query string, args ..
return td.Tx.QueryRowContext(ctx, query, args...)
}

func (td *TxDatabase) GetTagName() string {
return td.tagName
}

func (td *TxDatabase) queryFactory() exec.QueryFactory {
td.qfOnce.Do(func() {
td.qf = exec.NewQueryFactory(td)
td.qf = exec.NewQueryFactory(td.tagName, td)
})
return td.qf
}
Expand Down
13 changes: 8 additions & 5 deletions delete_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,24 @@ type DeleteDataset struct {
clauses exp.DeleteClauses
isPrepared prepared
queryFactory exec.QueryFactory
tagName string
err error
}

// used internally by database to create a database with a specific adapter
func newDeleteDataset(d string, queryFactory exec.QueryFactory) *DeleteDataset {
func newDeleteDataset(d, tagName string, queryFactory exec.QueryFactory) *DeleteDataset {
return &DeleteDataset{
clauses: exp.NewDeleteClauses(),
dialect: GetDialect(d),
dialect: GetDialectWithTag(d, tagName),
queryFactory: queryFactory,
isPrepared: preparedNoPreference,
tagName: tagName,
err: nil,
}
}

func Delete(table interface{}) *DeleteDataset {
return newDeleteDataset("default", nil).From(table)
return newDeleteDataset("default", "db", nil).From(table)
}

func (dd *DeleteDataset) Expression() exp.Expression {
Expand Down Expand Up @@ -58,7 +60,7 @@ func (dd *DeleteDataset) IsPrepared() bool {
// Sets the adapter used to serialize values and create the SQL statement
func (dd *DeleteDataset) WithDialect(dl string) *DeleteDataset {
ds := dd.copy(dd.GetClauses())
ds.dialect = GetDialect(dl)
ds.dialect = GetDialectWithTag(dl, dd.tagName)
return ds
}

Expand Down Expand Up @@ -86,6 +88,7 @@ func (dd *DeleteDataset) copy(clauses exp.DeleteClauses) *DeleteDataset {
clauses: clauses,
isPrepared: dd.isPrepared,
queryFactory: dd.queryFactory,
tagName: dd.tagName,
err: dd.err,
}
}
Expand Down Expand Up @@ -180,7 +183,7 @@ func (dd *DeleteDataset) ClearLimit() *DeleteDataset {

// Adds a RETURNING clause to the dataset if the adapter supports it.
func (dd *DeleteDataset) Returning(returning ...interface{}) *DeleteDataset {
return dd.copy(dd.clauses.SetReturning(exp.NewColumnListExpression(nil, returning...)))
return dd.copy(dd.clauses.SetReturning(exp.NewColumnListExpression(nil, dd.tagName, returning...)))
}

// Get any error that has been set or nil if no error has been set.
Expand Down
14 changes: 7 additions & 7 deletions delete_dataset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ func (dds *deleteDatasetSuite) assertCases(cases ...deleteTestCase) {
func (dds *deleteDatasetSuite) SetupSuite() {
noReturn := goqu.DefaultDialectOptions()
noReturn.SupportsReturn = false
goqu.RegisterDialect("no-return", noReturn)
goqu.RegisterDialect("no-return", "db", noReturn)

limitOnDelete := goqu.DefaultDialectOptions()
limitOnDelete.SupportsLimitOnDelete = true
goqu.RegisterDialect("limit-on-delete", limitOnDelete)
goqu.RegisterDialect("limit-on-delete", "db", limitOnDelete)

orderOnDelete := goqu.DefaultDialectOptions()
orderOnDelete.SupportsOrderByOnDelete = true
goqu.RegisterDialect("order-on-delete", orderOnDelete)
goqu.RegisterDialect("order-on-delete", "db", orderOnDelete)
}

func (dds *deleteDatasetSuite) TearDownSuite() {
Expand Down Expand Up @@ -359,25 +359,25 @@ func (dds *deleteDatasetSuite) TestReturning() {
ds: bd.Returning("a"),
clauses: exp.NewDeleteClauses().
SetFrom(goqu.C("items")).
SetReturning(exp.NewColumnListExpression(nil, "a")),
SetReturning(exp.NewColumnListExpression(nil, "db", "a")),
},
deleteTestCase{
ds: bd.Returning(),
clauses: exp.NewDeleteClauses().
SetFrom(goqu.C("items")).
SetReturning(exp.NewColumnListExpression(nil)),
SetReturning(exp.NewColumnListExpression(nil, "db")),
},
deleteTestCase{
ds: bd.Returning(nil),
clauses: exp.NewDeleteClauses().
SetFrom(goqu.C("items")).
SetReturning(exp.NewColumnListExpression(nil)),
SetReturning(exp.NewColumnListExpression(nil, "db")),
},
deleteTestCase{
ds: bd.Returning("a").Returning("b"),
clauses: exp.NewDeleteClauses().
SetFrom(goqu.C("items")).
SetReturning(exp.NewColumnListExpression(nil, "b")),
SetReturning(exp.NewColumnListExpression(nil, "db", "b")),
},
deleteTestCase{
ds: bd,
Expand Down
11 changes: 9 additions & 2 deletions dialect/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,14 @@ func DialectOptionsV8() *goqu.SQLDialectOptions {
return opts
}

func DialectOptionsV8Nano() *goqu.SQLDialectOptions {
opts := DialectOptionsV8()
opts.TimeFormat = "2006-01-02 15:04:05.999999999"
return opts
}

func init() {
goqu.RegisterDialect("mysql", DialectOptions())
goqu.RegisterDialect("mysql8", DialectOptionsV8())
goqu.RegisterDialect("mysql", "db", DialectOptions())
goqu.RegisterDialect("mysql8", "db", DialectOptionsV8())
goqu.RegisterDialect("mysql8nano", "db", DialectOptionsV8Nano())
}
2 changes: 1 addition & 1 deletion dialect/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ func DialectOptions() *goqu.SQLDialectOptions {
}

func init() {
goqu.RegisterDialect("postgres", DialectOptions())
goqu.RegisterDialect("postgres", "db", DialectOptions())
}
2 changes: 1 addition & 1 deletion dialect/sqlite3/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ func DialectOptions() *goqu.SQLDialectOptions {
}

func init() {
goqu.RegisterDialect("sqlite3", DialectOptions())
goqu.RegisterDialect("sqlite3", "db", DialectOptions())
}
2 changes: 1 addition & 1 deletion dialect/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ func DialectOptions() *goqu.SQLDialectOptions {
}

func init() {
goqu.RegisterDialect("sqlserver", DialectOptions())
goqu.RegisterDialect("sqlserver", "db", DialectOptions())
}
15 changes: 8 additions & 7 deletions exec/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import (

type (
QueryExecutor struct {
de DbExecutor
err error
query string
args []interface{}
tagName string
de DbExecutor
err error
query string
args []interface{}
}
)

Expand All @@ -26,8 +27,8 @@ var (
errScanValNonSlice = errors.New("type cannot be a pointer to a slice when scanning into val")
)

func newQueryExecutor(de DbExecutor, err error, query string, args ...interface{}) QueryExecutor {
return QueryExecutor{de: de, err: err, query: query, args: args}
func newQueryExecutor(tagName string, de DbExecutor, err error, query string, args ...interface{}) QueryExecutor {
return QueryExecutor{tagName: tagName, de: de, err: err, query: query, args: args}
}

func (q QueryExecutor) ToSQL() (sql string, args []interface{}, err error) {
Expand Down Expand Up @@ -243,5 +244,5 @@ func (q QueryExecutor) ScannerContext(ctx context.Context) (Scanner, error) {
if err != nil {
return nil, err
}
return NewScanner(rows), nil
return NewScanner(rows, q.tagName), nil
}
Loading

0 comments on commit ce048d2

Please sign in to comment.