diff --git a/.gitignore b/.gitignore index 2b267ab..d94d6cd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.swp +debug.log bin/* dist/ diff --git a/README.md b/README.md index 88b892c..5af4599 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,6 @@ Available Commands: Flags: -a, --approve Pass an optional approve flag as an argument which will only approve and not merge selected repos. --close Pass an optional argument to close a pull request. - --commit-msg string Add a custom message when approving a pull request. -c, --config string Pass an optional config file as an argument with list of repositories. -d, --delay int Set the value of delay, which will determine how long to wait between mergeing pull requests. Default is (6) seconds. (default 6) -e, --enterprise-base-url string For Github Enterprise users, you can pass your enterprise base. Format: http(s)://[hostname]/ diff --git a/go.mod b/go.mod index 855f494..6714def 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.19 require ( github.com/AlecAivazis/survey/v2 v2.3.6 - github.com/google/go-github/v45 v45.2.0 github.com/olekukonko/tablewriter v0.0.5 + github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 github.com/spf13/cobra v1.6.1 github.com/spf13/viper v1.14.0 github.com/stretchr/testify v1.9.0 @@ -16,7 +16,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-querystring v1.1.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.1 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect @@ -30,12 +29,12 @@ require ( github.com/pelletier/go-toml/v2 v2.0.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 // indirect github.com/spf13/afero v1.9.2 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.4.1 // indirect - golang.org/x/crypto v0.13.0 // indirect golang.org/x/net v0.15.0 // indirect golang.org/x/sys v0.12.0 // indirect golang.org/x/term v0.12.0 // indirect diff --git a/go.sum b/go.sum index cbcef3c..4f73b9a 100644 --- a/go.sum +++ b/go.sum @@ -109,10 +109,6 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-github/v45 v45.2.0 h1:5oRLszbrkvxDDqBCNj2hjDZMKmvexaZ1xw/FCD+K3FI= -github.com/google/go-github/v45 v45.2.0/go.mod h1:FObaZJEDSTa/WGCzZ2Z3eoCDXWJKMenWWTrd8jrta28= -github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= -github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -184,6 +180,10 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 h1:cYCy18SHPKRkvclm+pWm1Lk4YrREb4IOIb/YdFO0p2M= +github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8= +github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0= +github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466/go.mod h1:9dIRpgIY7hVhoqfe0/FcYp0bpInZaT7dc3BYOprrIUE= github.com/spf13/afero v1.9.2 h1:j49Hj62F0n+DaZ1dDCvhABaPNSGNkt32oRFxI33IEMw= github.com/spf13/afero v1.9.2/go.mod h1:iUV7ddyEEZPO5gA3zD4fJt6iStLlL+Lg4m2cihcDf8Y= github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w= @@ -226,8 +226,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= diff --git a/pkg/cli/gomerge/gomerge.go b/pkg/cli/gomerge/gomerge.go index 4dfe229..b6b3251 100644 --- a/pkg/cli/gomerge/gomerge.go +++ b/pkg/cli/gomerge/gomerge.go @@ -1,10 +1,11 @@ package gomerge import ( - "github.com/cian911/go-merge/pkg/cli/list" - "github.com/cian911/go-merge/pkg/cli/version" "github.com/spf13/cobra" "github.com/spf13/viper" + + "github.com/cian911/go-merge/pkg/cli/list" + "github.com/cian911/go-merge/pkg/cli/version" ) func New() (c *cobra.Command) { @@ -13,20 +14,30 @@ func New() (c *cobra.Command) { Short: "Gomerge makes it simple to merge an open pull request from your terminal.", } - c.PersistentFlags().StringP("repo", "r", "", "Pass name of repository as argument (organization/repo).") + c.PersistentFlags(). + StringP("repo", "r", "", "Pass name of repository as argument (organization/repo).") + c.PersistentFlags(). + StringArrayP("label", "l", []string{}, "Pass an optional list of labels to filter pull requests. (label1,label2,label3)") c.PersistentFlags().StringP("token", "t", "", "Pass your github personal access token (PAT).") - c.PersistentFlags().StringP("config", "c", "", "Pass an optional config file as an argument with list of repositories.") - c.PersistentFlags().BoolP("approve", "a", false, "Pass an optional approve flag as an argument which will only approve and not merge selected repos.") - c.PersistentFlags().StringP("merge-method", "m", "", "Pass an optional merge method for the pull request (merge [default], squash, rebase).") - c.PersistentFlags().BoolP("skip", "s", false, "Pass an optional flag to skip a pull request and continue if one or more are not mergable.") - c.PersistentFlags().BoolP("close", "", false, "Pass an optional argument to close a pull request.") - c.PersistentFlags().IntP("delay", "d", 6, "Set the value of delay, which will determine how long to wait between mergeing pull requests. Default is (6) seconds.") - c.PersistentFlags().StringP("enterprise-base-url", "e", "", "For Github Enterprise users, you can pass your enterprise base. Format: http(s)://[hostname]/") - c.PersistentFlags().StringP("commit-msg", "", "", "Add a custom message when approving a pull request.") + c.PersistentFlags(). + StringP("config", "c", "", "Pass an optional config file as an argument with list of repositories.") + c.PersistentFlags(). + BoolP("approve", "a", false, "Pass an optional approve flag as an argument which will only approve and not merge selected repos.") + c.PersistentFlags(). + StringP("merge-method", "m", "", "Pass an optional merge method for the pull request (merge [default], squash, rebase).") + c.PersistentFlags(). + BoolP("skip", "s", false, "Pass an optional flag to skip a pull request and continue if one or more are not mergable.") + c.PersistentFlags(). + BoolP("close", "", false, "Pass an optional argument to close a pull request.") + c.PersistentFlags(). + IntP("delay", "d", 6, "Set the value of delay, which will determine how long to wait between mergeing pull requests. Default is (6) seconds.") + c.PersistentFlags(). + StringP("enterprise-base-url", "e", "", "For Github Enterprise users, you can pass your enterprise base. Format: http(s)://[hostname]/") c.MarkFlagRequired("token") viper.BindPFlag("repo", c.PersistentFlags().Lookup("repo")) + viper.BindPFlag("label", c.PersistentFlags().Lookup("label")) viper.BindPFlag("token", c.PersistentFlags().Lookup("token")) viper.BindPFlag("config", c.PersistentFlags().Lookup("config")) viper.BindPFlag("approve", c.PersistentFlags().Lookup("approve")) @@ -35,7 +46,6 @@ func New() (c *cobra.Command) { viper.BindPFlag("delay", c.PersistentFlags().Lookup("delay")) viper.BindPFlag("close", c.PersistentFlags().Lookup("close")) viper.BindPFlag("enterprise-base-url", c.PersistentFlags().Lookup("enterprise-base-url")) - viper.BindPFlag("commit-msg", c.PersistentFlags().Lookup("commit-msg")) c.AddCommand(list.NewCommand()) c.AddCommand(version.NewCommand()) diff --git a/pkg/cli/list/list.go b/pkg/cli/list/list.go index 95e1953..1300412 100644 --- a/pkg/cli/list/list.go +++ b/pkg/cli/list/list.go @@ -1,22 +1,21 @@ package list import ( - "context" "fmt" "log" "os" - "strconv" "strings" "time" "github.com/AlecAivazis/survey/v2" - "github.com/cian911/go-merge/pkg/gitclient" - "github.com/cian911/go-merge/pkg/printer" - "github.com/cian911/go-merge/pkg/utils" - "github.com/google/go-github/v45/github" "github.com/olekukonko/tablewriter" + "github.com/shurcooL/githubv4" "github.com/spf13/cobra" "github.com/spf13/viper" + + "github.com/cian911/go-merge/pkg/gitclient" + "github.com/cian911/go-merge/pkg/printer" + "github.com/cian911/go-merge/pkg/utils" ) var ( @@ -24,14 +23,15 @@ var ( repo = "" approveOnly = false configPresent = false - mergeMethod = "merge" ) const ( - TokenEnvVar = "GITHUB_TOKEN" + TokenEnvVar = "GITHUB_TOKEN" + STATUS_SUCCESS = 0 + STATUS_WAITING = 1 + STATUS_FAILED = 2 ) -// TODO: Refactor NewCommnd func NewCommand() (c *cobra.Command) { c = &cobra.Command{ Use: "list", @@ -39,9 +39,10 @@ func NewCommand() (c *cobra.Command) { Run: func(cmd *cobra.Command, args []string) { ctx := cmd.Context() orgRepo := viper.GetString("repo") + labels := getLabels() configFile := viper.GetString("config") approveOnly = viper.GetBool("approve") - mergeMethod := viper.GetString("merge-method") + mergeMethod := getMergeMethod() flagToken := viper.GetString("token") skip := viper.GetBool("skip") closePr := viper.GetBool("close") @@ -54,7 +55,9 @@ func NewCommand() (c *cobra.Command) { } if !configPresent && len(orgRepo) <= 0 { - log.Fatal("You must pass either a config file or repository as argument to continue.") + log.Fatal( + "You must pass either a config file or repository as argument to continue.", + ) } configToken := viper.GetString("token") @@ -68,72 +71,66 @@ func NewCommand() (c *cobra.Command) { isEnterprise = true } - ghClient := gitclient.Client(token, ctx, isEnterprise) - pullRequestsArray := []*github.PullRequest{} + ghClient := gitclient.ClientV4(token, ctx, isEnterprise) + + pullRequestsArray := []*gitclient.PullRequest{} table := initTable() - ctx = commitMsg(ctx, viper.GetString("commit-msg")) - // If user has passed a config file + var org string + var repositories []string = nil + if configPresent { org = viper.GetString("organization") - - for _, v := range viper.GetStringSlice("repositories") { - pullRequests, _, err := ghClient.PullRequests.List(ctx, org, v, nil) - if err != nil { - log.Fatal(err) - } - - // Use variadic notation to append to array here... - pullRequestsArray = append(pullRequestsArray, pullRequests...) + if len(viper.GetStringSlice("repositories")) > 0 { + repositories = viper.GetStringSlice("repositories") + } else { + repositories = append(repositories, "") } - - if len(pullRequestsArray) == 0 { - fmt.Println("No open pull requests found for configured repositories.") - os.Exit(0) + } else { + parts := strings.Split(orgRepo, "/") + + if len(parts) == 1 { + org = parts[0] + repositories = append(repositories, "") + } else if len(parts) == 2 { + org = parts[0] + repositories = append(repositories, parts[1]) + } else { + log.Fatal("You must pass your repo name like so: organization/repository to continue.") } + } - selectedIds := promptAndFormat(pullRequestsArray, table) - for x, id := range selectedIds { - p := parsePrId(id) - prId, _ := strconv.Atoi(p[0]) - if approveOnly { - gitclient.ApprovePullRequest(ghClient, ctx, org, p[1], prId, skip) - } else if closePr { - gitclient.ClosePullRequest(ghClient, ctx, org, p[1], prId, pullRequestsArray[x]) - } else { - gitclient.MergePullRequest(ghClient, ctx, org, p[1], prId, mergeMethod, skip) - - // delay between merges to allow other active PRs to get synced - time.Sleep(time.Duration(delay) * time.Second) - } - } - } else { - org, repo = parseOrgRepo(orgRepo, configPresent) - // if user has NOT passed a config file - pullRequests, _, err := ghClient.PullRequests.List(ctx, org, repo, nil) + for _, v := range repositories { + pullRequests, err := gitclient.GetPullRequests(ghClient, ctx, org, v, &labels) if err != nil { log.Fatal(err) } - if len(pullRequests) == 0 { - fmt.Println("No open pull requests found for given repository.") - os.Exit(0) - } + // Use variadic notation to append to array here... + pullRequestsArray = append(pullRequestsArray, pullRequests...) + } + + if len(pullRequestsArray) == 0 { + fmt.Println("No open pull requests found for configured repositories.") + os.Exit(0) + } - selectedIds := promptAndFormat(pullRequests, table) - for x, id := range selectedIds { - p := parsePrId(id) - prId, _ := strconv.Atoi(p[0]) - if approveOnly { - gitclient.ApprovePullRequest(ghClient, ctx, org, repo, prId, skip) - } else if closePr { - gitclient.ClosePullRequest(ghClient, ctx, org, repo, prId, pullRequests[x]) - } else { - gitclient.MergePullRequest(ghClient, ctx, org, repo, prId, mergeMethod, skip) - - // delay between merges to allow other active PRs to get synced + selectedIds := promptAndFormat(pullRequestsArray, table) + for i, pr := range selectedIds { + if approveOnly { + gitclient.ApprovePullRequest(ghClient, ctx, pr, skip) + } else if closePr { + gitclient.ClosePullRequest(ghClient, ctx, pr, skip) + } else { + // delay between merges to allow other active PRs to get synced + if i > 0 { time.Sleep(time.Duration(delay) * time.Second) } + if pr.NeedsReview { + gitclient.ApprovePullRequest(ghClient, ctx, pr, skip) + } + gitclient.MergePullRequest(ghClient, ctx, pr, &mergeMethod, skip) + } } }, @@ -142,30 +139,46 @@ func NewCommand() (c *cobra.Command) { return } -func promptAndFormat(pullRequests []*github.PullRequest, table *tablewriter.Table) []string { +func promptAndFormat( + pullRequests []*gitclient.PullRequest, + table *tablewriter.Table, +) (selectedPullRequests []*gitclient.PullRequest) { prIds := []string{} - data := []string{} - repoName := "" for _, pr := range pullRequests { - if pr.Head.Repo == nil { - repoName = "Forked Likely Repository Removed." - } else { - repoName = *pr.Head.Repo.Name - } - prIds = append(prIds, fmt.Sprintf("%d | %s", *pr.Number, repoName)) - data = formatTable(pr, org, repoName) + prIds = append( + prIds, + fmt.Sprintf("%d | %s/%s", pr.Number, pr.RepositoryOwner, pr.RepositoryName), + ) + + data, status := formatTable(pr) if len(data) == 0 { // If there is an issue with the pr, skip continue } - table = printer.SuccessStyle(table, data) + switch status { + case STATUS_SUCCESS: + table = printer.SuccessStyle(table, data) + case STATUS_WAITING: + table = printer.WaitingStyle(table, data) + case STATUS_FAILED: + table = printer.ErrorStyle(table, data) + } } table.Render() prompt, selectedIds := selectPrIds(prIds) survey.AskOne(prompt, &selectedIds) - return selectedIds + selectedPullRequests = make([]*gitclient.PullRequest, len(selectedIds)) + for idIndex, id := range selectedIds { + for prIndex, prId := range prIds { + if id == prId { + selectedPullRequests[idIndex] = pullRequests[prIndex] + break + } + } + } + return } func initTable() (table *tablewriter.Table) { @@ -182,39 +195,37 @@ func initTable() (table *tablewriter.Table) { return } -func formatTable(pr *github.PullRequest, org, repo string) (data []string) { - if (pr.Number == nil) || (pr.State == nil) || (pr.Title == nil) || (pr.CreatedAt == nil) { - return - } - data = []string{ - fmt.Sprintf("#%s", printer.FormatID(pr.Number)), - printer.FormatString(pr.State), - printer.FormatString(pr.Title), - fmt.Sprintf("%s/%s", org, repo), - printer.FormatTime(pr.CreatedAt), +func statusIcon(state string) (icon string, status int) { + switch state { + case "SUCCESS": + icon = "" + status = STATUS_SUCCESS + case "IN_PROGRESS": + icon = "" + status = STATUS_WAITING + case "FAILURE": + icon = "󰅙" + status = STATUS_FAILED + default: + icon = "" } return } -func parseOrgRepo(repo string, configPresent bool) (org, repository string) { - str := strings.Split(repo, "/") - - if len(str) <= 1 { - log.Fatal("You must pass your repo name like so: organization/repository to continue.") +func formatTable(pr *gitclient.PullRequest) (data []string, status int) { + icon, status := statusIcon(pr.StatusRollup) + data = []string{ + fmt.Sprintf("#%d", pr.Number), + fmt.Sprintf("%s %s", pr.State, icon), + pr.Title, + fmt.Sprintf("%s/%s", pr.RepositoryOwner, pr.RepositoryName), + printer.FormatTime(&pr.CreatedAt), } - org = str[0] - repository = str[1] - return } -func parsePrId(prId string) []string { - str := strings.Split(strings.ReplaceAll(prId, " ", ""), "|") - return str -} - func getToken(flag, config string) (str string, err error) { if flag != str { return flag, nil @@ -247,10 +258,30 @@ func selectPrIds(prIds []string) (*survey.MultiSelect, []string) { return prompt, selectedIds } -func commitMsg(ctx context.Context, msg string) context.Context { - if len(msg) != 0 { - return context.WithValue(ctx, "message", msg) - } +func getMergeMethod() githubv4.PullRequestMergeMethod { + method := viper.GetString("merge-method") + switch method { + case "merge": + return githubv4.PullRequestMergeMethodMerge + case "rebase": + return githubv4.PullRequestMergeMethodRebase + case "squash": + return githubv4.PullRequestMergeMethodSquash + } + if method != "" { + log.Fatalf( + "Unknown merge method %s. Please use one of the following: merge, rebase, squash", + method, + ) + } + return githubv4.PullRequestMergeMethodMerge +} - return context.WithValue(ctx, "message", gitclient.DefaultApproveMsg()) +func getLabels() (labels []githubv4.String) { + raw_labels := viper.GetStringSlice("label") + labels = make([]githubv4.String, len(raw_labels)) + for i, label := range raw_labels { + labels[i] = githubv4.String(label) + } + return } diff --git a/pkg/cli/list/list_test.go b/pkg/cli/list/list_test.go index db47b95..b8d1bf6 100644 --- a/pkg/cli/list/list_test.go +++ b/pkg/cli/list/list_test.go @@ -5,27 +5,12 @@ import ( "testing" "time" - "github.com/cian911/go-merge/pkg/printer" - "github.com/google/go-github/v45/github" "github.com/olekukonko/tablewriter" "github.com/stretchr/testify/assert" -) - -func TestParseOrgRepo(t *testing.T) { - t.Run("It returns a valid tuple when no config is present", func(t *testing.T) { - repo := "Cian911/syncwave" - configPresent := false - - want1 := "Cian911" - want2 := "syncwave" - - got1, got2 := parseOrgRepo(repo, configPresent) - if got1 != want1 || got2 != want2 { - t.Errorf("got1: %s, got2: %s, want1: %s, want2: %s", got1, got2, want1, want2) - } - }) -} + "github.com/cian911/go-merge/pkg/gitclient" + "github.com/cian911/go-merge/pkg/printer" +) func TestInitTable(t *testing.T) { t.Run("It returns a tablewriter pointer", func(t *testing.T) { @@ -37,50 +22,33 @@ func TestInitTable(t *testing.T) { } func TestFormatTable(t *testing.T) { - var ( - org = "Cian911" - repo = "syncwave" - ) - t.Run("It returns a string array", func(t *testing.T) { number := 1 state := "#open" title := "My Pr" createdAt := time.Now() - pr := &github.PullRequest{ - Number: &number, - State: &state, - Title: &title, - CreatedAt: &createdAt, + pr := &gitclient.PullRequest{ + RepositoryOwner: "Cian911", + RepositoryName: "syncwave", + Number: number, + State: state, + Title: title, + StatusRollup: "SUCCESS", + CreatedAt: createdAt, } - got := formatTable(pr, org, repo) + got, _ := formatTable(pr) want := []string{ "#1", - "#open", + "#open ", "My Pr", "Cian911/syncwave", - printer.FormatTime(pr.CreatedAt), + printer.FormatTime(&pr.CreatedAt), } assert.Equal(t, got, want) }) - - t.Run("It returns an empty string array when attrs are not present in pr struct", func(t *testing.T) { - state := "#open" - title := "My Pr" - - pr := &github.PullRequest{ - State: &state, - Title: &title, - } - - got := formatTable(pr, org, repo) - want := []string(nil) - - assert.Equal(t, got, want) - }) } func TestListGetToken(t *testing.T) { @@ -90,62 +58,83 @@ func TestListGetToken(t *testing.T) { envVar = "env@token" ) - t.Run("When a given token is set by flag, it should return token as the flag value", func(t *testing.T) { - got, err := getToken(flag, "") - want := flag - assert.Nil(t, err) - assert.Equal(t, want, got) - }) + t.Run( + "When a given token is set by flag, it should return token as the flag value", + func(t *testing.T) { + got, err := getToken(flag, "") + want := flag + assert.Nil(t, err) + assert.Equal(t, want, got) + }, + ) - t.Run("When a given token is set by config, it should return token as defined on the configuration file", func(t *testing.T) { - got, err := getToken("", config) - want := config - assert.Nil(t, err) - assert.Equal(t, want, got) - }) + t.Run( + "When a given token is set by config, it should return token as defined on the configuration file", + func(t *testing.T) { + got, err := getToken("", config) + want := config + assert.Nil(t, err) + assert.Equal(t, want, got) + }, + ) - t.Run("When a given token is set by environment variable, it should return token as defined on the environment", func(t *testing.T) { - os.Setenv(TokenEnvVar, envVar) - got, err := getToken("", "") - want := envVar - assert.Nil(t, err) - assert.Equal(t, want, got) - os.Unsetenv(TokenEnvVar) - }) + t.Run( + "When a given token is set by environment variable, it should return token as defined on the environment", + func(t *testing.T) { + os.Setenv(TokenEnvVar, envVar) + got, err := getToken("", "") + want := envVar + assert.Nil(t, err) + assert.Equal(t, want, got) + os.Unsetenv(TokenEnvVar) + }, + ) - t.Run("When a given token is set on flag and config file, it should return the value set on flag", func(t *testing.T) { - got, err := getToken(flag, config) - want := flag - assert.Nil(t, err) - assert.Equal(t, want, got) - }) + t.Run( + "When a given token is set on flag and config file, it should return the value set on flag", + func(t *testing.T) { + got, err := getToken(flag, config) + want := flag + assert.Nil(t, err) + assert.Equal(t, want, got) + }, + ) - t.Run("When a given token is set on flag and environment, it should return the value set on the flag", func(t *testing.T) { - os.Setenv(TokenEnvVar, envVar) - got, err := getToken(flag, "") - want := flag - assert.Nil(t, err) - assert.Equal(t, want, got) - os.Unsetenv(TokenEnvVar) - }) + t.Run( + "When a given token is set on flag and environment, it should return the value set on the flag", + func(t *testing.T) { + os.Setenv(TokenEnvVar, envVar) + got, err := getToken(flag, "") + want := flag + assert.Nil(t, err) + assert.Equal(t, want, got) + os.Unsetenv(TokenEnvVar) + }, + ) - t.Run("When a given token is set on config file and environment, it should return the value set on the config file", func(t *testing.T) { - os.Setenv(TokenEnvVar, envVar) - got, err := getToken("", config) - want := config - assert.Nil(t, err) - assert.Equal(t, want, got) - os.Unsetenv(TokenEnvVar) - }) + t.Run( + "When a given token is set on config file and environment, it should return the value set on the config file", + func(t *testing.T) { + os.Setenv(TokenEnvVar, envVar) + got, err := getToken("", config) + want := config + assert.Nil(t, err) + assert.Equal(t, want, got) + os.Unsetenv(TokenEnvVar) + }, + ) - t.Run("When a given token is set on flag, config file, and environment, it should return the value set on flag", func(t *testing.T) { - os.Setenv(TokenEnvVar, envVar) - got, err := getToken(flag, config) - want := flag - assert.Nil(t, err) - assert.Equal(t, want, got) - os.Unsetenv(TokenEnvVar) - }) + t.Run( + "When a given token is set on flag, config file, and environment, it should return the value set on flag", + func(t *testing.T) { + os.Setenv(TokenEnvVar, envVar) + got, err := getToken(flag, config) + want := flag + assert.Nil(t, err) + assert.Equal(t, want, got) + os.Unsetenv(TokenEnvVar) + }, + ) t.Run("When no token is passed should return error", func(t *testing.T) { got, err := getToken("", "") diff --git a/pkg/gitclient/client.go b/pkg/gitclient/client.go index 02305fd..1377088 100644 --- a/pkg/gitclient/client.go +++ b/pkg/gitclient/client.go @@ -4,81 +4,288 @@ import ( "context" "fmt" "log" + "time" + + "github.com/shurcooL/githubv4" - "github.com/google/go-github/v45/github" "github.com/spf13/viper" "golang.org/x/oauth2" ) -func Client(githubToken string, ctx context.Context, isEnterprise bool) (client *github.Client) { +type PullRequest struct { + RepositoryOwner string + RepositoryName string + Number int + Title string + State string + CreatedAt time.Time + ID githubv4.ID + StatusRollup string + NeedsReview bool +} + +func ClientV4(githubToken string, ctx context.Context, isEnterprise bool) (client *githubv4.Client) { tokenSource := oauth2.StaticTokenSource( &oauth2.Token{ AccessToken: githubToken, }, ) - tokenContext := oauth2.NewClient(ctx, tokenSource) + httpClient := oauth2.NewClient(context.Background(), tokenSource) if isEnterprise { baseUrl := viper.GetString("enterprise-base-url") - c, err := github.NewEnterpriseClient(baseUrl, baseUrl, tokenContext) + client = githubv4.NewEnterpriseClient(baseUrl, httpClient) + } else { + client = githubv4.NewClient(httpClient) - if err != nil { - log.Fatalf("Could not auth enterprise client: %v", err) + } + + return +} + +type commits struct { + Nodes []struct { + Commit struct { + StatusCheckRollup struct { + State githubv4.String + } } + } +} + +type pullRequests struct { + Nodes []struct { + Number githubv4.Int + Title githubv4.String + State githubv4.String + Url githubv4.URI + CreatedAt githubv4.DateTime + ID githubv4.ID + Commits commits `graphql:"commits(last:1)"` + // NB: this only works for classic branch protection rules. Repository + // rulesets don't appear to be visible in the API + BaseRef struct { + BranchProtectionRule struct { + RequiredApprovingReviewCount githubv4.Int + } + } + Reviews struct { + Nodes []struct { + AuthorCanPushToRepository githubv4.Boolean + } + } `graphql:"reviews(first: 100, states: [APPROVED])"` + } +} - client = c +// NB: githubv4 uses struct tags to define the GraphQL query. To omit the labels +// argument to pullRequests(), we need to define a compeltely new type. Luckily +// these are convertible to each other as of go 1.8. +type repository struct { + NameWithOwner githubv4.String + Owner struct { + Login githubv4.String + } + Name githubv4.String + PullRequests pullRequests `graphql:"pullRequests(states: [OPEN], first: $maxPullRequests, orderBy: {field: CREATED_AT, direction: DESC})"` +} + +type repositoryWithPRLabels struct { + NameWithOwner githubv4.String + Owner struct { + Login githubv4.String + } + Name githubv4.String + PullRequests pullRequests `graphql:"pullRequests(states: [OPEN], labels: $labels, first: $maxPullRequests, orderBy: {field: CREATED_AT, direction: DESC})"` +} + +func GetPullRequests(client *githubv4.Client, ctx context.Context, owner string, repo string, labels *[]githubv4.String) ([]*PullRequest, error) { + + vars := map[string]interface{}{ + "maxPullRequests": githubv4.Int(100), + "owner": githubv4.String(owner), + } + if len(*labels) > 0 { + vars["labels"] = labels + } + + repos := []repository{} + + if repo == "" { + vars["maxRepositories"] = githubv4.Int(100) + if len(*labels) > 0 { + var q struct { + RepositoryOwner struct { + Repositories struct { + Nodes []repositoryWithPRLabels + } `graphql:"repositories(first: $maxRepositories, isFork: false, isArchived: false, isLocked: false, orderBy: {field: UPDATED_AT, direction: DESC})"` + } `graphql:"repositoryOwner(login: $owner)"` + } + err := client.Query(ctx, &q, vars) + if err != nil { + return nil, err + } + for _, repo := range q.RepositoryOwner.Repositories.Nodes { + repos = append(repos, repository(repo)) + } + } else { + var q struct { + RepositoryOwner struct { + Repositories struct { + Nodes []repository + } `graphql:"repositories(first: $maxRepositories, isFork: false, isArchived: false, isLocked: false, orderBy: {field: UPDATED_AT, direction: DESC})"` + } `graphql:"repositoryOwner(login: $owner)"` + } + err := client.Query(ctx, &q, vars) + if err != nil { + return nil, err + } + for _, repo := range q.RepositoryOwner.Repositories.Nodes { + repos = append(repos, repository(repo)) + } + } } else { - client = github.NewClient(tokenContext) + vars["name"] = githubv4.String(repo) + if len(*labels) > 0 { + var q struct { + Repository repositoryWithPRLabels `graphql:"repository(owner: $owner, name: $name)"` + } + err := client.Query(ctx, &q, vars) + if err != nil { + return nil, err + } + repos = append(repos, repository(q.Repository)) + } else { + var q struct { + Repository repository `graphql:"repository(owner: $owner, name: $name)"` + } + err := client.Query(ctx, &q, vars) + if err != nil { + return nil, err + } + repos = append(repos, q.Repository) + } } - return + pullRequests := []*PullRequest{} + for _, repo := range repos { + for _, pr := range repo.PullRequests.Nodes { + reviews := 0 + for _, review := range pr.Reviews.Nodes { + if review.AuthorCanPushToRepository { + reviews++ + } + } + pullRequest := &PullRequest{ + RepositoryOwner: string(repo.Owner.Login), + RepositoryName: string(repo.Name), + Number: int(pr.Number), + Title: string(pr.Title), + State: string(pr.State), + CreatedAt: pr.CreatedAt.Time, + ID: pr.ID, + StatusRollup: string(pr.Commits.Nodes[0].Commit.StatusCheckRollup.State), + NeedsReview: reviews < int(pr.BaseRef.BranchProtectionRule.RequiredApprovingReviewCount), + } + + pullRequests = append(pullRequests, pullRequest) + } + } + + return pullRequests, nil } -func ApprovePullRequest(ghClient *github.Client, ctx context.Context, org, repo string, prId int, skip bool) { - // Create review - commitMsg := ctx.Value("message").(string) - e := "APPROVE" - reviewRequest := &github.PullRequestReviewRequest{ - Body: &commitMsg, - Event: &e, +func ApprovePullRequest(ghClient *githubv4.Client, ctx context.Context, pr *PullRequest, skip bool) { + + commitMessage := githubv4.String(DefaultApproveMsg()) + event := githubv4.PullRequestReviewEventApprove + + input := githubv4.AddPullRequestReviewInput{ + PullRequestID: pr.ID, + Body: &commitMessage, + Event: &event, + } + + var m struct { + AddPullRequestReview struct { + PullRequestReview struct { + State githubv4.PullRequestReviewState + } + } `graphql:"addPullRequestReview(input: $input)"` } - review, _, err := ghClient.PullRequests.CreateReview(ctx, org, repo, prId, reviewRequest) + + err := ghClient.Mutate(ctx, &m, input, nil) if err != nil && !skip { - log.Fatalf("Could not approve pull request, did you try to approve your on pull request? - %v", err) - } - - if err != nil && skip { - fmt.Printf("Could not approve pull request, skipping.") - } else { - fmt.Printf("PR #%d: %v\n", prId, *review.State) - } + log.Printf("Could not approve pull request %s/%s#%d - %v\n", pr.RepositoryOwner, pr.RepositoryName, pr.Number, err) + } + + review := m.AddPullRequestReview.PullRequestReview + + if err != nil && skip { + fmt.Printf("Could not approve pull request, skipping.") + } else { + fmt.Printf("%s/%s#%d: %v\n", pr.RepositoryOwner, pr.RepositoryName, pr.Number, review.State) + } } -func MergePullRequest(ghClient *github.Client, ctx context.Context, org, repo string, prId int, mergeMethod string, skip bool) { - result, _, err := ghClient.PullRequests.Merge(ctx, org, repo, prId, defaultCommitMsg(), &github.PullRequestOptions{MergeMethod: mergeMethod}) - if err != nil { - log.Printf("Could not merge PR #%d, skipping: %v\n", prId, err) +func MergePullRequest(ghClient *githubv4.Client, ctx context.Context, pr *PullRequest, mergeMethod *githubv4.PullRequestMergeMethod, skip bool) { - return + input := githubv4.MergePullRequestInput{ + PullRequestID: pr.ID, + MergeMethod: mergeMethod, } - fmt.Sprintf("PR #%d: %v.\n", prId, *result.Message) -} + var m struct { + MergePullRequest struct { + PullRequest struct { + Merged bool + State githubv4.PullRequestState + Number githubv4.Int + Repository struct { + NameWithOwner githubv4.String + } + } + } `graphql:"mergePullRequest(input: $input)"` + } -func ClosePullRequest(ghClient *github.Client, ctx context.Context, org, repo string, prId int, prRef *github.PullRequest) { - // Set Closed state for PR - *prRef.State = "closed" - result, _, err := ghClient.PullRequests.Edit(ctx, org, repo, prId, prRef) + err := ghClient.Mutate(ctx, &m, input, nil) if err != nil { - log.Printf("Could not close PR #%d - %v", prId, err) - } else { - fmt.Sprintf("PR #%d: %v.\n", prId, *result.State) + log.Printf("Could not merge %s/%s#%d - %v\n", pr.RepositoryOwner, pr.RepositoryName, pr.Number, err) } + + prOut := m.MergePullRequest.PullRequest + fmt.Printf("%s/%s#%d merged: %v\n", pr.RepositoryOwner, pr.RepositoryName, pr.Number, prOut.State) + } -func defaultCommitMsg() string { - return "Merged by gomerge CLI." +func ClosePullRequest(ghClient *githubv4.Client, ctx context.Context, pr *PullRequest, skip bool) { + + input := githubv4.ClosePullRequestInput{ + PullRequestID: pr.ID, + } + + var m struct { + ClosePullRequest struct { + PullRequest struct { + Closed bool + State githubv4.PullRequestState + Number githubv4.Int + Repository struct { + NameWithOwner githubv4.String + } + } + } `graphql:"closePullRequest(input: $input)"` + } + + err := ghClient.Mutate(ctx, &m, input, nil) + if err != nil { + log.Printf("Could not close %s/%s#%d - %v\n", pr.RepositoryOwner, pr.RepositoryName, pr.Number, err) + } + + prOut := m.ClosePullRequest.PullRequest + + fmt.Printf("%s/%s#%d merged: %v\n", pr.RepositoryOwner, pr.RepositoryName, pr.Number, prOut.State) + } func DefaultApproveMsg() string { diff --git a/pkg/gitclient/client_test.go b/pkg/gitclient/client_test.go index e7ed18c..1c1469c 100644 --- a/pkg/gitclient/client_test.go +++ b/pkg/gitclient/client_test.go @@ -4,17 +4,6 @@ import ( "testing" ) -func TestDefaultCommitMsg(t *testing.T) { - t.Run("It returns a default commit message", func(t *testing.T) { - got := defaultCommitMsg() - want := "Merged by gomerge CLI." - - if got != want { - t.Errorf("got %s, want %s", got, want) - } - }) -} - func TestDefaultApproveMsg(t *testing.T) { t.Run("It returns a default approve message.", func(t *testing.T) { got := DefaultApproveMsg() diff --git a/pkg/printer/printer.go b/pkg/printer/printer.go index 493ae43..5926fd8 100644 --- a/pkg/printer/printer.go +++ b/pkg/printer/printer.go @@ -24,16 +24,33 @@ func HeaderStyle(t *tablewriter.Table) *tablewriter.Table { func SuccessStyle(t *tablewriter.Table, data []string) *tablewriter.Table { t.Rich(data, []tablewriter.Colors{ - tablewriter.Colors{tablewriter.Bold, tablewriter.FgHiCyanColor}, - tablewriter.Colors{tablewriter.Bold, tablewriter.FgHiGreenColor}, - tablewriter.Colors{tablewriter.Bold}, - tablewriter.Colors{tablewriter.Bold}, - tablewriter.Colors{tablewriter.Bold}, + {tablewriter.Bold, tablewriter.FgHiCyanColor}, + {tablewriter.Bold, tablewriter.FgHiGreenColor}, + {tablewriter.Bold}, + {tablewriter.Bold}, + {tablewriter.Bold}, }) return t } func ErrorStyle(t *tablewriter.Table, data []string) *tablewriter.Table { - t.Rich(data, []tablewriter.Colors{tablewriter.Colors{tablewriter.Bold, tablewriter.FgHiRedColor}, tablewriter.Colors{tablewriter.Bold, tablewriter.FgHiRedColor}}) + t.Rich( + data, + []tablewriter.Colors{ + {tablewriter.Bold, tablewriter.FgHiCyanColor}, + {tablewriter.Bold, tablewriter.FgHiRedColor}, + }, + ) + return t +} + +func WaitingStyle(t *tablewriter.Table, data []string) *tablewriter.Table { + t.Rich( + data, + []tablewriter.Colors{ + {tablewriter.Bold, tablewriter.FgHiCyanColor}, + {tablewriter.Bold, tablewriter.FgHiYellowColor}, + }, + ) return t } diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 141566c..477fc14 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -35,7 +35,13 @@ func TestReadConfigFile(t *testing.T) { wantExt := "yaml" if gotFilename != wantFilename || gotExt != wantExt { - t.Errorf("got: %s, want: %s, got: %s, want: %s", gotFilename, wantFilename, gotExt, wantExt) + t.Errorf( + "got: %s, want: %s, got: %s, want: %s", + gotFilename, + wantFilename, + gotExt, + wantExt, + ) } }) }