diff --git a/docker/Dockerfile.minimega b/docker/Dockerfile.minimega index 6f689d6f..32f02eb6 100644 --- a/docker/Dockerfile.minimega +++ b/docker/Dockerfile.minimega @@ -9,7 +9,8 @@ ARG PHENIX_REVISION=local-dev LABEL gov.sandia.phenix.revision="${PHENIX_REVISION}" # iptables needed in minimega container for scorch and tap apps -RUN apt update && apt install -y iptables \ +# socat needed in minimega container for serial app +RUN apt update && apt install -y iptables socat \ && apt autoremove -y \ && apt clean -y \ && rm -rf /var/lib/apt/lists/* \ diff --git a/src/go/api/experiment/experiment.go b/src/go/api/experiment/experiment.go index a9fc2e8c..f548474f 100644 --- a/src/go/api/experiment/experiment.go +++ b/src/go/api/experiment/experiment.go @@ -64,6 +64,18 @@ func init() { return fmt.Errorf("initializing experiment: %w", err) } + if common.BridgeMode == common.BRIDGE_MODE_AUTO { + if len(c.Metadata.Name) > 15 { + return fmt.Errorf("experiment name must be 15 characters or less when using auto bridge mode") + } + + exp.Spec.SetDefaultBridge(c.Metadata.Name) + } + + if len(exp.Spec.DefaultBridge()) > 15 { + return fmt.Errorf("default bridge name must be 15 characters or less") + } + exp.Spec.SetUseGREMesh(exp.Spec.UseGREMesh() || common.UseGREMesh) existing, _ := types.Experiments(false) @@ -100,6 +112,19 @@ func init() { return fmt.Errorf("re-initializing experiment (after update): %w", err) } + // Just in case the updated experiment reset the default bridge. + if common.BridgeMode == common.BRIDGE_MODE_AUTO { + if len(c.Metadata.Name) > 15 { + return fmt.Errorf("experiment name must be 15 characters or less when using auto bridge mode") + } + + exp.Spec.SetDefaultBridge(c.Metadata.Name) + } + + if len(exp.Spec.DefaultBridge()) > 15 { + return fmt.Errorf("default bridge name must be 15 characters or less") + } + exp.Spec.SetUseGREMesh(exp.Spec.UseGREMesh() || common.UseGREMesh) existing, _ := types.Experiments(false) diff --git a/src/go/api/soh/soh.go b/src/go/api/soh/soh.go index 24ccaac8..f76fbf38 100644 --- a/src/go/api/soh/soh.go +++ b/src/go/api/soh/soh.go @@ -7,6 +7,7 @@ import ( "phenix/api/experiment" "phenix/api/vm" + "phenix/app" "github.com/mitchellh/mapstructure" ) @@ -15,7 +16,10 @@ var vlanAliasRegex = regexp.MustCompile(`(.*) \(\d*\)`) func Get(expName, statusFilter string) (*Network, error) { // Create an empty network - network := new(Network) + network := &Network{ + Nodes: []Node{}, + Edges: []Edge{}, + } // Create structure to format nodes' font font := Font{ @@ -69,6 +73,7 @@ func Get(expName, statusFilter string) (*Network, error) { // Internally use to track connections, VM's state, and whether or not the // VM is in minimega var ( + vmIDs = make(map[string]int) interfaces = make(map[string]int) ifaceCount = len(vms) + 1 edgeCount int @@ -76,6 +81,8 @@ func Get(expName, statusFilter string) (*Network, error) { // Traverse the experiment VMs and create topology for _, vm := range vms { + vmIDs[vm.Name] = vm.ID + var vmState string /* @@ -164,6 +171,32 @@ func Get(expName, statusFilter string) (*Network, error) { } } + // Check to see if a scenario exists for this experiment and if it contains a + // "serial" app. If so, add edges for all the serial connections. + for _, a := range exp.Apps() { + if a.Name() == "serial" { + var config app.SerialConfig + + if err := a.ParseMetadata(&config); err != nil { + continue // TODO: handle this better? Like warn the user perhaps? + } + + for _, conn := range config.Connections { + // create edge for serial connection + edge := Edge{ + ID: edgeCount, + Source: vmIDs[conn.Src], + Target: vmIDs[conn.Dst], + Length: 150, + Type: "serial", + } + + network.Edges = append(network.Edges, edge) + edgeCount++ + } + } + } + return network, err } diff --git a/src/go/api/soh/types.go b/src/go/api/soh/types.go index 05df56d5..f3fbdcc6 100644 --- a/src/go/api/soh/types.go +++ b/src/go/api/soh/types.go @@ -20,10 +20,11 @@ type Node struct { } type Edge struct { - ID int `json:"id"` - Source int `json:"source"` - Target int `json:"target"` - Length int `json:"length"` + ID int `json:"id"` + Type string `json:"type"` + Source int `json:"source"` + Target int `json:"target"` + Length int `json:"length"` } type Network struct { diff --git a/src/go/app/serial.go b/src/go/app/serial.go index 1280147a..4d7e8ccc 100644 --- a/src/go/app/serial.go +++ b/src/go/app/serial.go @@ -9,8 +9,27 @@ import ( "phenix/tmpl" "phenix/types" ifaces "phenix/types/interfaces" + "phenix/util/mm" ) +var ( + idFormat = "%s_serial_%s_%d" + lfFormat = "/tmp/%s_serial_%s_%s_%d.log" + optFormat = "-chardev socket,id=%[1]s,path=/tmp/%[1]s,server,nowait -device pci-serial,chardev=%[1]s" + + defaultStartPort = 40500 +) + +type SerialConfig struct { + Connections []SerialConnectionConfig `mapstructure:"connections"` +} + +type SerialConnectionConfig struct { + Src string `mapstructure:"src"` + Dst string `mapstructure:"dst"` + Port int `mapstructure:"port"` +} + type Serial struct{} func (Serial) Init(...Option) error { @@ -121,10 +140,91 @@ func (Serial) PreStart(ctx context.Context, exp *types.Experiment) error { } } + // Check to see if a scenario exists for this experiment and if it contains a + // "serial" app. If so, configure serial ports according to the app config. + for _, app := range exp.Apps() { + if app.Name() == "serial" { + var config SerialConfig + + if err := app.ParseMetadata(&config); err != nil { + continue // TODO: handle this better? Like warn the user perhaps? + } + + for i, conn := range config.Connections { + src := exp.Spec.Topology().FindNodeByName(conn.Src) + + if src == nil { + continue // TODO: handle this better? Like warn the user perhaps? + } + + appendQEMUFlags(exp.Metadata.Name, src, i) + + dst := exp.Spec.Topology().FindNodeByName(conn.Dst) + + if src == nil { + continue // TODO: handle this better? Like warn the user perhaps? + } + + appendQEMUFlags(exp.Metadata.Name, dst, i) + } + } + } + return nil } func (Serial) PostStart(ctx context.Context, exp *types.Experiment) error { + // Check to see if a scenario exists for this experiment and if it contains a + // "serial" app. If so, configure serial ports according to the app config. + for _, app := range exp.Apps() { + if app.Name() == "serial" { + var ( + schedule = exp.Status.Schedules() + config SerialConfig + ) + + if err := app.ParseMetadata(&config); err != nil { + continue // TODO: handle this better? Like warn the user perhaps? + } + + for i, conn := range config.Connections { + var ( + logFile = fmt.Sprintf(lfFormat, exp.Metadata.Name, conn.Src, conn.Dst, i) + srcID = fmt.Sprintf(idFormat, exp.Metadata.Name, conn.Src, i) + dstID = fmt.Sprintf(idFormat, exp.Metadata.Name, conn.Dst, i) + srcHost = schedule[conn.Src] + dstHost = schedule[conn.Dst] + ) + + if srcHost == dstHost { // single socat process on host connecting unix sockets + socat := fmt.Sprintf("socat -lf%s -d -d -d -d UNIX-CONNECT:/tmp/%s UNIX-CONNECT:/tmp/%s", logFile, srcID, dstID) + + if err := mm.MeshBackground(srcHost, socat); err != nil { + return fmt.Errorf("starting socat on %s: %w", srcHost, err) + } + } else { // single socat process on each host connected via TCP + port := conn.Port + + if port == 0 { + port = defaultStartPort + i + } + + srcSocat := fmt.Sprintf("socat -lf%s -d -d -d -d UNIX-CONNECT:/tmp/%s TCP-LISTEN:%d", logFile, srcID, port) + + if err := mm.MeshBackground(srcHost, srcSocat); err != nil { + return fmt.Errorf("starting socat on %s: %w", srcHost, err) + } + + dstSocat := fmt.Sprintf("socat -lf%s -d -d -d -d UNIX-CONNECT:/tmp/%s TCP-CONNECT:%s:%d", logFile, dstID, srcHost, port) + + if err := mm.MeshBackground(dstHost, dstSocat); err != nil { + return fmt.Errorf("starting socat on %s: %w", dstHost, err) + } + } + } + } + } + return nil } @@ -135,3 +235,27 @@ func (Serial) Running(ctx context.Context, exp *types.Experiment) error { func (Serial) Cleanup(ctx context.Context, exp *types.Experiment) error { return nil } + +func appendQEMUFlags(exp string, node ifaces.NodeSpec, idx int) error { + var ( + id = fmt.Sprintf(idFormat, exp, node.General().Hostname(), idx) + options = fmt.Sprintf(optFormat, id) + ) + + var qemuAppend []string + + if advanced := node.Advanced(); advanced != nil { + if v, ok := advanced["qemu-append"]; ok { + if strings.Contains(v, options) { + return nil + } + + qemuAppend = []string{v} + } + } + + qemuAppend = append(qemuAppend, options) + node.AddAdvanced("qemu-append", strings.Join(qemuAppend, " ")) + + return nil +} diff --git a/src/go/app/tap.go b/src/go/app/tap.go index 4c05296d..d07eb716 100644 --- a/src/go/app/tap.go +++ b/src/go/app/tap.go @@ -59,9 +59,9 @@ func (this *Tap) PostStart(ctx context.Context, exp *types.Experiment) error { return fmt.Errorf("decoding %s app metadata: %w", this.Name(), err) } - hosts, err := mm.GetClusterHosts(true) + hosts, err := mm.GetNamespaceHosts(exp.Metadata.Name) if err != nil { - return fmt.Errorf("getting list of cluster hosts: %w", err) + return fmt.Errorf("getting list of experiment hosts: %w", err) } rand.Seed(time.Now().UnixNano()) diff --git a/src/go/app/user.go b/src/go/app/user.go index dc3ddd14..3ddd9561 100644 --- a/src/go/app/user.go +++ b/src/go/app/user.go @@ -154,8 +154,14 @@ func (this UserApp) shellOut(ctx context.Context, action Action, exp *types.Expe } switch action { - case ACTIONCONFIG, ACTIONPRESTART: + case ACTIONCONFIG: exp.SetSpec(result.Spec) + case ACTIONPRESTART: + exp.SetSpec(result.Spec) + + if metadata, ok := result.Status.AppStatus()[this.options.Name]; ok { + exp.Status.SetAppStatus(this.options.Name, metadata) + } case ACTIONPOSTSTART, ACTIONRUNNING: if metadata, ok := result.Status.AppStatus()[this.options.Name]; ok { exp.Status.SetAppStatus(this.options.Name, metadata) diff --git a/src/go/cmd/root.go b/src/go/cmd/root.go index 4e135410..c785b831 100644 --- a/src/go/cmd/root.go +++ b/src/go/cmd/root.go @@ -39,9 +39,15 @@ var rootCmd = &cobra.Command{ PersistentPreRunE: func(cmd *cobra.Command, args []string) error { common.UnixSocket = viper.GetString("unix-socket") - // Initialize use GRE mesh with option set locally by user. Later it will be - // forcefully enabled if it's enabled at the server. This must be done - // before getting options from the server (unlike deploy mode option). + // Initialize bridge mode and use GRE mesh options with values set locally + // by user. Later they will be forcefully enabled if they're enabled at the + // server. This must be done before getting options from the server (unlike + // deploy mode option). + + if err := common.SetBridgeMode(viper.GetString("bridge-mode")); err != nil { + return fmt.Errorf("setting user-specified bridge mode: %w", err) + } + common.UseGREMesh = viper.GetBool("use-gre-mesh") // check for global options set by UI server @@ -61,7 +67,17 @@ var rootCmd = &cobra.Command{ var options map[string]any json.Unmarshal(body, &options) - mode, _ := options["deploy-mode"].(string) + mode, _ := options["bridge-mode"].(string) + + // Only override value locally set by user (above) if auto mode is set + // on the server. + if mode == string(common.BRIDGE_MODE_AUTO) { + if err := common.SetBridgeMode(mode); err != nil { + return fmt.Errorf("setting server-specified bridge mode: %w", err) + } + } + + mode, _ = options["deploy-mode"].(string) if err := common.SetDeployMode(mode); err != nil { return fmt.Errorf("setting server-specified deploy mode: %w", err) } @@ -178,6 +194,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&hostnameSuffixes, "hostname-suffixes", "-minimega,-phenix", "hostname suffixes to strip") rootCmd.PersistentFlags().Bool("log.error-stderr", true, "log fatal errors to STDERR") rootCmd.PersistentFlags().String("log.level", "info", "level to log messages at") + rootCmd.PersistentFlags().String("bridge-mode", "", "bridge naming mode for experiments ('auto' uses experiment name for bridge; 'manual' uses user-specified bridge name, or 'phenix' if not specified) (options: manual | auto)") rootCmd.PersistentFlags().String("deploy-mode", "", "deploy mode for minimega VMs (options: all | no-headnode | only-headnode)") rootCmd.PersistentFlags().Bool("use-gre-mesh", false, "use GRE tunnels between mesh nodes for VLAN trunking") rootCmd.PersistentFlags().String("unix-socket", "/tmp/phenix.sock", "phēnix unix socket to listen on (ui subcommand) or connect to") diff --git a/src/go/tunneler/main.go b/src/go/tunneler/main.go index 87178b6f..bef4790e 100644 --- a/src/go/tunneler/main.go +++ b/src/go/tunneler/main.go @@ -15,6 +15,7 @@ import ( bt "phenix/web/broker/brokertypes" ft "phenix/web/forward/forwardtypes" + jwtutil "phenix/web/util/jwt" "github.com/dgrijalva/jwt-go" "github.com/olekukonko/tablewriter" @@ -91,23 +92,32 @@ var serveCmd = &cobra.Command{ } if token != "" { + cookie, err := cmd.Flags().GetString("use-cookie") + if err != nil { + return fmt.Errorf("unable to get --use-cookie flag") + } + var claims jwt.MapClaims - _, _, err := new(jwt.Parser).ParseUnverified(token, &claims) + _, _, err = new(jwt.Parser).ParseUnverified(token, &claims) if err != nil { return fmt.Errorf("parsing phenix auth token for username: %w", err) } - sub, ok := claims["sub"].(string) - if !ok { - return fmt.Errorf("username missing from phenix auth token") + username, err = jwtutil.UsernameFromClaims(claims) + if err != nil { + return fmt.Errorf("username missing from token") } - if username != "" && sub != username { - return fmt.Errorf("provided username does not match token subject") + if err := jwtutil.ValidateExpirationClaim(claims); err != nil { + return fmt.Errorf("validating token expiration: %w", err) } headers.Set("X-phenix-auth-token", "Bearer "+token) + + if cookie != "" { + headers.Set("Cookie", fmt.Sprintf("%s=%s", cookie, token)) + } } else if username != "" { fmt.Printf("Password for %s: ", username) @@ -388,6 +398,7 @@ var deactivateCmd = &cobra.Command{ func main() { serveCmd.Flags().StringP("username", "u", "", "username to log into phēnix with") serveCmd.Flags().StringP("auth-token", "t", "", "phēnix API token (skip login process)") + serveCmd.Flags().StringP("use-cookie", "c", "", "name of cookie to use for auth token") rootCmd.AddCommand(listCmd, activateCmd, deactivateCmd, moveCmd, serveCmd) diff --git a/src/go/util/common/common.go b/src/go/util/common/common.go index 6d24afeb..5acd7cf1 100644 --- a/src/go/util/common/common.go +++ b/src/go/util/common/common.go @@ -5,7 +5,16 @@ import ( "strings" ) -type DeploymentMode string +type ( + BridgingMode string + DeploymentMode string +) + +const ( + BRIDGE_MODE_UNSET BridgingMode = "" + BRIDGE_MODE_MANUAL BridgingMode = "manual" + BRIDGE_MODE_AUTO BridgingMode = "auto" +) const ( DEPLOY_MODE_UNSET DeploymentMode = "" @@ -18,6 +27,7 @@ var ( PhenixBase = "/phenix" MinimegaBase = "/tmp/minimega" + BridgeMode = BRIDGE_MODE_MANUAL DeployMode = DEPLOY_MODE_NO_HEADNODE LogFile = "/var/log/phenix/phenix.log" @@ -38,6 +48,30 @@ func TrimHostnameSuffixes(str string) string { return str } +func ParseBridgeMode(mode string) (BridgingMode, error) { + switch strings.ToLower(mode) { + case "manual": + return BRIDGE_MODE_MANUAL, nil + case "auto": + return BRIDGE_MODE_AUTO, nil + case "": // default to current setting + return BridgeMode, nil + } + + return BRIDGE_MODE_UNSET, fmt.Errorf("unknown bridge mode provided: %s", mode) +} + +func SetBridgeMode(mode string) error { + parsed, err := ParseBridgeMode(mode) + if err != nil { + return fmt.Errorf("setting bridge mode: %w", err) + } + + BridgeMode = parsed + + return nil +} + func ParseDeployMode(mode string) (DeploymentMode, error) { switch strings.ToLower(mode) { case "no-headnode": diff --git a/src/go/util/mm/minimega.go b/src/go/util/mm/minimega.go index d99cfce7..eefbe996 100644 --- a/src/go/util/mm/minimega.go +++ b/src/go/util/mm/minimega.go @@ -729,10 +729,6 @@ func (this Minimega) GetClusterHosts(schedOnly bool) (Hosts, error) { // This will happen if the headnode is included as a compute node // (ie. when there's only one node in the cluster). if host.Name == head.Name { - // Add disk info - head.DiskUsage.Phenix = this.getDiskUsage(head.Name, common.PhenixBase) - head.DiskUsage.Minimega = this.getDiskUsage(head.Name, common.MinimegaBase) - head.Schedulable = true continue } @@ -762,6 +758,28 @@ func (this Minimega) GetClusterHosts(schedOnly bool) (Hosts, error) { return cluster, nil } +func (this Minimega) GetNamespaceHosts(ns string) (Hosts, error) { + var hosts []Host + + // Get namespace nodes details + processed, err := processNamespaceHosts(ns) + if err != nil { + return nil, fmt.Errorf("processing namespace nodes details: %w", err) + } + + for _, host := range processed { + host.Name = common.TrimHostnameSuffixes(host.Name) + + // Add disk info + host.DiskUsage.Phenix = this.getDiskUsage(host.Name, common.PhenixBase) + host.DiskUsage.Minimega = this.getDiskUsage(host.Name, common.MinimegaBase) + + hosts = append(hosts, host) + } + + return hosts, nil +} + func (Minimega) Headnode() string { // Get headnode details hosts, _ := processNamespaceHosts("minimega") @@ -1186,6 +1204,26 @@ func (Minimega) MeshShellResponse(host, command string) (string, error) { return "", fmt.Errorf("error running MeshShellResponse()") } +func (Minimega) MeshBackground(host, command string) error { + cmd := mmcli.NewCommand() + + if host == "" { + host = Headnode() + } + + if IsHeadnode(host) { + cmd.Command = fmt.Sprintf("background %s", command) + } else { + cmd.Command = fmt.Sprintf("mesh send %s background %s", host, command) + } + + if err := mmcli.ErrorResponse(mmcli.Run(cmd)); err != nil { + return fmt.Errorf("backgrounding shell command (host %s) %s: %w", host, command, err) + } + + return nil +} + func (Minimega) MeshSend(ns, host, command string) error { var cmd *mmcli.Command diff --git a/src/go/util/mm/mm.go b/src/go/util/mm/mm.go index eee41019..36f0de20 100644 --- a/src/go/util/mm/mm.go +++ b/src/go/util/mm/mm.go @@ -32,6 +32,7 @@ type MM interface { GetVMCaptures(...Option) []Capture GetClusterHosts(bool) (Hosts, error) + GetNamespaceHosts(string) (Hosts, error) Headnode() string IsHeadnode(string) bool GetVLANs(...Option) (map[string]int, error) @@ -44,5 +45,6 @@ type MM interface { TapVLAN(...TapOption) error MeshShell(string, string) error + MeshBackground(string, string) error MeshSend(string, string, string) error } diff --git a/src/go/util/mm/package.go b/src/go/util/mm/package.go index 71997158..8713b36f 100644 --- a/src/go/util/mm/package.go +++ b/src/go/util/mm/package.go @@ -92,6 +92,10 @@ func GetClusterHosts(schedOnly bool) (Hosts, error) { return DefaultMM.GetClusterHosts(schedOnly) } +func GetNamespaceHosts(ns string) (Hosts, error) { + return DefaultMM.GetNamespaceHosts(ns) +} + func Headnode() string { return DefaultMM.Headnode() } @@ -132,6 +136,10 @@ func MeshShell(host, cmd string) error { return DefaultMM.MeshShell(host, cmd) } +func MeshBackground(host, cmd string) error { + return DefaultMM.MeshBackground(host, cmd) +} + func MeshSend(ns, host, command string) error { return DefaultMM.MeshSend(ns, host, command) } diff --git a/src/go/web/auth.go b/src/go/web/auth.go index cb24c7db..fef04b7b 100644 --- a/src/go/web/auth.go +++ b/src/go/web/auth.go @@ -11,6 +11,7 @@ import ( "phenix/util/plog" "phenix/web/rbac" "phenix/web/util" + jwtutil "phenix/web/util/jwt" "github.com/dgrijalva/jwt-go" "github.com/gorilla/mux" @@ -33,7 +34,28 @@ func Signup(w http.ResponseWriter, r *http.Request) { return } - if o.proxyAuthHeader != "" { + var ( + ctx = r.Context() + token *jwt.Token + ) + + // Will only be present when this function is called if proxy JWT is enabled. + if userToken := ctx.Value("user"); userToken != nil { + token = userToken.(*jwt.Token) + claims := token.Claims.(*jwt.MapClaims) + + jwtUser, err := jwtutil.UsernameFromClaims(*claims) + if err != nil { + plog.Error("proxy user missing from JWT", "path", r.URL.Path, "err", err) + http.Error(w, "proxy user missing", http.StatusUnauthorized) + return + } + + if req.Username != jwtUser { + http.Error(w, "proxy user mismatch", http.StatusUnauthorized) + return + } + } else if o.proxyAuthHeader != "" { if user := r.Header.Get(o.proxyAuthHeader); user != req.Username { http.Error(w, "proxy user mismatch", http.StatusUnauthorized) return @@ -45,26 +67,37 @@ func Signup(w http.ResponseWriter, r *http.Request) { u.Spec.FirstName = req.FirstName u.Spec.LastName = req.LastName - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "sub": u.Username(), - "exp": time.Now().Add(o.jwtLifetime).Unix(), - }) + var raw string - // Sign and get the complete encoded token as a string using the secret - signed, err := token.SignedString([]byte(o.jwtKey)) - if err != nil { - http.Error(w, "failed to sign JWT", http.StatusInternalServerError) - return - } + if token == nil { // not using proxy JWT + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": u.Username(), + "exp": time.Now().Add(o.jwtLifetime).Unix(), + }) - if err := u.AddToken(signed, time.Now().Format(time.RFC3339)); err != nil { - http.Error(w, "", http.StatusInternalServerError) - return + // Sign and get the complete encoded token as a string using the secret + raw, err = token.SignedString([]byte(o.jwtKey)) + if err != nil { + http.Error(w, "failed to sign JWT", http.StatusInternalServerError) + return + } + + if err := u.AddToken(raw, time.Now().Format(time.RFC3339)); err != nil { + http.Error(w, "", http.StatusInternalServerError) + return + } + } else { // using proxy JWT + raw = token.Raw + + if err := u.AddToken(raw, "proxied"); err != nil { + http.Error(w, "", http.StatusInternalServerError) + return + } } resp := LoginResponse{ User: userFromRBAC(*u), - Token: signed, + Token: raw, } body, err = json.Marshal(resp) @@ -84,68 +117,92 @@ func Login(w http.ResponseWriter, r *http.Request) { proxied bool ) - switch r.Method { - case "GET": - if o.proxyAuthHeader == "" { - var ok bool + var ( + ctx = r.Context() + token *jwt.Token + ) - user, pass, ok = r.BasicAuth() + // Will only be present when this function is called if proxy JWT is enabled. + if userToken := ctx.Value("user"); userToken != nil { + token = userToken.(*jwt.Token) - if !ok { - query := r.URL.Query() + var ( + claims = token.Claims.(*jwt.MapClaims) + err error + ) - user = query.Get("user") - if user == "" { - http.Error(w, "no username provided", http.StatusBadRequest) - return + user, err = jwtutil.UsernameFromClaims(*claims) + if err != nil { + plog.Error("proxy user missing from JWT", "path", r.URL.Path, "token", token.Raw, "err", err) + http.Error(w, "proxy user missing", http.StatusUnauthorized) + return + } + + proxied = true + } else { + switch r.Method { + case "GET": + if o.proxyAuthHeader == "" { + var ok bool + + user, pass, ok = r.BasicAuth() + + if !ok { + query := r.URL.Query() + + user = query.Get("user") + if user == "" { + http.Error(w, "no username provided", http.StatusBadRequest) + return + } + + pass = query.Get("pass") + if pass == "" { + http.Error(w, "no password provided", http.StatusBadRequest) + return + } } + } else { + user = r.Header.Get(o.proxyAuthHeader) - pass = query.Get("pass") - if pass == "" { - http.Error(w, "no password provided", http.StatusBadRequest) + if user == "" { + http.Error(w, "proxy authentication failed", http.StatusUnauthorized) return } - } - } else { - user = r.Header.Get(o.proxyAuthHeader) - if user == "" { - http.Error(w, "proxy authentication failed", http.StatusUnauthorized) + proxied = true + } + case "POST": + if o.proxyAuthHeader != "" { + http.Error(w, "proxy auth enabled -- must login via GET request", http.StatusBadRequest) return } - proxied = true - } - case "POST": - if o.proxyAuthHeader != "" { - http.Error(w, "proxy auth enabled -- must login via GET request", http.StatusBadRequest) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "no data provided in POST", http.StatusBadRequest) - return - } + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "no data provided in POST", http.StatusBadRequest) + return + } - var req LoginRequest - if err := json.Unmarshal(body, &req); err != nil { - http.Error(w, "invalid data provided in POST", http.StatusBadRequest) - return - } + var req LoginRequest + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, "invalid data provided in POST", http.StatusBadRequest) + return + } - if user = req.Username; user == "" { - http.Error(w, "invalid username provided in POST", http.StatusBadRequest) - return - } + if user = req.Username; user == "" { + http.Error(w, "invalid username provided in POST", http.StatusBadRequest) + return + } - if pass = req.Password; pass == "" { - http.Error(w, "invalid password provided in POST", http.StatusBadRequest) + if pass = req.Password; pass == "" { + http.Error(w, "invalid password provided in POST", http.StatusBadRequest) + return + } + default: + http.Error(w, "invalid method", http.StatusBadRequest) return } - default: - http.Error(w, "invalid method", http.StatusBadRequest) - return } u, err := rbac.GetUser(user) @@ -161,21 +218,32 @@ func Login(w http.ResponseWriter, r *http.Request) { } } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "sub": u.Username(), - "exp": time.Now().Add(o.jwtLifetime).Unix(), - }) + var signed string - // Sign and get the complete encoded token as a string using the secret - signed, err := token.SignedString([]byte(o.jwtKey)) - if err != nil { - http.Error(w, "failed to sign JWT", http.StatusInternalServerError) - return - } + if token == nil { + token = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": u.Username(), + "exp": time.Now().Add(o.jwtLifetime).Unix(), + }) - if err := u.AddToken(signed, time.Now().Format(time.RFC3339)); err != nil { - http.Error(w, "", http.StatusInternalServerError) - return + // Sign and get the complete encoded token as a string using the secret + signed, err = token.SignedString([]byte(o.jwtKey)) + if err != nil { + http.Error(w, "failed to sign JWT", http.StatusInternalServerError) + return + } + + if err := u.AddToken(signed, time.Now().Format(time.RFC3339)); err != nil { + http.Error(w, "", http.StatusInternalServerError) + return + } + } else { + signed = token.Raw + + if err := u.AddToken(signed, "proxied"); err != nil { + http.Error(w, "", http.StatusInternalServerError) + return + } } resp := LoginResponse{ diff --git a/src/go/web/middleware/auth.go b/src/go/web/middleware/auth.go index 118f2dd7..8cb45cf5 100644 --- a/src/go/web/middleware/auth.go +++ b/src/go/web/middleware/auth.go @@ -6,6 +6,7 @@ import ( "net/http" "phenix/util/plog" "phenix/web/rbac" + jwtutil "phenix/web/util/jwt" "strings" jwtmiddleware "github.com/cescoferraro/go-jwt-middleware" @@ -70,6 +71,33 @@ func Auth(jwtKey, proxyAuthHeader string) mux.MiddlewareFunc { }, ) + validTokenMiddleware := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := fromPhenixAuthTokenHeader(r) + if err != nil { + plog.Error("getting raw JWT from X-phenix-auth-token header", "err", err) + + http.Error(w, "missing phenix auth token header", http.StatusBadRequest) + return + } + + var claims jwt.MapClaims + + token, _, err := new(jwt.Parser).ParseUnverified(raw, &claims) + if err != nil { + plog.Error("parsing valid JWT", "token", raw, "err", err) + + http.Error(w, "parsing auth token", http.StatusBadRequest) + return + } + + ctx := r.Context() + ctx = context.WithValue(ctx, "user", token) + + h.ServeHTTP(w, r.WithContext(ctx)) + }) + } + userMiddleware := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasSuffix(r.URL.Path, "/signup") { @@ -91,18 +119,27 @@ func Auth(jwtKey, proxyAuthHeader string) mux.MiddlewareFunc { return } - token := userToken.(*jwt.Token) - claim := token.Claims.(jwt.MapClaims) + var ( + token = userToken.(*jwt.Token) + claims = token.Claims.(*jwt.MapClaims) + ) + + jwtUser, err := jwtutil.UsernameFromClaims(*claims) + if err != nil { + plog.Error("rejecting unauthorized request", "path", r.URL.Path, "err", err) + http.Error(w, "Forbidden", http.StatusUnauthorized) + return + } if proxyAuthHeader != "" { - if user := r.Header.Get(proxyAuthHeader); user != claim["sub"].(string) { - plog.Error("proxy user mismatch", "user", user, "token", claim["sub"].(string)) + if user := r.Header.Get(proxyAuthHeader); user != jwtUser { + plog.Error("proxy user mismatch", "user", user, "token", jwtUser) http.Error(w, "proxy user mismatch", http.StatusUnauthorized) return } } - user, err := rbac.GetUser(claim["sub"].(string)) + user, err := rbac.GetUser(jwtUser) if err != nil { http.Error(w, "user error", http.StatusUnauthorized) return @@ -150,6 +187,9 @@ func Auth(jwtKey, proxyAuthHeader string) mux.MiddlewareFunc { if jwtKey == "" { plog.Info("no JWT signing key provided -- disabling auth") return func(h http.Handler) http.Handler { return NoAuth(h) } + } else if jwtKey == "proxy-jwt" { + plog.Info("using JWTs from proxy") + return func(h http.Handler) http.Handler { return validTokenMiddleware(userMiddleware(h)) } } else if strings.HasPrefix(jwtKey, "dev|") { plog.Debug("development JWT key provided -- enabling dev auth") return func(h http.Handler) http.Handler { return devAuthMiddleware(h) } diff --git a/src/go/web/option.go b/src/go/web/option.go index 913d84ba..2f172971 100644 --- a/src/go/web/option.go +++ b/src/go/web/option.go @@ -188,6 +188,7 @@ func GetOptions(w http.ResponseWriter, r *http.Request) error { } options := map[string]any{ + "bridge-mode": common.BridgeMode, "deploy-mode": common.DeployMode, "use-gre-mesh": common.UseGREMesh, } diff --git a/src/go/web/rbac/known_policy.go b/src/go/web/rbac/known_policy.go index bb20bba6..c535876f 100644 --- a/src/go/web/rbac/known_policy.go +++ b/src/go/web/rbac/known_policy.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// This file was generated at build time 2024-05-29 11:25:27.212902293 -0600 MDT m=+0.101882971 +// This file was generated at build time 2024-06-07 13:45:11.094626942 -0600 MDT m=+0.099505897 // This contains all known role checks used in codebase package rbac diff --git a/src/go/web/rbac/user.go b/src/go/web/rbac/user.go index 9fe17b7a..0982da31 100644 --- a/src/go/web/rbac/user.go +++ b/src/go/web/rbac/user.go @@ -179,11 +179,30 @@ func (this User) UpdatePassword(old, new string) error { return nil } +func (this User) GetProxyToken() string { + for token, note := range this.Spec.Tokens { + if note == "proxied" { + return token + } + } + + return "" +} + func (this User) AddToken(token, note string) error { if this.Spec.Tokens == nil { this.Spec.Tokens = make(map[string]string) } + if note == "proxied" { + // we only want to keep one proxy JWT + for k, v := range this.Spec.Tokens { + if v == "proxied" { + delete(this.Spec.Tokens, k) + } + } + } + enc := base64.StdEncoding.EncodeToString([]byte(token)) this.Spec.Tokens[enc] = note diff --git a/src/go/web/server.go b/src/go/web/server.go index 42c21afb..7e14f772 100644 --- a/src/go/web/server.go +++ b/src/go/web/server.go @@ -261,6 +261,7 @@ func Start(opts ...ServerOption) error { addRoutesToRouter(api, workflowRoutes...) addRoutesToRouter(api, errorRoutes...) + addRoutesToRouter(api, optionRoutes...) if o.allowCORS { plog.Info("CORS is enabled on HTTP API endpoints") diff --git a/src/go/web/types.go b/src/go/web/types.go index 056d51ef..e1504308 100644 --- a/src/go/web/types.go +++ b/src/go/web/types.go @@ -55,6 +55,7 @@ type User struct { LastName string `json:"last_name"` ResourceNames []string `json:"resource_names"` Role Role `json:"role"` + ProxyToken string `json:"proxy_token,omitempty"` } type Policy struct { @@ -70,6 +71,7 @@ type Role struct { func userFromRBAC(u rbac.User) User { role, _ := u.Role() + user := User{ Username: u.Username(), FirstName: u.FirstName(), diff --git a/src/go/web/users.go b/src/go/web/users.go index 9fbf2cab..c464f15b 100644 --- a/src/go/web/users.go +++ b/src/go/web/users.go @@ -35,19 +35,28 @@ func GetUsers(w http.ResponseWriter, r *http.Request) { return } - for _, user := range users { - if role.Allowed("users", "list", user.Username()) { - resp = append(resp, userFromRBAC(*user)) + for _, rbacUser := range users { + if role.Allowed("users", "list", rbacUser.Username()) { + user := userFromRBAC(*rbacUser) + + if rbacUser.Username() == uname { + user.ProxyToken = rbacUser.GetProxyToken() + } + + resp = append(resp, user) } } } else if role.Allowed("users", "get", uname) { - user, err := rbac.GetUser(uname) + rbacUser, err := rbac.GetUser(uname) if err != nil { http.Error(w, "", http.StatusInternalServerError) return } - resp = append(resp, userFromRBAC(*user)) + user := userFromRBAC(*rbacUser) + user.ProxyToken = rbacUser.GetProxyToken() + + resp = append(resp, user) } else { http.Error(w, "forbidden", http.StatusForbidden) return @@ -138,28 +147,33 @@ func GetUser(w http.ResponseWriter, r *http.Request) { plog.Debug("HTTP handler called", "GetUser") var ( - ctx = r.Context() - role = ctx.Value("role").(rbac.Role) - vars = mux.Vars(r) - uname = vars["username"] + ctx = r.Context() + uname = ctx.Value("user").(string) + role = ctx.Value("role").(rbac.Role) + vars = mux.Vars(r) + username = vars["username"] ) - if !role.Allowed("users", "get", uname) { + if !role.Allowed("users", "get", username) { http.Error(w, "forbidden", http.StatusForbidden) return } - user, err := rbac.GetUser(uname) + rbacUser, err := rbac.GetUser(username) if err != nil { http.Error(w, "unable to get user", http.StatusInternalServerError) return } - resp := userFromRBAC(*user) + user := userFromRBAC(*rbacUser) - body, err := json.Marshal(resp) + if rbacUser.Username() == uname { + user.ProxyToken = rbacUser.GetProxyToken() + } + + body, err := json.Marshal(user) if err != nil { - plog.Error("marshaling user", "user", user.Username(), "err", err) + plog.Error("marshaling user", "user", rbacUser.Username(), "err", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/src/go/web/util/jwt/jwt.go b/src/go/web/util/jwt/jwt.go new file mode 100644 index 00000000..e1f587ee --- /dev/null +++ b/src/go/web/util/jwt/jwt.go @@ -0,0 +1,40 @@ +package jwt + +import ( + "fmt" + "time" + + "github.com/dgrijalva/jwt-go" +) + +var userClaims = []string{"sub", "username", "user"} + +func UsernameFromClaims(claims jwt.MapClaims) (string, error) { + for _, claim := range userClaims { + if user, ok := claims[claim].(string); ok && user != "" { + return user, nil + } + } + + return "", fmt.Errorf("username not found in JWT claims") +} + +func ValidateExpirationClaim(claims jwt.MapClaims) error { + exp, ok := claims["exp"] + if !ok { + return fmt.Errorf("expiration (exp) missing from token claims") + } + + epoch, ok := exp.(float64) + if !ok { + return fmt.Errorf("expiration (exp) claim is formatted incorrectly") + } + + expires := time.Unix(int64(epoch), 0) + + if time.Now().After(expires) { + return fmt.Errorf("token expired at %v", expires) + } + + return nil +} diff --git a/src/go/web/util/protobuf.go b/src/go/web/util/protobuf.go index 72df2fe2..3866b051 100644 --- a/src/go/web/util/protobuf.go +++ b/src/go/web/util/protobuf.go @@ -1,13 +1,14 @@ package util import ( + "sort" + "phenix/types" ifaces "phenix/types/interfaces" "phenix/util/mm" "phenix/web/cache" "phenix/web/proto" "phenix/web/rbac" - "sort" ) func ExperimentToProtobuf(exp types.Experiment, status cache.Status, vms []mm.VM) *proto.Experiment { diff --git a/src/js/src/App.vue b/src/js/src/App.vue index 24b6da43..1778f7d0 100644 --- a/src/js/src/App.vue +++ b/src/js/src/App.vue @@ -44,6 +44,15 @@ login and returns a user to Experiments component if successful. console.log(`ERROR getting features: ${err}`); } + try { + let resp = await fetch(this.$router.resolve({ name: 'options'}).href); + let options = await resp.json(); + + this.$store.commit( 'OPTIONS', options ); + } catch (err) { + console.log(`ERROR getting options: ${err}`); + } + this.wsConnect(); this.unwatch = this.$store.watch( diff --git a/src/js/src/components/Disabled.vue b/src/js/src/components/Disabled.vue index 0f9a9fd6..2eccbb47 100644 --- a/src/js/src/components/Disabled.vue +++ b/src/js/src/components/Disabled.vue @@ -5,8 +5,7 @@