diff --git a/client/alloc_endpoint.go b/client/alloc_endpoint.go index 8eae2da78a3..fb6068a5d0c 100644 --- a/client/alloc_endpoint.go +++ b/client/alloc_endpoint.go @@ -177,7 +177,6 @@ func (a *Allocations) exec(conn io.ReadWriteCloser) { handleStreamResultError(err, code, encoder) return } - a.c.logger.Info("task exec session ended", "exec_id", execID) } @@ -216,6 +215,7 @@ func (a *Allocations) execImpl(encoder *codec.Encoder, decoder *codec.Decoder, e "task", req.Task, "command", req.Cmd, "tty", req.Tty, + "action", req.Action, } if ident != nil { if ident.ACLToken != nil { @@ -238,7 +238,7 @@ func (a *Allocations) execImpl(encoder *codec.Encoder, decoder *codec.Decoder, e // Check alloc-exec permission. if err != nil { - return nil, err + return pointer.Of(int64(400)), err } else if !aclObj.AllowNsOp(alloc.Namespace, acl.NamespaceCapabilityAllocExec) { return nil, nstructs.ErrPermissionDenied } @@ -247,6 +247,20 @@ func (a *Allocations) execImpl(encoder *codec.Encoder, decoder *codec.Decoder, e if req.Task == "" { return pointer.Of(int64(400)), taskNotPresentErr } + + // If an action is present, go find the command and args + if req.Action != "" { + alloc, _ := a.c.GetAlloc(req.AllocID) + jobAction, err := validateActionExists(req.Action, req.Task, alloc) + if err != nil { + return pointer.Of(int64(400)), err + } + if jobAction != nil { + // append both Command and Args + req.Cmd = append([]string{jobAction.Command}, jobAction.Args...) + } + } + if len(req.Cmd) == 0 { return pointer.Of(int64(400)), errors.New("command is not present") } @@ -343,3 +357,14 @@ func (s *execStream) Recv() (*drivers.ExecTaskStreamingRequestMsg, error) { err := s.decoder.Decode(&req) return &req, err } + +func validateActionExists(actionName string, taskName string, alloc *nstructs.Allocation) (*nstructs.Action, error) { + t := alloc.LookupTask(taskName) + + for _, action := range t.Actions { + if action.Name == actionName { + return action, nil + } + } + return nil, fmt.Errorf("action %s not found", actionName) +} diff --git a/client/structs/structs.go b/client/structs/structs.go index 8c878f37e85..31a33c98876 100644 --- a/client/structs/structs.go +++ b/client/structs/structs.go @@ -182,6 +182,9 @@ type AllocExecRequest struct { // Cmd is the command to be executed Cmd []string + // The name of a predefined command to be executed (optional) + Action string + structs.QueryOptions } diff --git a/command/agent/job_endpoint.go b/command/agent/job_endpoint.go index 9540f17379e..8a00a636ecd 100644 --- a/command/agent/job_endpoint.go +++ b/command/agent/job_endpoint.go @@ -12,8 +12,10 @@ import ( "strings" "github.com/golang/snappy" + "github.com/gorilla/websocket" "github.com/hashicorp/nomad/acl" api "github.com/hashicorp/nomad/api" + cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/jobspec" "github.com/hashicorp/nomad/jobspec2" "github.com/hashicorp/nomad/nomad/structs" @@ -113,6 +115,8 @@ func (s *HTTPServer) JobSpecificRequest(resp http.ResponseWriter, req *http.Requ case strings.HasSuffix(path, "/actions"): jobID := strings.TrimSuffix(path, "/actions") return s.jobActions(resp, req, jobID) + case strings.HasSuffix(path, "/action"): + return s.jobRunAction(resp, req) default: return s.jobCRUD(resp, req, path) } @@ -358,6 +362,47 @@ func (s *HTTPServer) jobActions(resp http.ResponseWriter, req *http.Request, job return out.Actions, nil } +func (s *HTTPServer) jobRunAction(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + + s.logger.Info("jobRunAction called") + + // Build the request and parse the ACL token + task := req.URL.Query().Get("task") + action := req.URL.Query().Get("action") + allocID := req.URL.Query().Get("allocID") + isTTY := false + err := error(nil) + if tty := req.URL.Query().Get("tty"); tty != "" { + isTTY, err = strconv.ParseBool(tty) + if err != nil { + return nil, fmt.Errorf("tty value is not a boolean: %v", err) + } + } + + args := cstructs.AllocExecRequest{ + Task: task, + Action: action, + AllocID: allocID, + Tty: isTTY, + } + + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) + + conn, err := s.wsUpgrader.Upgrade(resp, req, nil) + + if err != nil { + return nil, fmt.Errorf("failed to upgrade connection: %v", err) + } + + if err := readWsHandshake(conn.ReadJSON, req, &args.QueryOptions); err != nil { + conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(toWsCode(400), err.Error())) + return nil, err + } + + return s.execStreamImpl(conn, &args) +} + func (s *HTTPServer) jobSubmissionCRUD(resp http.ResponseWriter, req *http.Request, jobID string) (*structs.JobSubmission, error) { version, err := strconv.ParseUint(req.URL.Query().Get("version"), 10, 64) if err != nil {