diff --git a/cmd/initTestMocks_test.go b/cmd/initTestMocks_test.go index f410e24a..f85f76db 100644 --- a/cmd/initTestMocks_test.go +++ b/cmd/initTestMocks_test.go @@ -1,16 +1,24 @@ package cmd import ( + "context" "crypto/ecdsa" "crypto/rand" + "fmt" + "github.com/avast/retry-go" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "math/big" "razor/cmd/mocks" "razor/path" pathPkgMocks "razor/path/mocks" "razor/utils" utilsPkgMocks "razor/utils/mocks" + "strings" + "testing" + "time" ) var ( @@ -167,3 +175,75 @@ func SetUpMockInterfaces() { var privateKey, _ = ecdsa.GenerateKey(crypto.S256(), rand.Reader) var TxnOpts, _ = bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(31000)) // Used any random big int for chain ID + +func TestInvokeFunctionWithRetryAttempts(t *testing.T) { + tests := []struct { + name string + methodName string + timeout time.Duration + expectError bool + expectedErr string + expectedVals bool + }{ + { + name: "Normal Case - Fast Method", + methodName: "FastMethod", + timeout: 5 * time.Second, + expectError: false, + expectedErr: "", + expectedVals: true, + }, + { + name: "Timeout Case - Slow Method", + methodName: "SlowMethod", + timeout: 0 * time.Second, + expectError: true, + expectedErr: "context deadline exceeded", + expectedVals: false, + }, + } + + // Dummy RPC struct + dummyRPC := &DummyRPC{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up context with timeout for each test case + ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) + defer cancel() + + SetUpMockInterfaces() + retryUtilsMock.On("RetryAttempts", mock.AnythingOfType("uint")).Return(retry.Attempts(4)) + + returnedValues, err := utils.InvokeFunctionWithRetryAttempts(ctx, dummyRPC, tt.methodName) + + if tt.expectError { + assert.Error(t, err) + assert.True(t, strings.Contains(err.Error(), tt.expectedErr), "Expected error to contain: %v, but got: %v", tt.expectedErr, err.Error()) + } else { + assert.NoError(t, err) + } + + if tt.expectedVals { + assert.NotNil(t, returnedValues) + } else { + assert.Nil(t, returnedValues) + } + }) + } +} + +// Dummy interface with methods +type DummyRPC struct{} + +// A fast method that simulates successful execution +func (d *DummyRPC) FastMethod() error { + return nil +} + +// A slow method that simulates a long-running process +func (d *DummyRPC) SlowMethod() error { + fmt.Println("Sleeping...") + time.Sleep(3 * time.Second) // Simulate delay to trigger timeout + return nil +}