Skip to content

Commit

Permalink
Merge pull request containerd#8735 from iain-macdonald/iain-macdonald…
Browse files Browse the repository at this point in the history
…/issue-6377

remotes/docker/authorizer.go: refresh OAuth tokens when they expire
  • Loading branch information
estesp authored Jan 29, 2024
2 parents 1b6019b + af6a90b commit f5f84a9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
20 changes: 10 additions & 10 deletions core/remotes/docker/auth/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ type TokenOptions struct {

// OAuthTokenResponse is response from fetching token with a OAuth POST request
type OAuthTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresInSeconds int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
}

// FetchTokenWithOAuth fetches a token using a POST request
Expand Down Expand Up @@ -152,11 +152,11 @@ func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.

// FetchTokenResponse is response from fetching token with GET request
type FetchTokenResponse struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresInSeconds int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
}

// FetchToken fetches a token using a GET request
Expand Down
27 changes: 22 additions & 5 deletions core/remotes/docker/authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"net/http"
"strings"
"sync"
"time"

"github.com/containerd/containerd/v2/core/remotes/docker/auth"
remoteerrors "github.com/containerd/containerd/v2/core/remotes/errors"
Expand Down Expand Up @@ -205,9 +206,10 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R
// authResult is used to control limit rate.
type authResult struct {
sync.WaitGroup
token string
refreshToken string
err error
token string
refreshToken string
expirationTime *time.Time
err error
}

// authHandler is used to handle auth request per registry server.
Expand Down Expand Up @@ -270,8 +272,12 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
// Docs: https://docs.docker.com/registry/spec/auth/scope
scoped := strings.Join(to.Scopes, " ")

// Keep track of the expiration time of cached bearer tokens so they can be
// refreshed when they expire without a server roundtrip.
var expirationTime *time.Time

ah.Lock()
if r, exist := ah.scopedTokens[scoped]; exist {
if r, exist := ah.scopedTokens[scoped]; exist && (r.expirationTime == nil || r.expirationTime.After(time.Now())) {
ah.Unlock()
r.Wait()
return r.token, r.refreshToken, r.err
Expand All @@ -285,7 +291,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st

defer func() {
token = fmt.Sprintf("Bearer %s", token)
r.token, r.refreshToken, r.err = token, refreshToken, err
r.token, r.refreshToken, r.err, r.expirationTime = token, refreshToken, err, expirationTime
r.Done()
}()

Expand All @@ -311,6 +317,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
if err != nil {
return "", "", err
}
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.Token, resp.RefreshToken, nil
}
log.G(ctx).WithFields(log.Fields{
Expand All @@ -320,16 +327,26 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
}
return "", "", err
}
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.AccessToken, resp.RefreshToken, nil
}
// do request anonymously
resp, err := auth.FetchToken(ctx, ah.client, ah.header, to)
if err != nil {
return "", "", fmt.Errorf("failed to fetch anonymous token: %w", err)
}
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.Token, resp.RefreshToken, nil
}

func getExpirationTime(expiresInSeconds int) *time.Time {
if expiresInSeconds <= 0 {
return nil
}
expirationTime := time.Now().Add(time.Duration(expiresInSeconds) * time.Second)
return &expirationTime
}

func invalidAuthorization(ctx context.Context, c auth.Challenge, responses []*http.Response) (retry bool, _ error) {
errStr := c.Parameters["error"]
if errStr == "" {
Expand Down

0 comments on commit f5f84a9

Please sign in to comment.