Skip to content

Commit

Permalink
feat(session): added RefreshSession and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi committed Apr 9, 2024
1 parent 5250719 commit 27d3dd0
Show file tree
Hide file tree
Showing 12 changed files with 204 additions and 73 deletions.
99 changes: 95 additions & 4 deletions auth_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/pquerna/otp/totp"
"github.com/yaitoo/auth/masker"
"github.com/yaitoo/sqle"
Expand Down Expand Up @@ -145,17 +146,19 @@ func (a *Auth) getUserIDByEmail(ctx context.Context, email string) (shardid.ID,
return userID, nil
}

func (a *Auth) deleteUserToken(ctx context.Context, userID shardid.ID) error {
_, err := a.db.On(userID).
func (a *Auth) deleteUserToken(ctx context.Context, uid shardid.ID, token string) error {
_, err := a.db.On(uid).
ExecBuilder(ctx, a.createBuilder().
Delete("<prefix>user_token").
Where("user_id = {user_id}").
Param("user_id", userID))
If(token != "").And(" hash = {hash}").
Param("hash", hashToken(token)).
Param("user_id", uid))

if err != nil {
a.logger.Error("auth: deleteUserToken",
slog.String("tag", "db"),
slog.Int64("user_id", userID.Int64),
slog.Int64("user_id", uid.Int64),
slog.Any("err", err))
return ErrBadDatabase
}
Expand Down Expand Up @@ -771,3 +774,91 @@ func (a *Auth) checkSignInCode(ctx context.Context, userID shardid.ID, code stri

return nil
}

func (a *Auth) createSession(ctx context.Context, userID shardid.ID) (Session, error) {
s := Session{
UserID: userID.Int64,
}

now := time.Now()
accToken := jwt.NewWithClaims(jwt.SigningMethodHS256, UserClaims{
ID: userID.Int64,
IssuedAt: now.Unix(),
ExpirationTime: now.Add(a.accessTokenTTL).Unix(),
})

exp := time.Now().Add(a.refreshTokenTTL)

refToken := jwt.NewWithClaims(jwt.SigningMethodHS256, UserClaims{
ID: userID.Int64,
Nonce: randStr(12, dicAlphaNumber),
IssuedAt: now.Unix(),
ExpirationTime: exp.Unix(),
})

var err error
s.AccessToken, err = accToken.SignedString(a.jwtSignKey)
if err != nil {
a.logger.Error("auth: createSession",
slog.String("tag", "token"),
slog.String("step", "access_token"),
slog.Any("err", err))
return s, ErrUnknown
}

s.RefreshToken, err = refToken.SignedString(a.jwtSignKey)
if err != nil {
a.logger.Error("auth: createSession",
slog.String("tag", "token"),
slog.String("step", "refresh_token"),
slog.Any("err", err))
return s, ErrUnknown
}

_, err = a.db.On(userID).
ExecBuilder(ctx, a.createBuilder().
Insert("<prefix>user_token").
Set("user_id", userID.Int64).
Set("hash", hashToken(s.RefreshToken)).
Set("expires_on", exp).
Set("created_at", now).
End())

if err != nil {
a.logger.Error("auth: createSession",
slog.String("tag", "db"),
slog.Int64("user_id", userID.Int64),
slog.Any("err", err))
return s, ErrBadDatabase
}

return s, nil
}

func (a *Auth) checkRefreshToken(ctx context.Context, userID shardid.ID, token string) error {
var count int
err := a.db.On(userID).
QueryRowBuilder(ctx, a.createBuilder().
Select("<prefix>user_token", "count(user_id)").
Where("user_id = {user_id} AND hash = {hash}").
Param("user_id", userID.Int64).
Param("hash", hashToken(token))).
Scan(&count)

if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrInvalidRefreshToken
}
a.logger.Error("auth: checkRefreshToken",
slog.Int64("user_id", userID.Int64),
slog.String("token", token),
slog.Any("err", err))
return ErrBadDatabase
}

if count == 0 {
return ErrInvalidRefreshToken
}

return nil
}
69 changes: 18 additions & 51 deletions auth_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,40 @@ package auth

import (
"context"
"log/slog"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/yaitoo/sqle/shardid"
)

func (a *Auth) createSession(ctx context.Context, userID shardid.ID) (Session, error) {
s := Session{
UserID: userID.Int64,
}

accToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"id": userID,
"exp": time.Now().Add(a.accessTokenTTL).Unix(),
"ttl": a.accessTokenTTL,
})
// SignOut sign out the user, and delete his refresh token
func (a *Auth) SignOut(ctx context.Context, uid shardid.ID) error {
return a.deleteUserToken(ctx, uid, "")
}

exp := time.Now().Add(a.refreshTokenTTL)
now := time.Now()
refToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"id": userID,
"exp": exp.Unix(),
"ttl": a.refreshTokenTTL,
// RefreshSession refresh access token and refresh token
func (a *Auth) RefreshSession(ctx context.Context, refreshToken string) (Session, error) {
token, err := jwt.ParseWithClaims(refreshToken, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return a.jwtSignKey, nil
})

var err error
s.AccessToken, err = accToken.SignedString(a.jwtSignKey)
if err != nil {
a.logger.Error("auth: createSession",
slog.String("tag", "token"),
slog.String("step", "access_token"),
slog.Any("err", err))
return s, ErrUnknown
return noSession, err
}

s.RefreshToken, err = refToken.SignedString(a.jwtSignKey)
if err != nil {
a.logger.Error("auth: createSession",
slog.String("tag", "token"),
slog.String("step", "refresh_token"),
slog.Any("err", err))
return s, ErrUnknown
if !token.Valid {
return noSession, ErrInvalidRefreshToken
}

_, err = a.db.On(userID).
ExecBuilder(ctx, a.createBuilder().
Insert("<prefix>user_token").
Set("user_id", userID.Int64).
Set("hash", s.refreshTokenHash()).
Set("expires_on", exp).
Set("created_at", now).
End())
uc := token.Claims.(*UserClaims)

uid := shardid.Parse(uc.ID)

err = a.checkRefreshToken(ctx, uid, refreshToken)
if err != nil {
a.logger.Error("auth: createSession",
slog.String("tag", "db"),
slog.Int64("user_id", userID.Int64),
slog.Any("err", err))
return s, ErrBadDatabase
return noSession, err
}

return s, nil
}

func (a *Auth) RefreshSession(ctx context.Context) (Session, error) {
var s Session
go a.deleteUserToken(ctx, uid, refreshToken) // nolint: errcheck

return s, nil
return a.createSession(ctx, uid)
}
37 changes: 37 additions & 0 deletions auth_session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package auth

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"github.com/yaitoo/sqle/shardid"
)

func TestSession(t *testing.T) {
au := createAuthTest("./tests_session.db")

s, err := au.SignIn(context.TODO(), "[email protected]", "abc123", LoginOption{CreateIfNotExists: true})
require.NoError(t, err)

uid := shardid.Parse(s.UserID)
err = au.checkRefreshToken(context.Background(), uid, s.RefreshToken)
require.NoError(t, err)

// refresh token should be refreshed
rs, err := au.RefreshSession(context.Background(), s.RefreshToken)
require.NoError(t, err)
err = au.checkRefreshToken(context.Background(), uid, rs.RefreshToken)
require.NoError(t, err)
// old token should be deleted
err = au.checkRefreshToken(context.Background(), uid, s.RefreshToken)
require.ErrorIs(t, err, ErrInvalidRefreshToken)

// sign out should delete all tokens
err = au.SignOut(context.Background(), uid)
require.NoError(t, err)

err = au.checkRefreshToken(context.Background(), uid, rs.RefreshToken)
require.ErrorIs(t, err, ErrInvalidRefreshToken)

}
7 changes: 0 additions & 7 deletions auth_signin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package auth
import (
"context"
"errors"

"github.com/yaitoo/sqle/shardid"
)

// SignIn sign in with email and password.
Expand Down Expand Up @@ -55,8 +53,3 @@ func (a *Auth) SignInMobile(ctx context.Context, mobile, passwd string, option L

return noSession, err
}

// SignOut sign out the user, and delete his refresh token
func (a *Auth) SignOut(ctx context.Context, userID shardid.ID) error {
return a.deleteUserToken(ctx, userID)
}
4 changes: 2 additions & 2 deletions auth_signin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestSignIn(t *testing.T) {
QueryRowBuilder(context.Background(), authTest.createBuilder().
Select("<prefix>user_token", "user_id").
Where("hash = {hash}").
Param("hash", s.refreshTokenHash())).
Param("hash", hashToken(s.RefreshToken))).
Scan(&id)

r.NoError(err)
Expand Down Expand Up @@ -129,7 +129,7 @@ func TestSignInMobile(t *testing.T) {
QueryRowBuilder(context.Background(), authTest.createBuilder().
Select("<prefix>user_token", "user_id").
Where("hash = {hash}").
Param("hash", s.refreshTokenHash())).
Param("hash", hashToken(s.RefreshToken))).
Scan(&id)

r.NoError(err)
Expand Down
4 changes: 2 additions & 2 deletions auth_signin_with_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestSignInWithCode(t *testing.T) {
QueryRowBuilder(context.Background(), authTest.createBuilder().
Select("<prefix>user_token", "user_id").
Where("hash = {hash}").
Param("hash", s.refreshTokenHash())).
Param("hash", hashToken(s.RefreshToken))).
Scan(&id)

r.NoError(err)
Expand Down Expand Up @@ -145,7 +145,7 @@ func TestSignInMobileWithCode(t *testing.T) {
QueryRowBuilder(context.Background(), authTest.createBuilder().
Select("<prefix>user_token", "user_id").
Where("hash = {hash}").
Param("hash", s.refreshTokenHash())).
Param("hash", hashToken(s.RefreshToken))).
Scan(&id)

r.NoError(err)
Expand Down
4 changes: 2 additions & 2 deletions auth_signin_with_otp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestSignInWithOTP(t *testing.T) {
QueryRowBuilder(context.Background(), authTest.createBuilder().
Select("<prefix>user_token", "user_id").
Where("hash = {hash}").
Param("hash", s.refreshTokenHash())).
Param("hash", hashToken(s.RefreshToken))).
Scan(&id)

r.NoError(err)
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestSignInMobileWithOTP(t *testing.T) {
QueryRowBuilder(context.Background(), authTest.createBuilder().
Select("<prefix>user_token", "user_id").
Where("hash = {hash}").
Param("hash", s.refreshTokenHash())).
Param("hash", hashToken(s.RefreshToken))).
Scan(&id)

r.NoError(err)
Expand Down
4 changes: 4 additions & 0 deletions crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func verifyHash(h hash.Hash, hash string, source, salt string) bool {
return hash == v
}

func hashToken(token string) string {
return generateHash(sha256.New(), token, "")
}

func getJWTKey(key string) []byte {
return sha256.New().Sum([]byte(key))
}
Expand Down
2 changes: 2 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ var (

ErrOTPNotMatched = errors.New("auth: otp_not_matched")
ErrCodeNotMatched = errors.New("auth: code_not_matched")

ErrInvalidRefreshToken = errors.New("auth: invalid_refresh_token")
)
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require (
)

require (
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/boombuler/barcode v1.0.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/iancoleman/strcase v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
3 changes: 2 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/boombuler/barcode v1.0.1 h1:NDBbPmhS+EqABEs5Kg3n/5ZNjy73Pz7SIV+KCeqyXcs=
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
Loading

0 comments on commit 27d3dd0

Please sign in to comment.