diff --git a/runtime/failpoint.go b/runtime/failpoint.go index 7559140..4bb3d8e 100644 --- a/runtime/failpoint.go +++ b/runtime/failpoint.go @@ -20,8 +20,7 @@ import ( ) type Failpoint struct { - t *terms - + t *terms failpointMu sync.RWMutex } @@ -54,3 +53,34 @@ func (fp *Failpoint) Acquire() (interface{}, error) { func (fp *Failpoint) BadType(v interface{}, t string) { fmt.Printf("failpoint: %q got value %v of type \"%T\" but expected type %q\n", fp.t.fpath, v, v, t) } + +func (fp *Failpoint) SetTerm(t *terms) { + fp.failpointMu.Lock() + defer fp.failpointMu.Unlock() + + fp.t = t +} + +func (fp *Failpoint) ClearTerm() error { + fp.failpointMu.Lock() + defer fp.failpointMu.Unlock() + + if fp.t == nil { + return ErrDisabled + } + fp.t = nil + + return nil +} + +func (fp *Failpoint) Status() (string, int, error) { + fp.failpointMu.RLock() + defer fp.failpointMu.RUnlock() + + t := fp.t + if t == nil { + return "", 0, ErrDisabled + } + + return t.desc, t.counter, nil +} diff --git a/runtime/runtime.go b/runtime/runtime.go index 1e57932..3230fb6 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -85,10 +85,7 @@ func Enable(name, inTerms string) error { return err } - fp.failpointMu.Lock() - defer fp.failpointMu.Unlock() - - fp.t = t + fp.SetTerm(t) return nil } @@ -102,15 +99,7 @@ func Disable(name string) error { return ErrNoExist } - fp.failpointMu.Lock() - defer fp.failpointMu.Unlock() - - if fp.t == nil { - return ErrDisabled - } - fp.t = nil - - return nil + return fp.ClearTerm() } // Status gives the current setting and execution count for the failpoint @@ -122,15 +111,7 @@ func Status(failpath string) (string, int, error) { return "", 0, ErrNoExist } - fp.failpointMu.RLock() - defer fp.failpointMu.RUnlock() - - t := fp.t - if t == nil { - return "", 0, ErrDisabled - } - - return t.desc, t.counter, nil + return fp.Status() } func List() []string {