forked from ido50/sqlz
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinsert.go
391 lines (326 loc) · 12 KB
/
insert.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
package sqlz
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/jmoiron/sqlx"
)
// InsertStmt represents an INSERT statement
type InsertStmt struct {
*Statement
InsCols []string
InsVals []interface{}
InsMultipleVals [][]interface{}
SelectStmt *SelectStmt
Table string
Return []string
Conflicts []*ConflictClause
execer Ext
sqliteConflict string
}
// InsertInto creates a new InsertStmt object for the
// provided table
func (db *DB) InsertInto(table string) *InsertStmt {
return &InsertStmt{
Table: table,
execer: db.DB,
Statement: &Statement{db.ErrHandlers},
}
}
// InsertInto creates a new InsertStmt object for the
// provided table
func (tx *Tx) InsertInto(table string) *InsertStmt {
return &InsertStmt{
Table: table,
execer: tx.Tx,
Statement: &Statement{tx.ErrHandlers},
}
}
// Columns defines the columns to insert. It can be safely
// used alongside ValueMap in the same query, provided Values
// is used immediately after Columns
func (stmt *InsertStmt) Columns(cols ...string) *InsertStmt {
stmt.InsCols = append(stmt.InsCols, cols...)
return stmt
}
// Values sets the values to insert to the table (based on the
// columns provided via Columns)
func (stmt *InsertStmt) Values(vals ...interface{}) *InsertStmt {
stmt.InsVals = append(stmt.InsVals, vals...)
return stmt
}
// ValueMap receives a map of columns and values to insert
func (stmt *InsertStmt) ValueMap(vals map[string]interface{}) *InsertStmt {
for _, col := range sortKeys(vals) {
stmt.InsCols = append(stmt.InsCols, col)
stmt.InsVals = append(stmt.InsVals, vals[col])
}
return stmt
}
// ValueMultiple receives an array of interfaces in order to insert multiple records using the same insert statement
func (stmt *InsertStmt) ValueMultiple(vals [][]interface{}) *InsertStmt {
stmt.InsMultipleVals = append(stmt.InsMultipleVals, vals...)
return stmt
}
// FromSelect sets a SELECT statements that will supply the rows to be inserted.
func (stmt *InsertStmt) FromSelect(selStmt *SelectStmt) *InsertStmt {
stmt.SelectStmt = selStmt
return stmt
}
// Returning sets a RETURNING clause to receive values back from the
// database once executing the INSERT statement. Note that GetRow or
// GetAll must be used to execute the query rather than Exec to get
// back the values.
func (stmt *InsertStmt) Returning(cols ...string) *InsertStmt {
stmt.Return = append(stmt.Return, cols...)
return stmt
}
// OnConflictDoNothing sets an ON CONFLICT clause on the statement. This method
// is deprecated in favor of OnConflict.
func (stmt *InsertStmt) OnConflictDoNothing() *InsertStmt {
return stmt.OnConflict(OnConflict().DoNothing())
}
// OrIgnore enables the "OR IGNORE" conflict resolution for SQLIte inserts
func (stmt *InsertStmt) OrIgnore() *InsertStmt {
stmt.sqliteConflict = "IGNORE"
return stmt
}
// OrReplace enables the "OR REPLACE" conflict resolution for SQLIte inserts
func (stmt *InsertStmt) OrReplace() *InsertStmt {
stmt.sqliteConflict = "REPLACE"
return stmt
}
// OrAbort enables the "OR ABORT" conflict resolution for SQLIte inserts
func (stmt *InsertStmt) OrAbort() *InsertStmt {
stmt.sqliteConflict = "ABORT"
return stmt
}
// OrRollback enables the "OR ROLLBACK" conflict resolution for SQLIte inserts
func (stmt *InsertStmt) OrRollback() *InsertStmt {
stmt.sqliteConflict = "ROLLBACK"
return stmt
}
// OrFail enables the "OR FAIL" conflict resolution for SQLIte inserts
func (stmt *InsertStmt) OrFail() *InsertStmt {
stmt.sqliteConflict = "FAIL"
return stmt
}
// OnConflict adds an ON CONFLICT clause to the statement
func (stmt *InsertStmt) OnConflict(clause *ConflictClause) *InsertStmt {
stmt.Conflicts = append(stmt.Conflicts, clause)
return stmt
}
// ToSQL generates the INSERT statement's SQL and returns a list of
// bindings. It is used internally by Exec, GetRow and GetAll, but is
// exported if you wish to use it directly.
func (stmt *InsertStmt) ToSQL(rebind bool) (asSQL string, bindings []interface{}) {
var clauses = []string{"INSERT", "INTO", stmt.Table}
if stmt.sqliteConflict != "" {
clauses[0] = fmt.Sprintf("INSERT OR %s", stmt.sqliteConflict)
}
if len(stmt.InsCols) > 0 {
clauses = append(clauses, "("+strings.Join(stmt.InsCols, ", ")+")")
}
switch {
case stmt.SelectStmt != nil:
selectSQL, selectBindings := stmt.SelectStmt.ToSQL(false)
clauses = append(clauses, selectSQL)
bindings = append(bindings, selectBindings...)
case len(stmt.InsVals) > 0:
placeholders, bindingsToAdd := parseInsertValues(stmt.InsVals)
bindings = append(bindings, bindingsToAdd...)
clauses = append(clauses, "VALUES ("+strings.Join(placeholders, ", ")+")")
case len(stmt.InsMultipleVals) > 0:
var multipleValues []string
for _, insVals := range stmt.InsMultipleVals {
placeholders, bindingsToAdd := parseInsertValues(insVals)
bindings = append(bindings, bindingsToAdd...)
multipleValues = append(multipleValues, "("+strings.Join(placeholders, ", ")+")")
}
clauses = append(clauses, "VALUES "+strings.Join(multipleValues, ", "))
}
for _, conflict := range stmt.Conflicts {
conflictSQL, conflictBindings := conflict.ToSQL()
clauses = append(clauses, conflictSQL)
bindings = append(bindings, conflictBindings...)
}
if len(stmt.Return) > 0 {
clauses = append(clauses, "RETURNING "+strings.Join(stmt.Return, ", "))
}
asSQL = strings.Join(clauses, " ")
if rebind {
if db, ok := stmt.execer.(*sqlx.DB); ok {
asSQL = db.Rebind(asSQL)
} else if tx, ok := stmt.execer.(*sqlx.Tx); ok {
asSQL = tx.Rebind(asSQL)
}
}
return asSQL, bindings
}
// Exec executes the INSERT statement, returning the standard
// sql.Result struct and an error if the query failed.
func (stmt *InsertStmt) Exec() (res sql.Result, err error) {
asSQL, bindings := stmt.ToSQL(true)
res, err = stmt.execer.Exec(asSQL, bindings...)
stmt.Statement.HandleError(err)
return res, err
}
// ExecContext executes the INSERT statement, returning the standard
// sql.Result struct and an error if the query failed.
func (stmt *InsertStmt) ExecContext(ctx context.Context) (res sql.Result, err error) {
asSQL, bindings := stmt.ToSQL(true)
res, err = stmt.execer.ExecContext(ctx, asSQL, bindings...)
stmt.Statement.HandleError(err)
return res, err
}
// GetRow executes an INSERT statement with a RETURNING clause
// expected to return one row, and loads the result into
// the provided variable (which may be a simple variable if
// only one column is returned, or a struct if multiple columns
// are returned)
func (stmt *InsertStmt) GetRow(into interface{}) error {
asSQL, bindings := stmt.ToSQL(true)
return sqlx.Get(stmt.execer, into, asSQL, bindings...)
}
// GetRowContext executes an INSERT statement with a RETURNING clause
// expected to return one row, and loads the result into
// the provided variable (which may be a simple variable if
// only one column is returned, or a struct if multiple columns
// are returned)
func (stmt *InsertStmt) GetRowContext(ctx context.Context, into interface{}) error {
asSQL, bindings := stmt.ToSQL(true)
return sqlx.GetContext(ctx, stmt.execer, into, asSQL, bindings...)
}
// GetAll executes an INSERT statement with a RETURNING clause
// expected to return multiple rows, and loads the result into
// the provided slice variable
func (stmt *InsertStmt) GetAll(into interface{}) error {
asSQL, bindings := stmt.ToSQL(true)
return sqlx.Select(stmt.execer, into, asSQL, bindings...)
}
// GetAllContext executes an INSERT statement with a RETURNING clause
// expected to return multiple rows, and loads the result into
// the provided slice variable
func (stmt *InsertStmt) GetAllContext(ctx context.Context, into interface{}) error {
asSQL, bindings := stmt.ToSQL(true)
return sqlx.SelectContext(ctx, stmt.execer, into, asSQL, bindings...)
}
// ConflictAction represents an action to perform on an INSERT conflict
type ConflictAction string
const (
// DoNothing represents a "DO NOTHING" conflict action
DoNothing ConflictAction = "nothing"
// DoUpdate represents a "DO UPDATE" conflict action
DoUpdate ConflictAction = "update"
)
// ConflictClause represents an ON CONFLICT clause in an INSERT statement
type ConflictClause struct {
Targets []string
Action ConflictAction
SetCols []string
SetVals []interface{}
Updates map[string]interface{}
}
// OnConflict gets a list of targets and creates a new ConflictClause object
func OnConflict(targets ...string) *ConflictClause {
return &ConflictClause{
Targets: targets,
}
}
// DoNothing sets the conflict clause's action as DO NOTHING
func (conflict *ConflictClause) DoNothing() *ConflictClause {
conflict.Action = DoNothing
return conflict
}
// DoUpdate sets the conflict clause's action as DO UPDATE. Caller is expected
// to set columns to update using Set or SetMap after calling this method.
func (conflict *ConflictClause) DoUpdate() *ConflictClause {
conflict.Action = DoUpdate
return conflict
}
// Set adds a column to update as part of the conflict resolution
func (conflict *ConflictClause) Set(col string, val interface{}) *ConflictClause {
return conflict.SetIf(col, val, true)
}
// SetMap adds a mapping between columns to values to update as part of the
// conflict resolution
func (conflict *ConflictClause) SetMap(vals map[string]interface{}) *ConflictClause {
if conflict.Action != DoUpdate {
return conflict
}
for _, col := range sortKeys(vals) {
conflict.SetCols = append(conflict.SetCols, col)
conflict.SetVals = append(conflict.SetVals, vals[col])
}
return conflict
}
// SetIf is the same as Set, but also accepts a boolean value and only does
// anything if that value is true. This is a convenience method so that
// conditional updates can be made without having to save the ConflictClause
// into a variable and using if statements
func (conflict *ConflictClause) SetIf(col string, val interface{}, b bool) *ConflictClause {
if conflict.Action != DoUpdate {
return conflict
}
if b {
conflict.SetCols = append(conflict.SetCols, col)
conflict.SetVals = append(conflict.SetVals, val)
}
return conflict
}
// ToSQL generates the SQL code for the conflict clause
func (conflict *ConflictClause) ToSQL() (asSQL string, bindings []interface{}) {
words := []string{"ON CONFLICT"}
if len(conflict.Targets) > 0 {
words = append(words, "("+strings.Join(conflict.Targets, ", ")+")")
}
switch conflict.Action {
case DoNothing:
words = append(words, "DO NOTHING")
case DoUpdate:
words = append(words, "DO UPDATE SET")
var updates []string
for i, col := range conflict.SetCols {
val := conflict.SetVals[i]
if fn, isFn := val.(UpdateFunction); isFn {
var args []string
for _, arg := range fn.Arguments {
if indirect, isIndirect := arg.(IndirectValue); isIndirect {
args = append(args, indirect.Reference)
bindings = append(bindings, indirect.Bindings...)
} else {
args = append(args, "?")
bindings = append(bindings, arg)
}
}
updates = append(updates, col+" = "+fn.Name+"("+strings.Join(args, ", ")+")")
} else if indirect, isIndirect := val.(IndirectValue); isIndirect {
updates = append(updates, col+" = "+indirect.Reference)
bindings = append(bindings, indirect.Bindings...)
} else {
updates = append(updates, col+" = ?")
bindings = append(bindings, val)
}
}
words = append(words, strings.Join(updates, ", "))
}
return strings.Join(words, " "), bindings
}
// parseInsertValues adds placeholders and binding for every insert value, by parsing the type of the insert value
func parseInsertValues(insVals []interface{}) (placeholders []string, bindingsToAdd []interface{}) {
for _, val := range insVals {
if indirect, isIndirect := val.(IndirectValue); isIndirect {
placeholders = append(placeholders, indirect.Reference)
bindingsToAdd = append(bindingsToAdd, indirect.Bindings...)
} else if builder, isBuilder := val.(JSONBBuilder); isBuilder {
bSQL, bBindings := builder.Parse()
placeholders = append(placeholders, bSQL)
bindingsToAdd = append(bindingsToAdd, bBindings...)
} else {
placeholders = append(placeholders, "?")
bindingsToAdd = append(bindingsToAdd, val)
}
}
return placeholders, bindingsToAdd
}