Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom error #16

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
8 changes: 6 additions & 2 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (association *Association) Count() int {
}

if err := query.Model(fieldValue).Count(&count).Error; err != nil {
association.Error = err
association.Error = NewGormError(err, query.SQL)
}
return count
}
Expand Down Expand Up @@ -371,7 +371,11 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
// setErr set error when the error is not nil. And return Association.
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
if association.scope != nil {
association.Error = NewGormError(err, association.scope.SQL)
} else {
association.Error = err
}
}
return association
}
4 changes: 4 additions & 0 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ func createCallback(scope *Scope) {
))
}

if scope.DB() != nil {
scope.DB().SQL = FormatSQL(scope.SQL, scope.SQLVars...)
}

// execute create sql: no primaryField
if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
Expand Down
4 changes: 4 additions & 0 deletions callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func queryCallback(scope *Scope) {

scope.prepareQuerySQL()

if scope.DB() != nil {
scope.DB().SQL = FormatSQL(scope.SQL, scope.SQLVars...)
}

if !scope.HasError() {
scope.db.RowsAffected = 0

Expand Down
12 changes: 11 additions & 1 deletion callback_row_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,24 @@ func rowQueryCallback(scope *Scope) {
if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL()

if scope.DB() != nil {
scope.DB().SQL = FormatSQL(scope.SQL, scope.SQLVars...)
}

if str, ok := scope.Get("gorm:query_hint"); ok {
scope.SQL = fmt.Sprint(str) + scope.SQL
}

if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
rowsResult.Rows = rows
if rowsResult.Error != nil {
rowsResult.Error = NewGormError(err, scope.SQL)
} else {
rowsResult.Error = err
}
}
}
}
4 changes: 2 additions & 2 deletions create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,11 @@ func TestFixFullTableScanWhenInsertIgnore(t *testing.T) {

DB.Callback().Query().Register("gorm:fix_full_table_scan", func(scope *gorm.Scope) {
if strings.Contains(scope.SQL, "SELECT") && strings.Contains(scope.SQL, "pandas") && len(scope.SQLVars) == 0 {
t.Error("Should skip force reload when ignore duplicate panda insert")
t.Error("Should skip force reload when ignore duplicate panda insert")
}
})

if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&pandaYuanYuan).Error != nil {
t.Error("Should ignore duplicate panda insert by insert modifier:IGNORE ")
}
}
}
29 changes: 27 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
)

var (
// ErrNoRecordsInResultSetSQL sql native error on querying with .row() function or similar
ErrNoRecordsInResultSetSQL = errors.New("sql: no rows in result set")
// ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL occurs when you attempt a query with invalid SQL
Expand All @@ -23,14 +25,17 @@ type Errors []error

// IsRecordNotFoundError returns true if error contains a RecordNotFound error
func IsRecordNotFoundError(err error) bool {
if err == nil {
return false
}
if errs, ok := err.(Errors); ok {
for _, err := range errs {
if err == ErrRecordNotFound {
if err.Error() == ErrRecordNotFound.Error() || err.Error() == ErrNoRecordsInResultSetSQL.Error() {
return true
}
}
}
return err == ErrRecordNotFound
return err.Error() == ErrRecordNotFound.Error() || err.Error() == ErrNoRecordsInResultSetSQL.Error()
}

// GetErrors gets all errors that have occurred and returns a slice of errors (Error type)
Expand Down Expand Up @@ -70,3 +75,23 @@ func (errs Errors) Error() string {
}
return strings.Join(errors, "; ")
}

// GormError is a custom error with the error and the SQL executed.
type GormError struct {
Err error
SQL string
}

// New is a construtor of custom error.
func NewGormError(err error, sql string) GormError {
return GormError{err, sql}
}

// Error return the error message.
func (e GormError) Error() string {
if e.Err != nil {
return e.Err.Error()
} else {
return "unexpected error"
}
}
75 changes: 75 additions & 0 deletions formatter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package gorm

import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)

func FormatSQL(sql string, values ...interface{}) string {
if len(values) > 0 {

formattedValues := []string{}

// duration
for _, value := range values {
indirectValue := reflect.Indirect(reflect.ValueOf(value))
if indirectValue.IsValid() {
value = indirectValue.Interface()
if t, ok := value.(time.Time); ok {
if t.IsZero() {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
} else {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
}
} else if b, ok := value.([]byte); ok {
if str := string(b); isPrintable(str) {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
} else {
formattedValues = append(formattedValues, "'<binary>'")
}
} else if r, ok := value.(driver.Valuer); ok {
if value, err := r.Value(); err == nil && value != nil {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
} else {
formattedValues = append(formattedValues, "NULL")
}
} else {
switch value.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
formattedValues = append(formattedValues, fmt.Sprintf("%v", value))
default:
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
}
}
} else {
formattedValues = append(formattedValues, "NULL")
}
}

// differentiate between $n placeholders or else treat like ?
if numericPlaceHolderRegexp.MatchString(sql) {
for index, value := range formattedValues {
placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
}
} else {
formattedValuesLength := len(formattedValues)
for index, value := range sqlRegexp.Split(sql, -1) {
sql += value
if index < formattedValuesLength {
sql += formattedValues[index]
}
}
}

}

sql = strings.ReplaceAll(sql, "\n", "")
sql = strings.ReplaceAll(sql, "\t", "")

return sql
}
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA=
github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down
12 changes: 7 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (

// DB contains information for current db connection
type DB struct {
SQL string

sync.RWMutex
Value interface{}
Error error
Expand Down Expand Up @@ -83,9 +85,9 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
}

db = &DB{
db: dbSQL,
logger: defaultLogger,
db: dbSQL,
logger: defaultLogger,

// Create a clone of the default logger to avoid mutating a shared object when
// multiple gorm connections are created simultaneously.
callbacks: DefaultCallback.clone(defaultLogger),
Expand Down Expand Up @@ -627,7 +629,7 @@ func (s *DB) NewRecord(value interface{}) bool {
// RecordNotFound check if returning ErrRecordNotFound error
func (s *DB) RecordNotFound() bool {
for _, err := range s.GetErrors() {
if err == ErrRecordNotFound {
if err != nil && err.Error() == ErrRecordNotFound.Error() {
return true
}
}
Expand Down Expand Up @@ -825,7 +827,7 @@ func (s *DB) AddError(err error) error {
}
}

s.Error = err
s.Error = NewGormError(err, s.SQL)
}
return err
}
Expand Down
2 changes: 1 addition & 1 deletion main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ func TestRaw(t *testing.T) {
}

DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
if err := DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error; err != nil && err.Error() != gorm.ErrRecordNotFound.Error() {
t.Error("Raw sql to update records")
}
}
Expand Down
4 changes: 2 additions & 2 deletions migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ func getPreparedUser(name string, role string) *User {
}

type Panda struct {
Number int64 `gorm:"unique_index:number"`
Name string `gorm:"column:name;type:varchar(255);default:null"`
Number int64 `gorm:"unique_index:number"`
Name string `gorm:"column:name;type:varchar(255);default:null"`
}

func runMigration() {
Expand Down
6 changes: 3 additions & 3 deletions preload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func TestNestedPreload1(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}

if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != nil && err.Error() != gorm.ErrRecordNotFound.Error() {
t.Error(err)
}
}
Expand Down Expand Up @@ -1104,7 +1104,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}

if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != nil && err.Error() != gorm.ErrRecordNotFound.Error() {
t.Error(err)
}
}
Expand Down Expand Up @@ -1161,7 +1161,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}

if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != nil && err.Error() != gorm.ErrRecordNotFound.Error() {
t.Error(err)
}
}
Expand Down
4 changes: 4 additions & 0 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ func (scope *Scope) Raw(sql string) *Scope {
func (scope *Scope) Exec() *Scope {
defer scope.trace(NowFunc())

if scope.DB() != nil {
scope.DB().SQL = FormatSQL(scope.SQL, scope.SQLVars...)
}

if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
Expand Down