Skip to content

Commit

Permalink
Merge pull request #44 from Versent/feat_osxkeychain_storage
Browse files Browse the repository at this point in the history
feat(storage) Keychain storage for password.
  • Loading branch information
wolfeidau authored Jul 27, 2017
2 parents fdb8337 + 9698e41 commit 02c1a1c
Show file tree
Hide file tree
Showing 15 changed files with 832 additions and 84 deletions.
32 changes: 32 additions & 0 deletions aws_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@ import (
"net/http"
"net/url"

"fmt"

"github.com/PuerkitoBio/goquery"
"github.com/pkg/errors"
)

// AWSAccount holds the AWS account name and roles
type AWSAccount struct {
Name string
Roles []*AWSRole
}

// ParseAWSAccounts extract the aws accounts from the saml assertion
func ParseAWSAccounts(samlAssertion string) ([]*AWSAccount, error) {
awsURL := "https://signin.aws.amazon.com/saml"

Expand All @@ -31,6 +35,7 @@ func ParseAWSAccounts(samlAssertion string) ([]*AWSAccount, error) {
return ExtractAWSAccounts(data)
}

// ExtractAWSAccounts extract the accounts from the AWS html page
func ExtractAWSAccounts(data []byte) ([]*AWSAccount, error) {
accounts := []*AWSAccount{}

Expand All @@ -53,3 +58,30 @@ func ExtractAWSAccounts(data []byte) ([]*AWSAccount, error) {

return accounts, nil
}

// AssignPrincipals assign principal from roles
func AssignPrincipals(awsRoles []*AWSRole, awsAccounts []*AWSAccount) {

awsPrincipalARNs := make(map[string]string)
for _, awsRole := range awsRoles {
awsPrincipalARNs[awsRole.RoleARN] = awsRole.PrincipalARN
}

for _, awsAccount := range awsAccounts {
for _, awsRole := range awsAccount.Roles {
awsRole.PrincipalARN = awsPrincipalARNs[awsRole.RoleARN]
}
}

}

// LocateRole locate role by name
func LocateRole(awsRoles []*AWSRole, roleName string) (*AWSRole, error) {
for _, awsRole := range awsRoles {
if awsRole.RoleARN == roleName {
return awsRole, nil
}
}

return nil, fmt.Errorf("Supplied RoleArn not found in saml assertion: %s", roleName)
}
42 changes: 42 additions & 0 deletions aws_account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,45 @@ func TestExtractAWSAccounts(t *testing.T) {
assert.Equal(t, role.RoleARN, "arn:aws:iam::000000000002:role/Production")
assert.Equal(t, role.Name, "Production")
}

func TestAssignPrincipals(t *testing.T) {
awsRoles := []*AWSRole{
&AWSRole{
PrincipalARN: "arn:aws:iam::000000000001:saml-provider/test-idp",
RoleARN: "arn:aws:iam::000000000001:role/Development",
},
}

awsAccounts := []*AWSAccount{
&AWSAccount{
Roles: []*AWSRole{
&AWSRole{
RoleARN: "arn:aws:iam::000000000001:role/Development",
},
},
},
}

AssignPrincipals(awsRoles, awsAccounts)

assert.Equal(t, "arn:aws:iam::000000000001:saml-provider/test-idp", awsAccounts[0].Roles[0].PrincipalARN)
}

func TestLocateRole(t *testing.T) {
awsRoles := []*AWSRole{
&AWSRole{
PrincipalARN: "arn:aws:iam::000000000001:saml-provider/test-idp",
RoleARN: "arn:aws:iam::000000000001:role/Development",
},
&AWSRole{
PrincipalARN: "arn:aws:iam::000000000002:saml-provider/test-idp",
RoleARN: "arn:aws:iam::000000000002:role/Development",
},
}

role, err := LocateRole(awsRoles, "arn:aws:iam::000000000001:role/Development")

assert.Empty(t, err)

assert.Equal(t, "arn:aws:iam::000000000001:role/Development", role.RoleARN)
}
8 changes: 4 additions & 4 deletions cmd/saml2aws/commands/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@ import (
)

// Exec execute the supplied command after seeding the environment
func Exec(profile string, providerName string, skipVerify bool, cmdline []string) error {
func Exec(loginFlags *LoginFlags, cmdline []string) error {

if len(cmdline) < 1 {
return fmt.Errorf("Command to execute required.")
}

ok, err := checkToken(profile)
ok, err := checkToken(loginFlags.Profile)
if err != nil {
return errors.Wrap(err, "error validating token")
}

if !ok {
err = Login(profile, providerName, skipVerify)
err = Login(loginFlags)
}
if err != nil {
return errors.Wrap(err, "error logging in")
}

sharedCreds := saml2aws.NewSharedCredentials(profile)
sharedCreds := saml2aws.NewSharedCredentials(loginFlags.Profile)

id, secret, token, err := sharedCreds.Load()
if err != nil {
Expand Down
154 changes: 111 additions & 43 deletions cmd/saml2aws/commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,46 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
"github.com/pkg/errors"
"github.com/versent/saml2aws"
"github.com/versent/saml2aws/helper/credentials"
)

// Login login to ADFS
func Login(profile, providerName string, skipVerify bool) error {
// LoginFlags login specific command flags
type LoginFlags struct {
Provider string
Profile string
Hostname string
Username string
Password string
RoleArn string
SkipVerify bool
SkipPrompt bool
}

config := saml2aws.NewConfigLoader(providerName)
// RoleSupplied role arn has been passed as a flag
func (lf *LoginFlags) RoleSupplied() bool {
return lf.RoleArn != ""
}

username, err := config.LoadUsername()
if err != nil {
return errors.Wrap(err, "error loading config file")
}
// Login login to ADFS
func Login(loginFlags *LoginFlags) error {

config := saml2aws.NewConfigLoader(loginFlags.Provider)

hostname, err := config.LoadHostname()
if err != nil {
return errors.Wrap(err, "error loading config file")
}

loginDetails, err := saml2aws.PromptForLoginDetails(username, hostname)
fmt.Println("LookupCredentials", hostname)

loginDetails, err := resolveLoginDetails(hostname, loginFlags)
if err != nil {
return errors.Wrap(err, "error accepting password")
}

fmt.Printf("%s https://%s\n", providerName, loginDetails.Hostname)
fmt.Printf("Authenticating to %s with URL https://%s\n", loginFlags.Provider, loginDetails.Hostname)

fmt.Printf("Authenticating to %s...\n", providerName)

opts := &saml2aws.SAMLOptions{Provider: providerName, SkipVerify: skipVerify}
opts := &saml2aws.SAMLOptions{Provider: loginFlags.Provider, SkipVerify: loginFlags.SkipVerify}

provider, err := saml2aws.NewSAMLClient(opts)
if err != nil {
Expand All @@ -55,6 +68,11 @@ func Login(profile, providerName string, skipVerify bool) error {
os.Exit(1)
}

err = credentials.SaveCredentials(loginDetails.Hostname, loginDetails.Username, loginDetails.Password)
if err != nil {
return errors.Wrap(err, "error storing password in keychain")
}

data, err := base64.StdEncoding.DecodeString(samlAssertion)
if err != nil {
return errors.Wrap(err, "error decoding saml assertion")
Expand All @@ -76,34 +94,7 @@ func Login(profile, providerName string, skipVerify bool) error {
return errors.Wrap(err, "error parsing aws roles")
}

var role = new(saml2aws.AWSRole)

if len(awsRoles) == 1 {
role = awsRoles[0]
} else if len(awsRoles) == 0 {
return errors.Wrap(err, "no roles available")
} else {
awsPrincipalARNs := make(map[string]string)
for _, awsRole := range awsRoles {
awsPrincipalARNs[awsRole.RoleARN] = awsRole.PrincipalARN
}

awsAccounts, err := saml2aws.ParseAWSAccounts(samlAssertion)
if err != nil {
return errors.Wrap(err, "error parsing aws role accounts")
}

for _, awsAccount := range awsAccounts {
for _, awsRole := range awsAccount.Roles {
awsRole.PrincipalARN = awsPrincipalARNs[awsRole.RoleARN]
}
}

role, err = saml2aws.PromptForAWSRoleSelection(awsAccounts)
if err != nil {
return errors.Wrap(err, "error selecting role")
}
}
role, err := resolveRole(awsRoles, samlAssertion, loginFlags)

fmt.Println("Selected role:", role.RoleARN)

Expand All @@ -125,12 +116,12 @@ func Login(profile, providerName string, skipVerify bool) error {

resp, err := svc.AssumeRoleWithSAML(params)
if err != nil {
return errors.Wrap(err, "error retieving sts credentials using SAML")
return errors.Wrap(err, "error retrieving STS credentials using SAML")
}

fmt.Println("Saving credentials")

sharedCreds := saml2aws.NewSharedCredentials(profile)
sharedCreds := saml2aws.NewSharedCredentials(loginFlags.Profile)

err = sharedCreds.Save(aws.StringValue(resp.Credentials.AccessKeyId), aws.StringValue(resp.Credentials.SecretAccessKey), aws.StringValue(resp.Credentials.SessionToken))
if err != nil {
Expand All @@ -141,11 +132,88 @@ func Login(profile, providerName string, skipVerify bool) error {
fmt.Println("")
fmt.Println("Your new access key pair has been stored in the AWS configuration")
fmt.Printf("Note that it will expire at %v\n", resp.Credentials.Expiration.Local())
fmt.Println("To use this credential, call the AWS CLI with the --profile option (e.g. aws --profile", profile, "ec2 describe-instances).")
fmt.Println("To use this credential, call the AWS CLI with the --profile option (e.g. aws --profile", loginFlags.Profile, "ec2 describe-instances).")

fmt.Println("Saving config:", config.Filename)
config.SaveUsername(loginDetails.Username)
config.SaveHostname(loginDetails.Hostname)

return nil
}

func resolveLoginDetails(hostname string, loginFlags *LoginFlags) (*saml2aws.LoginDetails, error) {

loginDetails := new(saml2aws.LoginDetails)

fmt.Println("hostname", hostname)

savedUsername, savedPassword, err := credentials.LookupCredentials(hostname)
if err != nil {
if !credentials.IsErrCredentialsNotFound(err) {
return nil, errors.Wrap(err, "error loading saved password")
}
}

// if you supply a username in a flag it takes precedence
if loginFlags.Username != "" {
loginDetails.Username = loginFlags.Username
} else {
fmt.Println("Using saved username")
loginDetails.Username = savedUsername
}

// if you supply a password in a flag it takes precedence
if loginFlags.Password != "" {
loginDetails.Password = loginFlags.Password
} else {
fmt.Println("Using saved password")
loginDetails.Password = savedPassword
}

fmt.Println("savedUsername", savedUsername)

// if skip prompt was passed just pass back the flag values
if loginFlags.SkipPrompt {
return &saml2aws.LoginDetails{
Username: loginDetails.Username,
Password: loginDetails.Password,
Hostname: loginFlags.Hostname,
}, nil
}

return saml2aws.PromptForLoginDetails(savedUsername, hostname, savedPassword)
}

func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string, loginFlags *LoginFlags) (*saml2aws.AWSRole, error) {
var role = new(saml2aws.AWSRole)

if len(awsRoles) == 1 {
if loginFlags.RoleSupplied() {
return saml2aws.LocateRole(awsRoles, loginFlags.RoleArn)
}
role = awsRoles[0]
} else if len(awsRoles) == 0 {
return nil, errors.New("no roles available")
}

awsAccounts, err := saml2aws.ParseAWSAccounts(samlAssertion)
if err != nil {
return nil, errors.Wrap(err, "error parsing aws role accounts")
}

saml2aws.AssignPrincipals(awsRoles, awsAccounts)

if loginFlags.RoleSupplied() {
return saml2aws.LocateRole(awsRoles, loginFlags.RoleArn)
}

for {
role, err = saml2aws.PromptForAWSRoleSelection(awsAccounts)
if err == nil {
break
}
fmt.Println("error selecting role")
}

return role, nil
}
13 changes: 13 additions & 0 deletions cmd/saml2aws/commands/login_darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package commands

import (
"fmt"

"github.com/versent/saml2aws/helper/credentials"
"github.com/versent/saml2aws/helper/osxkeychain"
)

func init() {
fmt.Println("adding osx helper")
credentials.CurrentHelper = &osxkeychain.Osxkeychain{}
}
18 changes: 18 additions & 0 deletions cmd/saml2aws/commands/login_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package commands

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/versent/saml2aws"
)

func TestResolveLoginDetails(t *testing.T) {

loginFlags := &LoginFlags{Hostname: "id.example.com", Username: "wolfeidau", Password: "testtestlol", SkipPrompt: true}

loginDetails, err := resolveLoginDetails("id.example.com", loginFlags)

assert.Empty(t, err)
assert.Equal(t, loginDetails, &saml2aws.LoginDetails{Username: "wolfeidau", Password: "testtestlol", Hostname: "id.example.com"})
}
Loading

0 comments on commit 02c1a1c

Please sign in to comment.