From 876ef0705f36af72dce4704a50006344f3022357 Mon Sep 17 00:00:00 2001 From: Maxim Vladimirskiy Date: Fri, 1 Sep 2023 14:45:36 +0300 Subject: [PATCH 1/3] Run Mailgun builds --- .github/workflows/build.yaml | 62 -------------------- .github/workflows/build_and_push.yml | 41 ++++++++++++++ .github/workflows/check_doc.yml | 25 --------- .github/workflows/codeql.yml | 70 ----------------------- .github/workflows/documentation.yml | 52 ----------------- .github/workflows/experimental.yaml | 59 ------------------- .github/workflows/test-integration.yaml | 75 ------------------------- .github/workflows/test-unit.yaml | 31 ---------- .github/workflows/validate.yaml | 68 ---------------------- Dockerfile | 45 +++++++++++++-- 10 files changed, 81 insertions(+), 447 deletions(-) delete mode 100644 .github/workflows/build.yaml create mode 100644 .github/workflows/build_and_push.yml delete mode 100644 .github/workflows/check_doc.yml delete mode 100644 .github/workflows/codeql.yml delete mode 100644 .github/workflows/documentation.yml delete mode 100644 .github/workflows/experimental.yaml delete mode 100644 .github/workflows/test-integration.yaml delete mode 100644 .github/workflows/test-unit.yaml delete mode 100644 .github/workflows/validate.yaml diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml deleted file mode 100644 index 6f73e85ecf..0000000000 --- a/.github/workflows/build.yaml +++ /dev/null @@ -1,62 +0,0 @@ -name: Build Binaries - -on: - pull_request: - branches: - - '*' - -env: - GO_VERSION: '1.22' - CGO_ENABLED: 0 - -jobs: - - build-webui: - runs-on: ubuntu-22.04 - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Build webui - run: | - make clean-webui generate-webui - tar czvf webui.tar.gz ./webui/static/ - - - name: Artifact webui - uses: actions/upload-artifact@v4 - with: - name: webui.tar.gz - path: webui.tar.gz - - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ ubuntu-22.04, macos-latest, windows-latest ] - needs: - - build-webui - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 - with: - go-version: ${{ env.GO_VERSION }} - - - name: Artifact webui - uses: actions/download-artifact@v4 - with: - name: webui.tar.gz - - - name: Untar webui - run: tar xvf webui.tar.gz - - - name: Build - run: make binary diff --git a/.github/workflows/build_and_push.yml b/.github/workflows/build_and_push.yml new file mode 100644 index 0000000000..41a517d735 --- /dev/null +++ b/.github/workflows/build_and_push.yml @@ -0,0 +1,41 @@ +name: On Pull Request + +on: + pull_request: + branches: [v2.11.2-patched] + +jobs: + test: + name: build_and_push + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Setup Docker Layer Cache + uses: actions/cache@v2 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-docker-${{ github.event.number }} + restore-keys: ${{ runner.os }}-docker- + + # Buildx Needs QEMU + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + + # We use Buildx so we can take advantage of the caching + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ github.token }} + + - name: Build and push + uses: docker/build-push-action@v3 + with: + tags: ghcr.io/${{ github.repository }}:PR${{ github.event.number }} + push: true diff --git a/.github/workflows/check_doc.yml b/.github/workflows/check_doc.yml deleted file mode 100644 index 9851e1a54c..0000000000 --- a/.github/workflows/check_doc.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: Check Documentation - -on: - pull_request: - branches: - - '*' - -jobs: - - docs: - name: Check, verify and build documentation - runs-on: ubuntu-22.04 - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Check documentation - run: make docs-pull-images docs - env: - # These variables are not passed to workflows that are triggered by a pull request from a fork. - DOCS_VERIFY_SKIP: ${{ vars.DOCS_VERIFY_SKIP }} - DOCS_LINT_SKIP: ${{ vars.DOCS_LINT_SKIP }} diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml deleted file mode 100644 index f921d77893..0000000000 --- a/.github/workflows/codeql.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: "CodeQL" - -on: - push: - branches: - - master - - v* - schedule: - - cron: '11 22 * * 1' - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - strategy: - fail-fast: false - matrix: - language: [ 'javascript', 'go' ] - # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] - # Use only 'java' to analyze code written in Java, Kotlin or both - # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both - # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: setup go - uses: actions/setup-go@v5 - if: ${{ matrix.language == 'go' }} - with: - go-version-file: 'go.mod' - - # Initializes the CodeQL tools for scanning. - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: ${{ matrix.language }} - # If you wish to specify custom queries, you can do so here or in a config file. - # By default, queries listed here will override any specified in a config file. - # Prefix the list here with "+" to use these queries and those in the config file. - - # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs - # queries: security-extended,security-and-quality - - - # Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift). - # If this step fails, then you should remove it and run the build manually (see below) - - name: Autobuild - uses: github/codeql-action/autobuild@v3 - - # ℹī¸ Command-line programs to run using the OS shell. - # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun - - # If the Autobuild fails above, remove it and uncomment the following three lines. - # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. - - # - run: | - # echo "Run, Build Application using script" - # ./location_of_script_within_repo/buildscript.sh - - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 - with: - category: "/language:${{matrix.language}}" diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml deleted file mode 100644 index fd08fe4de7..0000000000 --- a/.github/workflows/documentation.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: Build and Publish Documentation - -on: - push: - branches: - - master - - v* - -env: - STRUCTOR_VERSION: v1.13.2 - MIXTUS_VERSION: v0.4.1 - -jobs: - - docs: - name: Doc Process - runs-on: ubuntu-22.04 - if: github.repository == 'traefik/traefik' - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Login to DockerHub - uses: docker/login-action@v1 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Install Structor ${{ env.STRUCTOR_VERSION }} - run: curl -sSfL https://raw.githubusercontent.com/traefik/structor/master/godownloader.sh | sh -s -- -b $HOME/bin ${STRUCTOR_VERSION} - - - name: Install Seo-doc - run: curl -sSfL https://raw.githubusercontent.com/traefik/seo-doc/master/godownloader.sh | sh -s -- -b "${HOME}/bin" - - - name: Install Mixtus ${{ env.MIXTUS_VERSION }} - run: curl -sSfL https://raw.githubusercontent.com/traefik/mixtus/master/godownloader.sh | sh -s -- -b $HOME/bin ${MIXTUS_VERSION} - - - name: Build documentation - run: $HOME/bin/structor -o traefik -r traefik --dockerfile-url="https://raw.githubusercontent.com/traefik/traefik/v1.7/docs.Dockerfile" --menu.js-url="https://raw.githubusercontent.com/traefik/structor/master/traefik-menu.js.gotmpl" --rqts-url="https://raw.githubusercontent.com/traefik/structor/master/requirements-override.txt" --force-edit-url --exp-branch=master --debug - env: - STRUCTOR_LATEST_TAG: ${{ vars.STRUCTOR_LATEST_TAG }} - - - name: Apply seo - run: $HOME/bin/seo -path=./site -product=traefik - - - name: Publish documentation - run: $HOME/bin/mixtus --dst-doc-path="./traefik" --dst-owner=traefik --dst-repo-name=doc --git-user-email="30906710+traefiker@users.noreply.github.com" --git-user-name=traefiker --src-doc-path="./site" --src-owner=containous --src-repo-name=traefik - env: - GITHUB_TOKEN: ${{ secrets.GH_TOKEN_REPO }} diff --git a/.github/workflows/experimental.yaml b/.github/workflows/experimental.yaml deleted file mode 100644 index bced49ca89..0000000000 --- a/.github/workflows/experimental.yaml +++ /dev/null @@ -1,59 +0,0 @@ -name: Build experimental image on branch - -on: - push: - branches: - - master - - v* - -env: - GO_VERSION: '1.22' - CGO_ENABLED: 0 - -jobs: - - experimental: - if: github.repository == 'traefik/traefik' - name: Build experimental image on branch - runs-on: ubuntu-22.04 - - steps: - - # https://github.com/marketplace/actions/checkout - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Build webui - run: | - make clean-webui generate-webui - - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 - with: - go-version: ${{ env.GO_VERSION }} - - - name: Build - run: make generate binary - - - name: Branch name - run: echo ${GITHUB_REF##*/} - - - name: Login to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - - name: Build docker experimental image - env: - DOCKER_BUILDX_ARGS: "--push" - run: | - make multi-arch-image-experimental-${GITHUB_REF##*/} diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml deleted file mode 100644 index 0c93c03c2b..0000000000 --- a/.github/workflows/test-integration.yaml +++ /dev/null @@ -1,75 +0,0 @@ -name: Test Integration - -on: - pull_request: - branches: - - '*' - push: - branches: - - 'gh-actions' - -env: - GO_VERSION: '1.22' - CGO_ENABLED: 0 - -jobs: - - build: - runs-on: ubuntu-22.04 - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 - with: - go-version: ${{ env.GO_VERSION }} - - - name: Avoid generating webui - run: touch webui/static/index.html - - - name: Build binary - run: make binary - - test-integration: - runs-on: ubuntu-22.04 - needs: - - build - strategy: - fail-fast: true - matrix: - parallel: [12] - index: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 , 11] - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 - with: - go-version: ${{ env.GO_VERSION }} - - - name: Avoid generating webui - run: touch webui/static/index.html - - - name: Build binary - run: make binary - - - name: Generate go test Slice - id: test_split - uses: hashicorp-forge/go-test-split-action@v1 - with: - packages: ./integration - total: ${{ matrix.parallel }} - index: ${{ matrix.index }} - - - name: Run Integration tests - run: | - TESTS=$(echo "${{ steps.test_split.outputs.run}}" | sed 's/\$/\$\$/g') - TESTFLAGS="-run \"${TESTS}\"" make test-integration diff --git a/.github/workflows/test-unit.yaml b/.github/workflows/test-unit.yaml deleted file mode 100644 index c23db68c3f..0000000000 --- a/.github/workflows/test-unit.yaml +++ /dev/null @@ -1,31 +0,0 @@ -name: Test Unit - -on: - pull_request: - branches: - - '*' - -env: - GO_VERSION: '1.22' - -jobs: - - test-unit: - runs-on: ubuntu-22.04 - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 - with: - go-version: ${{ env.GO_VERSION }} - - - name: Avoid generating webui - run: touch webui/static/index.html - - - name: Tests - run: make test-unit diff --git a/.github/workflows/validate.yaml b/.github/workflows/validate.yaml deleted file mode 100644 index 1ba3a6d618..0000000000 --- a/.github/workflows/validate.yaml +++ /dev/null @@ -1,68 +0,0 @@ -name: Validate - -on: - pull_request: - branches: - - '*' - -env: - GO_VERSION: '1.22' - GOLANGCI_LINT_VERSION: v1.57.0 - MISSSPELL_VERSION: v0.4.1 - -jobs: - - validate: - runs-on: ubuntu-22.04 - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 - with: - go-version: ${{ env.GO_VERSION }} - - - name: Install golangci-lint ${{ env.GOLANGCI_LINT_VERSION }} - run: curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin ${GOLANGCI_LINT_VERSION} - - - name: Install missspell ${{ env.MISSSPELL_VERSION }} - run: curl -sfL https://raw.githubusercontent.com/golangci/misspell/master/install-misspell.sh | sh -s -- -b $(go env GOPATH)/bin ${MISSSPELL_VERSION} - - - name: Avoid generating webui - run: touch webui/static/index.html - - - name: Validate - run: make validate - - validate-generate: - runs-on: ubuntu-22.04 - - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 - with: - go-version: ${{ env.GO_VERSION }} - - - name: go generate - run: | - make generate - git diff --exit-code - - - name: go mod tidy - run: | - go mod tidy - git diff --exit-code - - - name: make generate-crd - run: | - make generate-crd - git diff --exit-code diff --git a/Dockerfile b/Dockerfile index fea8092259..d72e899e2e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,46 @@ -# syntax=docker/dockerfile:1.2 -FROM alpine:3.19 +# WEBUI +FROM node:20.11 as webui -RUN apk --no-cache --no-progress add ca-certificates tzdata \ +ENV WEBUI_DIR /src/webui +RUN mkdir -p $WEBUI_DIR + +COPY ./webui/ $WEBUI_DIR/ + +WORKDIR $WEBUI_DIR + +RUN yarn install +RUN yarn build + +# BUILD +FROM golang:1.22-alpine as gobuild + +RUN apk --no-cache --no-progress add git mercurial bash gcc musl-dev curl tar ca-certificates tzdata \ + && update-ca-certificates \ + && rm -rf /var/cache/apk/* + +WORKDIR /go/src/github.com/mailgun/traefik + +# Download go modules +COPY go.mod . +COPY go.sum . +RUN GO111MODULE=on GOPROXY=https://proxy.golang.org go mod download + +COPY . /go/src/github.com/mailgun/traefik + +RUN rm -rf /go/src/github.com/mailgun/traefik/webui/static/ +COPY --from=webui /src/webui/static/ /go/src/github.com/mailgun/traefik/webui/static/ + +RUN go generate +RUN CGO_ENABLED=0 GOGC=off go build -ldflags "-s -w" -o dist/traefik ./cmd/traefik + +## IMAGE +FROM alpine:3.14 + +RUN apk --no-cache --no-progress add bash curl ca-certificates tzdata \ + && update-ca-certificates \ && rm -rf /var/cache/apk/* -ARG TARGETPLATFORM -COPY ./dist/$TARGETPLATFORM/traefik / +COPY --from=gobuild /go/src/github.com/mailgun/traefik/dist/traefik / EXPOSE 80 VOLUME ["/tmp"] From 98234a2a5a7fb8e60ba9ca6dc1dd0cb5f4889e8b Mon Sep 17 00:00:00 2001 From: Maxim Vladimirskiy Date: Fri, 1 Sep 2023 12:14:41 +0300 Subject: [PATCH 2/3] Merge necessary changes from master --- pkg/config/dynamic/fixtures/sample.toml | 6 +- pkg/config/dynamic/http_config.go | 31 +- pkg/config/label/label_test.go | 50 +- pkg/config/runtime/runtime.go | 6 + pkg/config/runtime/runtime_test.go | 8 +- pkg/healthcheck/healthcheck.go | 475 ++++------- pkg/healthcheck/healthcheck_test.go | 750 +++++++----------- pkg/healthcheck/mock_test.go | 182 +++++ .../empty_backend_handler.go | 34 - .../empty_backend_handler_test.go | 84 -- .../kubernetes/crd/kubernetes_test.go | 2 +- pkg/provider/kv/kv_test.go | 12 +- pkg/redactor/redactor_config_test.go | 6 +- .../testdata/anonymized-dynamic-config.json | 8 +- .../testdata/secured-dynamic-config.json | 8 +- pkg/server/router/router.go | 2 +- pkg/server/router/router_test.go | 4 +- pkg/server/routerfactory.go | 11 +- pkg/server/service/internalhandler.go | 2 +- pkg/server/service/loadbalancer/wrr/wrr.go | 15 +- .../service/loadbalancer/wrr/wrr_test.go | 112 +-- pkg/server/service/proxy.go | 165 ++-- pkg/server/service/proxy_test.go | 2 +- pkg/server/service/proxy_websocket_test.go | 124 +-- pkg/server/service/service.go | 279 ++----- pkg/server/service/service_test.go | 63 +- 26 files changed, 968 insertions(+), 1473 deletions(-) create mode 100644 pkg/healthcheck/mock_test.go delete mode 100644 pkg/middlewares/emptybackendhandler/empty_backend_handler.go delete mode 100644 pkg/middlewares/emptybackendhandler/empty_backend_handler_test.go diff --git a/pkg/config/dynamic/fixtures/sample.toml b/pkg/config/dynamic/fixtures/sample.toml index b3c84571ce..9ac084b3a6 100644 --- a/pkg/config/dynamic/fixtures/sample.toml +++ b/pkg/config/dynamic/fixtures/sample.toml @@ -425,14 +425,14 @@ scheme = "foobar" path = "foobar" port = 42 - interval = "foobar" - timeout = "foobar" + interval = "1s" + timeout = "1s" hostname = "foobar" [http.services.Service0.loadBalancer.healthCheck.headers] name0 = "foobar" name1 = "foobar" [http.services.Service0.loadBalancer.responseForwarding] - flushInterval = "foobar" + flushInterval = "100ms" [tcp] [tcp.routers] diff --git a/pkg/config/dynamic/http_config.go b/pkg/config/dynamic/http_config.go index af97325d7c..cdc713de49 100644 --- a/pkg/config/dynamic/http_config.go +++ b/pkg/config/dynamic/http_config.go @@ -9,6 +9,19 @@ import ( "github.com/traefik/traefik/v2/pkg/types" ) +const ( + // DefaultHealthCheckInterval is the default value for the ServerHealthCheck interval. + DefaultHealthCheckInterval = ptypes.Duration(30 * time.Second) + // DefaultHealthCheckTimeout is the default value for the ServerHealthCheck timeout. + DefaultHealthCheckTimeout = ptypes.Duration(5 * time.Second) + + // DefaultPassHostHeader is the default value for the ServersLoadBalancer passHostHeader. + DefaultPassHostHeader = true + + // DefaultFlushInterval is the default value for the ResponseForwarding flush interval. + DefaultFlushInterval = ptypes.Duration(100 * time.Millisecond) +) + // +k8s:deepcopy-gen=true // HTTPConfiguration contains all the HTTP configuration parameters. @@ -192,7 +205,7 @@ type ResponseForwarding struct { // This configuration is ignored when ReverseProxy recognizes a response as a streaming response; // for such responses, writes are flushed to the client immediately. // Default: 100ms - FlushInterval string `json:"flushInterval,omitempty" toml:"flushInterval,omitempty" yaml:"flushInterval,omitempty" export:"true"` + FlushInterval ptypes.Duration `json:"flushInterval,omitempty" toml:"flushInterval,omitempty" yaml:"flushInterval,omitempty" export:"true"` } // +k8s:deepcopy-gen=true @@ -213,14 +226,14 @@ func (s *Server) SetDefaults() { // ServerHealthCheck holds the HealthCheck configuration. type ServerHealthCheck struct { - Scheme string `json:"scheme,omitempty" toml:"scheme,omitempty" yaml:"scheme,omitempty" export:"true"` - Path string `json:"path,omitempty" toml:"path,omitempty" yaml:"path,omitempty" export:"true"` - Method string `json:"method,omitempty" toml:"method,omitempty" yaml:"method,omitempty" export:"true"` - Port int `json:"port,omitempty" toml:"port,omitempty,omitzero" yaml:"port,omitempty" export:"true"` - // TODO change string to ptypes.Duration - Interval string `json:"interval,omitempty" toml:"interval,omitempty" yaml:"interval,omitempty" export:"true"` - // TODO change string to ptypes.Duration - Timeout string `json:"timeout,omitempty" toml:"timeout,omitempty" yaml:"timeout,omitempty" export:"true"` + Scheme string `json:"scheme,omitempty" toml:"scheme,omitempty" yaml:"scheme,omitempty" export:"true"` + Mode string `json:"mode,omitempty" toml:"mode,omitempty" yaml:"mode,omitempty" export:"true"` + Path string `json:"path,omitempty" toml:"path,omitempty" yaml:"path,omitempty" export:"true"` + Method string `json:"method,omitempty" toml:"method,omitempty" yaml:"method,omitempty" export:"true"` + Status int `json:"status,omitempty" toml:"status,omitempty" yaml:"status,omitempty" export:"true"` + Port int `json:"port,omitempty" toml:"port,omitempty,omitzero" yaml:"port,omitempty" export:"true"` + Interval ptypes.Duration `json:"interval,omitempty" toml:"interval,omitempty" yaml:"interval,omitempty" export:"true"` + Timeout ptypes.Duration `json:"timeout,omitempty" toml:"timeout,omitempty" yaml:"timeout,omitempty" export:"true"` Hostname string `json:"hostname,omitempty" toml:"hostname,omitempty" yaml:"hostname,omitempty"` FollowRedirects *bool `json:"followRedirects" toml:"followRedirects" yaml:"followRedirects" export:"true"` Headers map[string]string `json:"headers,omitempty" toml:"headers,omitempty" yaml:"headers,omitempty" export:"true"` diff --git a/pkg/config/label/label_test.go b/pkg/config/label/label_test.go index a11fb3976f..004825beed 100644 --- a/pkg/config/label/label_test.go +++ b/pkg/config/label/label_test.go @@ -149,15 +149,15 @@ func TestDecodeConfiguration(t *testing.T) { "traefik.http.services.Service0.loadbalancer.healthcheck.headers.name0": "foobar", "traefik.http.services.Service0.loadbalancer.healthcheck.headers.name1": "foobar", "traefik.http.services.Service0.loadbalancer.healthcheck.hostname": "foobar", - "traefik.http.services.Service0.loadbalancer.healthcheck.interval": "foobar", + "traefik.http.services.Service0.loadbalancer.healthcheck.interval": "1s", "traefik.http.services.Service0.loadbalancer.healthcheck.path": "foobar", "traefik.http.services.Service0.loadbalancer.healthcheck.method": "foobar", "traefik.http.services.Service0.loadbalancer.healthcheck.port": "42", "traefik.http.services.Service0.loadbalancer.healthcheck.scheme": "foobar", - "traefik.http.services.Service0.loadbalancer.healthcheck.timeout": "foobar", + "traefik.http.services.Service0.loadbalancer.healthcheck.timeout": "1s", "traefik.http.services.Service0.loadbalancer.healthcheck.followredirects": "true", "traefik.http.services.Service0.loadbalancer.passhostheader": "true", - "traefik.http.services.Service0.loadbalancer.responseforwarding.flushinterval": "foobar", + "traefik.http.services.Service0.loadbalancer.responseforwarding.flushinterval": "100ms", "traefik.http.services.Service0.loadbalancer.server.scheme": "foobar", "traefik.http.services.Service0.loadbalancer.server.port": "8080", "traefik.http.services.Service0.loadbalancer.sticky.cookie.name": "foobar", @@ -165,15 +165,15 @@ func TestDecodeConfiguration(t *testing.T) { "traefik.http.services.Service1.loadbalancer.healthcheck.headers.name0": "foobar", "traefik.http.services.Service1.loadbalancer.healthcheck.headers.name1": "foobar", "traefik.http.services.Service1.loadbalancer.healthcheck.hostname": "foobar", - "traefik.http.services.Service1.loadbalancer.healthcheck.interval": "foobar", + "traefik.http.services.Service1.loadbalancer.healthcheck.interval": "1s", "traefik.http.services.Service1.loadbalancer.healthcheck.path": "foobar", "traefik.http.services.Service1.loadbalancer.healthcheck.method": "foobar", "traefik.http.services.Service1.loadbalancer.healthcheck.port": "42", "traefik.http.services.Service1.loadbalancer.healthcheck.scheme": "foobar", - "traefik.http.services.Service1.loadbalancer.healthcheck.timeout": "foobar", + "traefik.http.services.Service1.loadbalancer.healthcheck.timeout": "1s", "traefik.http.services.Service1.loadbalancer.healthcheck.followredirects": "true", "traefik.http.services.Service1.loadbalancer.passhostheader": "true", - "traefik.http.services.Service1.loadbalancer.responseforwarding.flushinterval": "foobar", + "traefik.http.services.Service1.loadbalancer.responseforwarding.flushinterval": "100ms", "traefik.http.services.Service1.loadbalancer.server.scheme": "foobar", "traefik.http.services.Service1.loadbalancer.server.port": "8080", "traefik.http.services.Service1.loadbalancer.sticky": "false", @@ -658,8 +658,8 @@ func TestDecodeConfiguration(t *testing.T) { Path: "foobar", Method: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", Headers: map[string]string{ "name0": "foobar", @@ -669,7 +669,7 @@ func TestDecodeConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, @@ -686,8 +686,8 @@ func TestDecodeConfiguration(t *testing.T) { Path: "foobar", Method: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", Headers: map[string]string{ "name0": "foobar", @@ -697,7 +697,7 @@ func TestDecodeConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, @@ -1162,8 +1162,8 @@ func TestEncodeConfiguration(t *testing.T) { Path: "foobar", Method: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", Headers: map[string]string{ "name0": "foobar", @@ -1172,7 +1172,7 @@ func TestEncodeConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, @@ -1189,8 +1189,8 @@ func TestEncodeConfiguration(t *testing.T) { Path: "foobar", Method: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", Headers: map[string]string{ "name0": "foobar", @@ -1199,7 +1199,7 @@ func TestEncodeConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, @@ -1359,14 +1359,15 @@ func TestEncodeConfiguration(t *testing.T) { "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Headers.name1": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Hostname": "foobar", - "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Interval": "foobar", + "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Interval": "1000000000", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Path": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Method": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Port": "42", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Scheme": "foobar", - "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Timeout": "foobar", + "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Status": "0", + "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Timeout": "1000000000", "traefik.HTTP.Services.Service0.LoadBalancer.PassHostHeader": "true", - "traefik.HTTP.Services.Service0.LoadBalancer.ResponseForwarding.FlushInterval": "foobar", + "traefik.HTTP.Services.Service0.LoadBalancer.ResponseForwarding.FlushInterval": "100000000", "traefik.HTTP.Services.Service0.LoadBalancer.server.Port": "8080", "traefik.HTTP.Services.Service0.LoadBalancer.server.Scheme": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.Sticky.Cookie.Name": "foobar", @@ -1375,14 +1376,15 @@ func TestEncodeConfiguration(t *testing.T) { "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Headers.name0": "foobar", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Headers.name1": "foobar", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Hostname": "foobar", - "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Interval": "foobar", + "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Interval": "1000000000", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Path": "foobar", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Method": "foobar", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Port": "42", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Scheme": "foobar", - "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Timeout": "foobar", + "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Status": "0", + "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Timeout": "1000000000", "traefik.HTTP.Services.Service1.LoadBalancer.PassHostHeader": "true", - "traefik.HTTP.Services.Service1.LoadBalancer.ResponseForwarding.FlushInterval": "foobar", + "traefik.HTTP.Services.Service1.LoadBalancer.ResponseForwarding.FlushInterval": "100000000", "traefik.HTTP.Services.Service1.LoadBalancer.server.Port": "8080", "traefik.HTTP.Services.Service1.LoadBalancer.server.Scheme": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Headers.name0": "foobar", diff --git a/pkg/config/runtime/runtime.go b/pkg/config/runtime/runtime.go index 48d527f65e..8f9f2a4355 100644 --- a/pkg/config/runtime/runtime.go +++ b/pkg/config/runtime/runtime.go @@ -15,6 +15,12 @@ const ( StatusWarning = "warning" ) +// Status of the servers. +const ( + StatusUp = "UP" + StatusDown = "DOWN" +) + // Configuration holds the information about the currently running traefik instance. type Configuration struct { Routers map[string]*RouterInfo `json:"routers,omitempty"` diff --git a/pkg/config/runtime/runtime_test.go b/pkg/config/runtime/runtime_test.go index 7b309b53df..e959ca1629 100644 --- a/pkg/config/runtime/runtime_test.go +++ b/pkg/config/runtime/runtime_test.go @@ -2,9 +2,11 @@ package runtime_test import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" ) @@ -49,7 +51,7 @@ func TestPopulateUsedBy(t *testing.T) { {URL: "http://127.0.0.1:8086"}, }, HealthCheck: &dynamic.ServerHealthCheck{ - Interval: "500ms", + Interval: ptypes.Duration(500 * time.Millisecond), Path: "/health", }, }, @@ -159,7 +161,7 @@ func TestPopulateUsedBy(t *testing.T) { }, }, HealthCheck: &dynamic.ServerHealthCheck{ - Interval: "500ms", + Interval: ptypes.Duration(500 * time.Millisecond), Path: "/health", }, }, @@ -177,7 +179,7 @@ func TestPopulateUsedBy(t *testing.T) { }, }, HealthCheck: &dynamic.ServerHealthCheck{ - Interval: "500ms", + Interval: ptypes.Duration(500 * time.Millisecond), Path: "/health", }, }, diff --git a/pkg/healthcheck/healthcheck.go b/pkg/healthcheck/healthcheck.go index 8bd8d1e17c..0fc0d56bf4 100644 --- a/pkg/healthcheck/healthcheck.go +++ b/pkg/healthcheck/healthcheck.go @@ -8,425 +8,258 @@ import ( "net/http" "net/url" "strconv" - "strings" - "sync" "time" gokitmetrics "github.com/go-kit/kit/metrics" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/log" - "github.com/traefik/traefik/v2/pkg/metrics" - "github.com/traefik/traefik/v2/pkg/safe" - "github.com/vulcand/oxy/v2/roundrobin" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" ) -const ( - serverUp = "UP" - serverDown = "DOWN" -) - -var ( - singleton *HealthCheck - once sync.Once -) +const modeGRPC = "grpc" -// Balancer is the set of operations required to manage the list of servers in a load-balancer. -type Balancer interface { - Servers() []*url.URL - RemoveServer(u *url.URL) error - UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error +// StatusSetter should be implemented by a service that, when the status of a +// registered target change, needs to be notified of that change. +type StatusSetter interface { + SetStatus(ctx context.Context, childName string, up bool) } -// BalancerHandler includes functionality for load-balancing management. -type BalancerHandler interface { - ServeHTTP(w http.ResponseWriter, req *http.Request) - Balancer -} - -// BalancerStatusHandler is an http Handler that does load-balancing, -// and updates its parents of its status. -type BalancerStatusHandler interface { - BalancerHandler - StatusUpdater +// StatusUpdater should be implemented by a service that, when its status +// changes (e.g. all if its children are down), needs to propagate upwards (to +// their parent(s)) that change. +type StatusUpdater interface { + RegisterStatusUpdater(fn func(up bool)) error } -type metricsHealthcheck struct { - serverUpGauge gokitmetrics.Gauge +type metricsHealthCheck interface { + ServiceServerUpGauge() gokitmetrics.Gauge } -// Options are the public health check options. -type Options struct { - Headers map[string]string - Hostname string - Scheme string - Path string - Method string - Port int - FollowRedirects bool - Transport http.RoundTripper - Interval time.Duration - Timeout time.Duration - LB Balancer -} +type ServiceHealthChecker struct { + balancer StatusSetter + info *runtime.ServiceInfo -func (opt Options) String() string { - return fmt.Sprintf("[Hostname: %s Headers: %v Path: %s Method: %s Port: %d Interval: %s Timeout: %s FollowRedirects: %v]", opt.Hostname, opt.Headers, opt.Path, opt.Method, opt.Port, opt.Interval, opt.Timeout, opt.FollowRedirects) -} + config *dynamic.ServerHealthCheck + interval time.Duration + timeout time.Duration -type backendURL struct { - url *url.URL - weight int -} + metrics metricsHealthCheck -// BackendConfig HealthCheck configuration for a backend. -type BackendConfig struct { - Options - name string - disabledURLs []backendURL + client *http.Client + targets map[string]*url.URL } -func (b *BackendConfig) newRequest(serverURL *url.URL) (*http.Request, error) { - u, err := serverURL.Parse(b.Path) - if err != nil { - return nil, err - } +func NewServiceHealthChecker(ctx context.Context, metrics metricsHealthCheck, config *dynamic.ServerHealthCheck, service StatusSetter, info *runtime.ServiceInfo, transport http.RoundTripper, targets map[string]*url.URL) *ServiceHealthChecker { + logger := log.FromContext(ctx) - if len(b.Scheme) > 0 { - u.Scheme = b.Scheme + interval := time.Duration(config.Interval) + if interval <= 0 { + logger.Error("Health check interval smaller than zero") + interval = time.Duration(dynamic.DefaultHealthCheckInterval) } - if b.Port != 0 { - u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port)) + timeout := time.Duration(config.Timeout) + if timeout <= 0 { + logger.Error("Health check timeout smaller than zero") + timeout = time.Duration(dynamic.DefaultHealthCheckTimeout) } - return http.NewRequest(http.MethodGet, u.String(), http.NoBody) -} - -// setRequestOptions sets all request options present on the BackendConfig. -func (b *BackendConfig) setRequestOptions(req *http.Request) *http.Request { - if b.Options.Hostname != "" { - req.Host = b.Options.Hostname + if timeout >= interval { + logger.Warnf("Health check timeout should be lower than the health check interval. Interval set to timeout + 1 second (%s).", interval) + interval = timeout + time.Second } - for k, v := range b.Options.Headers { - req.Header.Set(k, v) + client := &http.Client{ + Transport: transport, } - if b.Options.Method != "" { - req.Method = strings.ToUpper(b.Options.Method) - } - - return req -} - -// HealthCheck struct. -type HealthCheck struct { - Backends map[string]*BackendConfig - metrics metricsHealthcheck - cancel context.CancelFunc -} - -// SetBackendsConfiguration set backends configuration. -func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendConfig) { - hc.Backends = backends - if hc.cancel != nil { - hc.cancel() + if config.FollowRedirects != nil && !*config.FollowRedirects { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } } - ctx, cancel := context.WithCancel(parentCtx) - hc.cancel = cancel - for _, backend := range backends { - safe.Go(func() { - hc.execute(ctx, backend) - }) + return &ServiceHealthChecker{ + balancer: service, + info: info, + config: config, + interval: interval, + timeout: timeout, + targets: targets, + client: client, + metrics: metrics, } } -func (hc *HealthCheck) execute(ctx context.Context, backend *BackendConfig) { - logger := log.FromContext(ctx) - - logger.Debugf("Initial health check for backend: %q", backend.name) - hc.checkServersLB(ctx, backend) - - ticker := time.NewTicker(backend.Interval) +func (shc *ServiceHealthChecker) Launch(ctx context.Context) { + ticker := time.NewTicker(shc.interval) defer ticker.Stop() + for { select { case <-ctx.Done(): - logger.Debugf("Stopping current health check goroutines of backend: %s", backend.name) return - case <-ticker.C: - logger.Debugf("Routine health check refresh for backend: %s", backend.name) - hc.checkServersLB(ctx, backend) - } - } -} - -func (hc *HealthCheck) checkServersLB(ctx context.Context, backend *BackendConfig) { - logger := log.FromContext(ctx) - enabledURLs := backend.LB.Servers() + case <-ticker.C: + for proxyName, target := range shc.targets { + select { + case <-ctx.Done(): + return + default: + } - var newDisabledURLs []backendURL - for _, disabledURL := range backend.disabledURLs { - serverUpMetricValue := float64(0) + up := true + serverUpMetricValue := float64(1) - if err := checkHealth(disabledURL.url, backend); err == nil { - logger.Warnf("Health check up: returning to server list. Backend: %q URL: %q Weight: %d", - backend.name, disabledURL.url.String(), disabledURL.weight) - if err = backend.LB.UpsertServer(disabledURL.url, roundrobin.Weight(disabledURL.weight)); err != nil { - logger.Error(err) - } - serverUpMetricValue = 1 - } else { - logger.Warnf("Health check still failing. Backend: %q URL: %q Reason: %s", backend.name, disabledURL.url.String(), err) - newDisabledURLs = append(newDisabledURLs, disabledURL) - } + if err := shc.executeHealthCheck(ctx, shc.config, target); err != nil { + // The context is canceled when the dynamic configuration is refreshed. + if errors.Is(err, context.Canceled) { + return + } - labelValues := []string{"service", backend.name, "url", disabledURL.url.String()} - hc.metrics.serverUpGauge.With(labelValues...).Set(serverUpMetricValue) - } + log.FromContext(ctx).WithError(err).WithField("targetURL", target.String()).Error("Health check failed.") - backend.disabledURLs = newDisabledURLs + up = false + serverUpMetricValue = float64(0) + } - for _, enabledURL := range enabledURLs { - serverUpMetricValue := float64(1) + shc.balancer.SetStatus(ctx, proxyName, up) - if err := checkHealth(enabledURL, backend); err != nil { - weight := 1 - rr, ok := backend.LB.(*roundrobin.RoundRobin) - if ok { - var gotWeight bool - weight, gotWeight = rr.ServerWeight(enabledURL) - if !gotWeight { - weight = 1 + statusStr := runtime.StatusDown + if up { + statusStr = runtime.StatusUp } - } - logger.Warnf("Health check failed, removing from server list. Backend: %q URL: %q Weight: %d Reason: %s", - backend.name, enabledURL.String(), weight, err) - if err := backend.LB.RemoveServer(enabledURL); err != nil { - logger.Error(err) - } + shc.info.UpdateServerStatus(target.String(), statusStr) - backend.disabledURLs = append(backend.disabledURLs, backendURL{enabledURL, weight}) - serverUpMetricValue = 0 + shc.metrics.ServiceServerUpGauge(). + With("service", proxyName, "url", target.String()). + Set(serverUpMetricValue) + } } - - labelValues := []string{"service", backend.name, "url", enabledURL.String()} - hc.metrics.serverUpGauge.With(labelValues...).Set(serverUpMetricValue) } } -// GetHealthCheck returns the health check which is guaranteed to be a singleton. -func GetHealthCheck(registry metrics.Registry) *HealthCheck { - once.Do(func() { - singleton = newHealthCheck(registry) - }) - return singleton -} +func (shc *ServiceHealthChecker) executeHealthCheck(ctx context.Context, config *dynamic.ServerHealthCheck, target *url.URL) error { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(shc.timeout)) + defer cancel() -func newHealthCheck(registry metrics.Registry) *HealthCheck { - return &HealthCheck{ - Backends: make(map[string]*BackendConfig), - metrics: metricsHealthcheck{ - serverUpGauge: registry.ServiceServerUpGauge(), - }, + if config.Mode == modeGRPC { + return shc.checkHealthGRPC(ctx, target) } + return shc.checkHealthHTTP(ctx, target) } -// NewBackendConfig Instantiate a new BackendConfig. -func NewBackendConfig(options Options, backendName string) *BackendConfig { - return &BackendConfig{ - Options: options, - name: backendName, - } -} - -// checkHealth returns a nil error in case it was successful and otherwise -// a non-nil error with a meaningful description why the health check failed. -func checkHealth(serverURL *url.URL, backend *BackendConfig) error { - req, err := backend.newRequest(serverURL) +// checkHealthHTTP returns an error with a meaningful description if the health check failed. +// Dedicated to HTTP servers. +func (shc *ServiceHealthChecker) checkHealthHTTP(ctx context.Context, target *url.URL) error { + req, err := shc.newRequest(ctx, target) if err != nil { - return fmt.Errorf("failed to create HTTP request: %w", err) + return fmt.Errorf("create HTTP request: %w", err) } - req = backend.setRequestOptions(req) - - client := http.Client{ - Timeout: backend.Options.Timeout, - Transport: backend.Options.Transport, - } - - if !backend.FollowRedirects { - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - } - - resp, err := client.Do(req) + resp, err := shc.client.Do(req) if err != nil { return fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + if shc.config.Status == 0 && (resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest) { return fmt.Errorf("received error status code: %v", resp.StatusCode) } - return nil -} + if shc.config.Status != 0 && shc.config.Status != resp.StatusCode { + return fmt.Errorf("received error status code: %v expected status code: %v", resp.StatusCode, shc.config.Status) + } -// StatusUpdater should be implemented by a service that, when its status -// changes (e.g. all if its children are down), needs to propagate upwards (to -// their parent(s)) that change. -type StatusUpdater interface { - RegisterStatusUpdater(fn func(up bool)) error + return nil } -// NewLBStatusUpdater returns a new LbStatusUpdater. -func NewLBStatusUpdater(bh BalancerHandler, info *runtime.ServiceInfo, hc *dynamic.ServerHealthCheck) *LbStatusUpdater { - return &LbStatusUpdater{ - BalancerHandler: bh, - serviceInfo: info, - wantsHealthCheck: hc != nil, +func (shc *ServiceHealthChecker) newRequest(ctx context.Context, target *url.URL) (*http.Request, error) { + u, err := target.Parse(shc.config.Path) + if err != nil { + return nil, err } -} -// LbStatusUpdater wraps a BalancerHandler and a ServiceInfo, -// so it can keep track of the status of a server in the ServiceInfo. -type LbStatusUpdater struct { - BalancerHandler - serviceInfo *runtime.ServiceInfo // can be nil - updaters []func(up bool) - wantsHealthCheck bool -} - -// RegisterStatusUpdater adds fn to the list of hooks that are run when the -// status of the Balancer changes. -// Not thread safe. -func (lb *LbStatusUpdater) RegisterStatusUpdater(fn func(up bool)) error { - if !lb.wantsHealthCheck { - return errors.New("healthCheck not enabled in config for this loadbalancer service") + if len(shc.config.Scheme) > 0 { + u.Scheme = shc.config.Scheme } - lb.updaters = append(lb.updaters, fn) - return nil -} + if shc.config.Port != 0 { + u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(shc.config.Port)) + } -// RemoveServer removes the given server from the BalancerHandler, -// and updates the status of the server to "DOWN". -func (lb *LbStatusUpdater) RemoveServer(u *url.URL) error { - // TODO(mpl): when we have the freedom to change the signature of RemoveServer - // (kinda stuck because of oxy for now), let's pass around a context to improve - // logging. - ctx := context.TODO() - upBefore := len(lb.BalancerHandler.Servers()) > 0 - err := lb.BalancerHandler.RemoveServer(u) + req, err := http.NewRequestWithContext(ctx, shc.config.Method, u.String(), http.NoBody) if err != nil { - return err + return nil, fmt.Errorf("failed to create HTTP request: %w", err) } - if lb.serviceInfo != nil { - lb.serviceInfo.UpdateServerStatus(u.String(), serverDown) - } - log.FromContext(ctx).Debugf("child %s now %s", u.String(), serverDown) - if !upBefore { - // we were already down, and we still are, no need to propagate. - log.FromContext(ctx).Debugf("Still %s, no need to propagate", serverDown) - return nil - } - if len(lb.BalancerHandler.Servers()) > 0 { - // we were up, and we still are, no need to propagate - log.FromContext(ctx).Debugf("Still %s, no need to propagate", serverUp) - return nil + if shc.config.Hostname != "" { + req.Host = shc.config.Hostname } - log.FromContext(ctx).Debugf("Propagating new %s status", serverDown) - for _, fn := range lb.updaters { - fn(false) + for k, v := range shc.config.Headers { + req.Header.Set(k, v) } - return nil + + return req, nil } -// UpsertServer adds the given server to the BalancerHandler, -// and updates the status of the server to "UP". -func (lb *LbStatusUpdater) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - ctx := context.TODO() - upBefore := len(lb.BalancerHandler.Servers()) > 0 - err := lb.BalancerHandler.UpsertServer(u, options...) +// checkHealthGRPC returns an error with a meaningful description if the health check failed. +// Dedicated to gRPC servers implementing gRPC Health Checking Protocol v1. +func (shc *ServiceHealthChecker) checkHealthGRPC(ctx context.Context, serverURL *url.URL) error { + u, err := serverURL.Parse(shc.config.Path) if err != nil { - return err + return fmt.Errorf("failed to parse server URL: %w", err) } - if lb.serviceInfo != nil { - lb.serviceInfo.UpdateServerStatus(u.String(), serverUp) - } - log.FromContext(ctx).Debugf("child %s now %s", u.String(), serverUp) - if upBefore { - // we were up, and we still are, no need to propagate - log.FromContext(ctx).Debugf("Still %s, no need to propagate", serverUp) - return nil + port := u.Port() + if shc.config.Port != 0 { + port = strconv.Itoa(shc.config.Port) } - log.FromContext(ctx).Debugf("Propagating new %s status", serverUp) - for _, fn := range lb.updaters { - fn(true) - } - return nil -} + serverAddr := net.JoinHostPort(u.Hostname(), port) -// Balancers is a list of Balancers(s) that implements the Balancer interface. -type Balancers []Balancer - -// Servers returns the deduplicated server URLs from all the Balancer. -// Note that the deduplication is only possible because all the underlying -// balancers are of the same kind (the oxy implementation). -// The comparison property is the same as the one found at: -// https://github.com/vulcand/oxy/blob/fb2728c857b7973a27f8de2f2190729c0f22cf49/roundrobin/rr.go#L347. -func (b Balancers) Servers() []*url.URL { - seen := make(map[string]struct{}) - - var servers []*url.URL - for _, lb := range b { - for _, server := range lb.Servers() { - key := serverKey(server) - if _, ok := seen[key]; ok { - continue - } + var opts []grpc.DialOption + switch shc.config.Scheme { + case "http", "h2c", "": + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } - servers = append(servers, server) - seen[key] = struct{}{} + conn, err := grpc.DialContext(ctx, serverAddr, opts...) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("fail to connect to %s within %s: %w", serverAddr, shc.config.Timeout, err) } + return fmt.Errorf("fail to connect to %s: %w", serverAddr, err) } + defer func() { _ = conn.Close() }() - return servers -} - -// RemoveServer removes the given server from all the Balancer, -// and updates the status of the server to "DOWN". -func (b Balancers) RemoveServer(u *url.URL) error { - for _, lb := range b { - if err := lb.RemoveServer(u); err != nil { - return err + resp, err := healthpb.NewHealthClient(conn).Check(ctx, &healthpb.HealthCheckRequest{}) + if err != nil { + if stat, ok := status.FromError(err); ok { + switch stat.Code() { + case codes.Unimplemented: + return fmt.Errorf("gRPC server does not implement the health protocol: %w", err) + case codes.DeadlineExceeded: + return fmt.Errorf("gRPC health check timeout: %w", err) + case codes.Canceled: + return context.Canceled + } } + + return fmt.Errorf("gRPC health check failed: %w", err) } - return nil -} -// UpsertServer adds the given server to all the Balancer, -// and updates the status of the server to "UP". -func (b Balancers) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - for _, lb := range b { - if err := lb.UpsertServer(u, options...); err != nil { - return err - } + if resp.Status != healthpb.HealthCheckResponse_SERVING { + return fmt.Errorf("received gRPC status code: %v", resp.Status) } - return nil -} -func serverKey(u *url.URL) string { - return u.Path + u.Host + u.Scheme + return nil } diff --git a/pkg/healthcheck/healthcheck_test.go b/pkg/healthcheck/healthcheck_test.go index 73254b27e9..3e80ea3ad7 100644 --- a/pkg/healthcheck/healthcheck_test.go +++ b/pkg/healthcheck/healthcheck_test.go @@ -11,347 +11,181 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + ptypes "github.com/traefik/paerser/types" + "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/testhelpers" - "github.com/vulcand/oxy/v2/roundrobin" + healthpb "google.golang.org/grpc/health/grpc_health_v1" ) -const ( - healthCheckInterval = 200 * time.Millisecond - healthCheckTimeout = 100 * time.Millisecond -) - -const delta float64 = 1e-10 - -type testHandler struct { - done func() - healthSequence []int -} - -func TestSetBackendsConfiguration(t *testing.T) { - testCases := []struct { - desc string - startHealthy bool - healthSequence []int - expectedNumRemovedServers int - expectedNumUpsertedServers int - expectedGaugeValue float64 - }{ - { - desc: "healthy server staying healthy", - startHealthy: true, - healthSequence: []int{http.StatusOK}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 1, - }, - { - desc: "healthy server staying healthy (StatusNoContent)", - startHealthy: true, - healthSequence: []int{http.StatusNoContent}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 1, - }, - { - desc: "healthy server staying healthy (StatusPermanentRedirect)", - startHealthy: true, - healthSequence: []int{http.StatusPermanentRedirect}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 1, - }, - { - desc: "healthy server becoming sick", - startHealthy: true, - healthSequence: []int{http.StatusServiceUnavailable}, - expectedNumRemovedServers: 1, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 0, - }, - { - desc: "sick server becoming healthy", - startHealthy: false, - healthSequence: []int{http.StatusOK}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 1, - expectedGaugeValue: 1, - }, - { - desc: "sick server staying sick", - startHealthy: false, - healthSequence: []int{http.StatusServiceUnavailable}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 0, - }, - { - desc: "healthy server toggling to sick and back to healthy", - startHealthy: true, - healthSequence: []int{http.StatusServiceUnavailable, http.StatusOK}, - expectedNumRemovedServers: 1, - expectedNumUpsertedServers: 1, - expectedGaugeValue: 1, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - // The context is passed to the health check and canonically canceled by - // the test server once all expected requests have been received. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - ts := newTestServer(cancel, test.healthSequence) - defer ts.Close() - - lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}} - backend := NewBackendConfig(Options{ - Path: "/path", - Interval: healthCheckInterval, - Timeout: healthCheckTimeout, - LB: lb, - }, "backendName") - - serverURL := testhelpers.MustParseURL(ts.URL) - if test.startHealthy { - lb.servers = append(lb.servers, serverURL) - } else { - backend.disabledURLs = append(backend.disabledURLs, backendURL{url: serverURL, weight: 1}) - } - - collectingMetrics := &testhelpers.CollectingGauge{} - check := HealthCheck{ - Backends: make(map[string]*BackendConfig), - metrics: metricsHealthcheck{serverUpGauge: collectingMetrics}, - } - - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - check.execute(ctx, backend) - wg.Done() - }() - - // Make test timeout dependent on number of expected requests, health - // check interval, and a safety margin. - timeout := time.Duration(len(test.healthSequence)*int(healthCheckInterval) + 500) - select { - case <-time.After(timeout): - t.Fatal("test did not complete in time") - case <-ctx.Done(): - wg.Wait() - } - - lb.Lock() - defer lb.Unlock() - - assert.Equal(t, test.expectedNumRemovedServers, lb.numRemovedServers, "removed servers") - assert.Equal(t, test.expectedNumUpsertedServers, lb.numUpsertedServers, "upserted servers") - assert.InDelta(t, test.expectedGaugeValue, collectingMetrics.GaugeValue, delta, "ServerUp Gauge") - }) - } -} - -func TestNewRequest(t *testing.T) { - type expected struct { - err bool - value string - } - +func TestServiceHealthChecker_newRequest(t *testing.T) { testCases := []struct { - desc string - serverURL string - options Options - expected expected + desc string + targetURL string + config dynamic.ServerHealthCheck + expTarget string + expError bool + expHostname string + expHeader string + expMethod string }{ { desc: "no port override", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: "/test", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1:80/test", - }, + expError: false, + expTarget: "http://backend1:80/test", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, { desc: "port override", - serverURL: "http://backend2:80", - options: Options{ + targetURL: "http://backend2:80", + config: dynamic.ServerHealthCheck{ Path: "/test", Port: 8080, }, - expected: expected{ - err: false, - value: "http://backend2:8080/test", - }, + expError: false, + expTarget: "http://backend2:8080/test", + expHostname: "backend2:8080", + expMethod: http.MethodGet, }, { desc: "no port override with no port in server URL", - serverURL: "http://backend1", - options: Options{ + targetURL: "http://backend1", + config: dynamic.ServerHealthCheck{ Path: "/health", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1/health", - }, + expError: false, + expTarget: "http://backend1/health", + expHostname: "backend1", + expMethod: http.MethodGet, }, { desc: "port override with no port in server URL", - serverURL: "http://backend2", - options: Options{ + targetURL: "http://backend2", + config: dynamic.ServerHealthCheck{ Path: "/health", Port: 8080, }, - expected: expected{ - err: false, - value: "http://backend2:8080/health", - }, + expError: false, + expTarget: "http://backend2:8080/health", + expHostname: "backend2:8080", + expMethod: http.MethodGet, }, { desc: "scheme override", - serverURL: "https://backend1:80", - options: Options{ + targetURL: "https://backend1:80", + config: dynamic.ServerHealthCheck{ Scheme: "http", Path: "/test", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1:80/test", - }, + expError: false, + expTarget: "http://backend1:80/test", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, { desc: "path with param", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: "/health?powpow=do", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1:80/health?powpow=do", - }, + expError: false, + expTarget: "http://backend1:80/health?powpow=do", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, { desc: "path with params", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: "/health?powpow=do&do=powpow", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1:80/health?powpow=do&do=powpow", - }, + expError: false, + expTarget: "http://backend1:80/health?powpow=do&do=powpow", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, { desc: "path with invalid path", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: ":", Port: 0, }, - expected: expected{ - err: true, - value: "", - }, + expError: true, + expTarget: "", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - backend := NewBackendConfig(test.options, "backendName") - - u := testhelpers.MustParseURL(test.serverURL) - - req, err := backend.newRequest(u) - - if test.expected.err { - require.Error(t, err) - assert.Nil(t, nil) - } else { - require.NoError(t, err, "failed to create new backend request") - require.NotNil(t, req) - assert.Equal(t, test.expected.value, req.URL.String()) - } - }) - } -} - -func TestRequestOptions(t *testing.T) { - testCases := []struct { - desc string - serverURL string - options Options - expectedHostname string - expectedHeader string - expectedMethod string - }{ { desc: "override hostname", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Hostname: "myhost", Path: "/", }, - expectedHostname: "myhost", - expectedHeader: "", - expectedMethod: http.MethodGet, + expTarget: "http://backend1:80/", + expHostname: "myhost", + expHeader: "", + expMethod: http.MethodGet, }, { desc: "not override hostname", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Hostname: "", Path: "/", }, - expectedHostname: "backend1:80", - expectedHeader: "", - expectedMethod: http.MethodGet, + expTarget: "http://backend1:80/", + expHostname: "backend1:80", + expHeader: "", + expMethod: http.MethodGet, }, { desc: "custom header", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Headers: map[string]string{"Custom-Header": "foo"}, Hostname: "", Path: "/", }, - expectedHostname: "backend1:80", - expectedHeader: "foo", - expectedMethod: http.MethodGet, + expTarget: "http://backend1:80/", + expHostname: "backend1:80", + expHeader: "foo", + expMethod: http.MethodGet, }, { desc: "custom header with hostname override", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Headers: map[string]string{"Custom-Header": "foo"}, Hostname: "myhost", Path: "/", }, - expectedHostname: "myhost", - expectedHeader: "foo", - expectedMethod: http.MethodGet, + expTarget: "http://backend1:80/", + expHostname: "myhost", + expHeader: "foo", + expMethod: http.MethodGet, }, { desc: "custom method", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: "/", Method: http.MethodHead, }, - expectedHostname: "backend1:80", - expectedMethod: http.MethodHead, + expTarget: "http://backend1:80/", + expHostname: "backend1:80", + expMethod: http.MethodHead, }, } @@ -359,259 +193,215 @@ func TestRequestOptions(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - backend := NewBackendConfig(test.options, "backendName") - - u, err := url.Parse(test.serverURL) - require.NoError(t, err) + shc := ServiceHealthChecker{config: &test.config} - req, err := backend.newRequest(u) - require.NoError(t, err, "failed to create new backend request") + u := testhelpers.MustParseURL(test.targetURL) + req, err := shc.newRequest(context.Background(), u) - req = backend.setRequestOptions(req) + if test.expError { + require.Error(t, err) + assert.Nil(t, req) + } else { + require.NoError(t, err, "failed to create new request") + require.NotNil(t, req) - assert.Equal(t, "http://backend1:80/", req.URL.String()) - assert.Equal(t, test.expectedHostname, req.Host) - assert.Equal(t, test.expectedHeader, req.Header.Get("Custom-Header")) - assert.Equal(t, test.expectedMethod, req.Method) + assert.Equal(t, test.expTarget, req.URL.String()) + assert.Equal(t, test.expHeader, req.Header.Get("Custom-Header")) + assert.Equal(t, test.expHostname, req.Host) + assert.Equal(t, test.expMethod, req.Method) + } }) } } -func TestBalancers_Servers(t *testing.T) { - server1, err := url.Parse("http://foo.com") - require.NoError(t, err) - - balancer1, err := roundrobin.New(nil) - require.NoError(t, err) - - err = balancer1.UpsertServer(server1) - require.NoError(t, err) - - server2, err := url.Parse("http://foo.com") - require.NoError(t, err) - - balancer2, err := roundrobin.New(nil) - require.NoError(t, err) - - err = balancer2.UpsertServer(server2) - require.NoError(t, err) - - balancers := Balancers([]Balancer{balancer1, balancer2}) - - want, err := url.Parse("http://foo.com") - require.NoError(t, err) - - assert.Len(t, balancers.Servers(), 1) - assert.Equal(t, want, balancers.Servers()[0]) -} - -func TestBalancers_UpsertServer(t *testing.T) { - balancer1, err := roundrobin.New(nil) - require.NoError(t, err) - - balancer2, err := roundrobin.New(nil) - require.NoError(t, err) - - want, err := url.Parse("http://foo.com") - require.NoError(t, err) - - balancers := Balancers([]Balancer{balancer1, balancer2}) - - err = balancers.UpsertServer(want) - require.NoError(t, err) - - assert.Len(t, balancer1.Servers(), 1) - assert.Equal(t, want, balancer1.Servers()[0]) - - assert.Len(t, balancer2.Servers(), 1) - assert.Equal(t, want, balancer2.Servers()[0]) -} - -func TestBalancers_RemoveServer(t *testing.T) { - server, err := url.Parse("http://foo.com") - require.NoError(t, err) - - balancer1, err := roundrobin.New(nil) - require.NoError(t, err) - - err = balancer1.UpsertServer(server) - require.NoError(t, err) - - balancer2, err := roundrobin.New(nil) - require.NoError(t, err) - - err = balancer2.UpsertServer(server) - require.NoError(t, err) - - balancers := Balancers([]Balancer{balancer1, balancer2}) - - err = balancers.RemoveServer(server) - require.NoError(t, err) - - assert.Empty(t, balancer1.Servers()) - assert.Empty(t, balancer2.Servers()) -} - -type testLoadBalancer struct { - // RWMutex needed due to parallel test execution: Both the system-under-test - // and the test assertions reference the counters. - *sync.RWMutex - numRemovedServers int - numUpsertedServers int - servers []*url.URL - // options is just to make sure that LBStatusUpdater forwards options on Upsert to its BalancerHandler - options []roundrobin.ServerOption -} - -func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // noop -} - -func (lb *testLoadBalancer) RemoveServer(u *url.URL) error { - lb.Lock() - defer lb.Unlock() - lb.numRemovedServers++ - lb.removeServer(u) - return nil -} - -func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - lb.Lock() - defer lb.Unlock() - lb.numUpsertedServers++ - lb.servers = append(lb.servers, u) - lb.options = append(lb.options, options...) - return nil -} - -func (lb *testLoadBalancer) Servers() []*url.URL { - return lb.servers -} - -func (lb *testLoadBalancer) Options() []roundrobin.ServerOption { - return lb.options -} - -func (lb *testLoadBalancer) removeServer(u *url.URL) { - var i int - var serverURL *url.URL - found := false - for i, serverURL = range lb.servers { - if *serverURL == *u { - found = true - break - } - } - if !found { - return - } - - lb.servers = append(lb.servers[:i], lb.servers[i+1:]...) -} - -func newTestServer(done func(), healthSequence []int) *httptest.Server { - handler := &testHandler{ - done: done, - healthSequence: healthSequence, - } - return httptest.NewServer(handler) -} - -// ServeHTTP returns HTTP response codes following a status sequences. -// It calls the given 'done' function once all request health indicators have been depleted. -func (th *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if len(th.healthSequence) == 0 { - panic("received unexpected request") - } - - w.WriteHeader(th.healthSequence[0]) - - th.healthSequence = th.healthSequence[1:] - if len(th.healthSequence) == 0 { - th.done() - } -} - -func TestLBStatusUpdater(t *testing.T) { - lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}} - svInfo := &runtime.ServiceInfo{} - lbsu := NewLBStatusUpdater(lb, svInfo, nil) - newServer, err := url.Parse("http://foo.com") - assert.NoError(t, err) - err = lbsu.UpsertServer(newServer, roundrobin.Weight(1)) - assert.NoError(t, err) - assert.Len(t, lbsu.Servers(), 1) - assert.Len(t, lbsu.BalancerHandler.(*testLoadBalancer).Options(), 1) - statuses := svInfo.GetAllStatus() - assert.Len(t, statuses, 1) - for k, v := range statuses { - assert.Equal(t, newServer.String(), k) - assert.Equal(t, serverUp, v) - break - } - err = lbsu.RemoveServer(newServer) - assert.NoError(t, err) - assert.Empty(t, lbsu.Servers()) - statuses = svInfo.GetAllStatus() - assert.Len(t, statuses, 1) - for k, v := range statuses { - assert.Equal(t, newServer.String(), k) - assert.Equal(t, serverDown, v) - break - } -} - -func TestNotFollowingRedirects(t *testing.T) { +func TestServiceHealthChecker_checkHealthHTTP_NotFollowingRedirects(t *testing.T) { redirectServerCalled := false redirectTestServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { redirectServerCalled = true })) defer redirectTestServer.Close() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(dynamic.DefaultHealthCheckTimeout)) defer cancel() server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Add("location", redirectTestServer.URL) rw.WriteHeader(http.StatusSeeOther) - cancel() })) defer server.Close() - lb := &testLoadBalancer{ - RWMutex: &sync.RWMutex{}, - servers: []*url.URL{testhelpers.MustParseURL(server.URL)}, + config := &dynamic.ServerHealthCheck{ + Path: "/path", + FollowRedirects: Bool(false), + Interval: dynamic.DefaultHealthCheckInterval, + Timeout: dynamic.DefaultHealthCheckTimeout, } + healthChecker := NewServiceHealthChecker(ctx, nil, config, nil, nil, http.DefaultTransport, nil) - backend := NewBackendConfig(Options{ - Path: "/path", - Interval: healthCheckInterval, - Timeout: healthCheckTimeout, - LB: lb, - FollowRedirects: false, - }, "backendName") + err := healthChecker.checkHealthHTTP(ctx, testhelpers.MustParseURL(server.URL)) + require.NoError(t, err) - collectingMetrics := &testhelpers.CollectingGauge{} - check := HealthCheck{ - Backends: make(map[string]*BackendConfig), - metrics: metricsHealthcheck{serverUpGauge: collectingMetrics}, + assert.False(t, redirectServerCalled, "HTTP redirect must not be followed") +} + +func TestServiceHealthChecker_Launch(t *testing.T) { + testCases := []struct { + desc string + mode string + status int + server StartTestServer + expNumRemovedServers int + expNumUpsertedServers int + expGaugeValue float64 + targetStatus string + }{ + { + desc: "healthy server staying healthy", + server: newHTTPServer(http.StatusOK), + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server staying healthy, with custom code status check", + server: newHTTPServer(http.StatusNotFound), + status: http.StatusNotFound, + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server staying healthy (StatusNoContent)", + server: newHTTPServer(http.StatusNoContent), + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server staying healthy (StatusPermanentRedirect)", + server: newHTTPServer(http.StatusPermanentRedirect), + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server becoming sick", + server: newHTTPServer(http.StatusServiceUnavailable), + expNumRemovedServers: 1, + expNumUpsertedServers: 0, + expGaugeValue: 0, + targetStatus: runtime.StatusDown, + }, + { + desc: "healthy server becoming sick, with custom code status check", + server: newHTTPServer(http.StatusOK), + status: http.StatusServiceUnavailable, + expNumRemovedServers: 1, + expNumUpsertedServers: 0, + expGaugeValue: 0, + targetStatus: runtime.StatusDown, + }, + { + desc: "healthy server toggling to sick and back to healthy", + server: newHTTPServer(http.StatusServiceUnavailable, http.StatusOK), + expNumRemovedServers: 1, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server toggling to healthy and go to sick", + server: newHTTPServer(http.StatusOK, http.StatusServiceUnavailable), + expNumRemovedServers: 1, + expNumUpsertedServers: 1, + expGaugeValue: 0, + targetStatus: runtime.StatusDown, + }, + { + desc: "healthy grpc server staying healthy", + mode: "grpc", + server: newGRPCServer(healthpb.HealthCheckResponse_SERVING), + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy grpc server becoming sick", + mode: "grpc", + server: newGRPCServer(healthpb.HealthCheckResponse_NOT_SERVING), + expNumRemovedServers: 1, + expNumUpsertedServers: 0, + expGaugeValue: 0, + targetStatus: runtime.StatusDown, + }, + { + desc: "healthy grpc server toggling to sick and back to healthy", + mode: "grpc", + server: newGRPCServer(healthpb.HealthCheckResponse_NOT_SERVING, healthpb.HealthCheckResponse_SERVING), + expNumRemovedServers: 1, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, } - wg := sync.WaitGroup{} - wg.Add(1) + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + // The context is passed to the health check and + // canonically canceled by the test server once all expected requests have been received. + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + targetURL, timeout := test.server.Start(t, cancel) + + lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}} + + config := &dynamic.ServerHealthCheck{ + Mode: test.mode, + Status: test.status, + Path: "/path", + Interval: ptypes.Duration(500 * time.Millisecond), + Timeout: ptypes.Duration(499 * time.Millisecond), + } + + gauge := &testhelpers.CollectingGauge{} + serviceInfo := &runtime.ServiceInfo{} + hc := NewServiceHealthChecker(ctx, &MetricsMock{gauge}, config, lb, serviceInfo, http.DefaultTransport, map[string]*url.URL{"test": targetURL}) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + hc.Launch(ctx) + wg.Done() + }() + + select { + case <-time.After(timeout): + t.Fatal("test did not complete in time") + case <-ctx.Done(): + wg.Wait() + } - go func() { - check.execute(ctx, backend) - wg.Done() - }() + lb.Lock() + defer lb.Unlock() - timeout := time.Duration(int(healthCheckInterval) + 500) - select { - case <-time.After(timeout): - t.Fatal("test did not complete in time") - case <-ctx.Done(): - wg.Wait() + assert.Equal(t, test.expNumRemovedServers, lb.numRemovedServers, "removed servers") + assert.Equal(t, test.expNumUpsertedServers, lb.numUpsertedServers, "upserted servers") + assert.Equal(t, test.expGaugeValue, gauge.GaugeValue, "ServerUp Gauge") + assert.Equal(t, serviceInfo.GetAllStatus(), map[string]string{targetURL.String(): test.targetStatus}) + }) } +} - assert.False(t, redirectServerCalled, "HTTP redirect must not be followed") +func Bool(b bool) *bool { + return &b } diff --git a/pkg/healthcheck/mock_test.go b/pkg/healthcheck/mock_test.go new file mode 100644 index 0000000000..d2a9b52e11 --- /dev/null +++ b/pkg/healthcheck/mock_test.go @@ -0,0 +1,182 @@ +package healthcheck + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" + + gokitmetrics "github.com/go-kit/kit/metrics" + "github.com/stretchr/testify/assert" + "github.com/traefik/traefik/v2/pkg/config/dynamic" + "github.com/traefik/traefik/v2/pkg/testhelpers" + "google.golang.org/grpc" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +type StartTestServer interface { + Start(t *testing.T, done func()) (*url.URL, time.Duration) +} + +type Status interface { + ~int | ~int32 +} + +type HealthSequence[T Status] struct { + sequenceMu sync.Mutex + sequence []T +} + +func (s *HealthSequence[T]) Pop() T { + s.sequenceMu.Lock() + defer s.sequenceMu.Unlock() + + stat := s.sequence[0] + + s.sequence = s.sequence[1:] + + return stat +} + +func (s *HealthSequence[T]) IsEmpty() bool { + s.sequenceMu.Lock() + defer s.sequenceMu.Unlock() + + return len(s.sequence) == 0 +} + +type GRPCServer struct { + status HealthSequence[healthpb.HealthCheckResponse_ServingStatus] + done func() +} + +func newGRPCServer(healthSequence ...healthpb.HealthCheckResponse_ServingStatus) *GRPCServer { + gRPCService := &GRPCServer{ + status: HealthSequence[healthpb.HealthCheckResponse_ServingStatus]{ + sequence: healthSequence, + }, + } + + return gRPCService +} + +func (s *GRPCServer) Check(_ context.Context, _ *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + if s.status.IsEmpty() { + s.done() + return &healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_SERVICE_UNKNOWN, + }, nil + } + stat := s.status.Pop() + + return &healthpb.HealthCheckResponse{ + Status: stat, + }, nil +} + +func (s *GRPCServer) Watch(_ *healthpb.HealthCheckRequest, server healthpb.Health_WatchServer) error { + if s.status.IsEmpty() { + s.done() + return server.Send(&healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_SERVICE_UNKNOWN, + }) + } + stat := s.status.Pop() + + return server.Send(&healthpb.HealthCheckResponse{ + Status: stat, + }) +} + +func (s *GRPCServer) Start(t *testing.T, done func()) (*url.URL, time.Duration) { + t.Helper() + + listener, err := net.Listen("tcp4", "127.0.0.1:0") + assert.NoError(t, err) + t.Cleanup(func() { _ = listener.Close() }) + + server := grpc.NewServer() + t.Cleanup(server.Stop) + + s.done = done + + healthpb.RegisterHealthServer(server, s) + + go func() { + err := server.Serve(listener) + assert.NoError(t, err) + }() + + // Make test timeout dependent on number of expected requests, health check interval, and a safety margin. + return testhelpers.MustParseURL("http://" + listener.Addr().String()), time.Duration(len(s.status.sequence)*int(dynamic.DefaultHealthCheckInterval) + 500) +} + +type HTTPServer struct { + status HealthSequence[int] + done func() +} + +func newHTTPServer(healthSequence ...int) *HTTPServer { + handler := &HTTPServer{ + status: HealthSequence[int]{ + sequence: healthSequence, + }, + } + + return handler +} + +// ServeHTTP returns HTTP response codes following a status sequences. +// It calls the given 'done' function once all request health indicators have been depleted. +func (s *HTTPServer) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + if s.status.IsEmpty() { + s.done() + // This ensures that the health-checker will handle the context cancellation error before receiving the HTTP response. + time.Sleep(500 * time.Millisecond) + return + } + + stat := s.status.Pop() + + w.WriteHeader(stat) +} + +func (s *HTTPServer) Start(t *testing.T, done func()) (*url.URL, time.Duration) { + t.Helper() + + s.done = done + + ts := httptest.NewServer(s) + t.Cleanup(ts.Close) + + // Make test timeout dependent on number of expected requests, health check interval, and a safety margin. + return testhelpers.MustParseURL(ts.URL), time.Duration(len(s.status.sequence)*int(dynamic.DefaultHealthCheckInterval) + 500) +} + +type testLoadBalancer struct { + // RWMutex needed due to parallel test execution: Both the system-under-test + // and the test assertions reference the counters. + *sync.RWMutex + numRemovedServers int + numUpsertedServers int +} + +func (lb *testLoadBalancer) SetStatus(ctx context.Context, childName string, up bool) { + if up { + lb.numUpsertedServers++ + } else { + lb.numRemovedServers++ + } +} + +type MetricsMock struct { + Gauge gokitmetrics.Gauge +} + +func (m *MetricsMock) ServiceServerUpGauge() gokitmetrics.Gauge { + return m.Gauge +} diff --git a/pkg/middlewares/emptybackendhandler/empty_backend_handler.go b/pkg/middlewares/emptybackendhandler/empty_backend_handler.go deleted file mode 100644 index 3331bf3e36..0000000000 --- a/pkg/middlewares/emptybackendhandler/empty_backend_handler.go +++ /dev/null @@ -1,34 +0,0 @@ -package emptybackendhandler - -import ( - "net/http" - - "github.com/traefik/traefik/v2/pkg/healthcheck" -) - -// EmptyBackend is a middleware that checks whether the current Backend -// has at least one active Server in respect to the healthchecks and if this -// is not the case, it will stop the middleware chain and respond with 503. -type emptyBackend struct { - healthcheck.BalancerStatusHandler -} - -// New creates a new EmptyBackend middleware. -func New(lb healthcheck.BalancerStatusHandler) http.Handler { - return &emptyBackend{BalancerStatusHandler: lb} -} - -// ServeHTTP responds with 503 when there is no active Server and otherwise -// invokes the next handler in the middleware chain. -func (e *emptyBackend) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if len(e.BalancerStatusHandler.Servers()) != 0 { - e.BalancerStatusHandler.ServeHTTP(rw, req) - return - } - - rw.WriteHeader(http.StatusServiceUnavailable) - if _, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } -} diff --git a/pkg/middlewares/emptybackendhandler/empty_backend_handler_test.go b/pkg/middlewares/emptybackendhandler/empty_backend_handler_test.go deleted file mode 100644 index 299fe9c316..0000000000 --- a/pkg/middlewares/emptybackendhandler/empty_backend_handler_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package emptybackendhandler - -import ( - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/traefik/traefik/v2/pkg/testhelpers" - "github.com/vulcand/oxy/v2/roundrobin" -) - -func TestEmptyBackendHandler(t *testing.T) { - testCases := []struct { - amountServer int - expectedStatusCode int - }{ - { - amountServer: 0, - expectedStatusCode: http.StatusServiceUnavailable, - }, - { - amountServer: 1, - expectedStatusCode: http.StatusOK, - }, - } - - for _, test := range testCases { - t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) { - t.Parallel() - - handler := New(&healthCheckLoadBalancer{amountServer: test.amountServer}) - - recorder := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - - handler.ServeHTTP(recorder, req) - - assert.Equal(t, test.expectedStatusCode, recorder.Result().StatusCode) - }) - } -} - -type healthCheckLoadBalancer struct { - amountServer int -} - -func (lb *healthCheckLoadBalancer) RegisterStatusUpdater(fn func(up bool)) error { - return nil -} - -func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -func (lb *healthCheckLoadBalancer) Servers() []*url.URL { - servers := make([]*url.URL, lb.amountServer) - for range lb.amountServer { - servers = append(servers, testhelpers.MustParseURL("http://localhost")) - } - return servers -} - -func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error { - return nil -} - -func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - return nil -} - -func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) { - return 0, false -} - -func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) { - return nil, nil -} - -func (lb *healthCheckLoadBalancer) Next() http.Handler { - return nil -} diff --git a/pkg/provider/kubernetes/crd/kubernetes_test.go b/pkg/provider/kubernetes/crd/kubernetes_test.go index 6724e5c631..1037b6c678 100644 --- a/pkg/provider/kubernetes/crd/kubernetes_test.go +++ b/pkg/provider/kubernetes/crd/kubernetes_test.go @@ -3573,7 +3573,7 @@ func TestLoadIngressRoutes(t *testing.T) { }, }, PassHostHeader: Bool(false), - ResponseForwarding: &dynamic.ResponseForwarding{FlushInterval: "10s"}, + ResponseForwarding: &dynamic.ResponseForwarding{FlushInterval: ptypes.Duration(10 * time.Second)}, }, }, }, diff --git a/pkg/provider/kv/kv_test.go b/pkg/provider/kv/kv_test.go index 535a48bab6..b0b9653637 100644 --- a/pkg/provider/kv/kv_test.go +++ b/pkg/provider/kv/kv_test.go @@ -42,14 +42,14 @@ func Test_buildConfiguration(t *testing.T) { "traefik/http/routers/Router1/service": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/path": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/port": "42", - "traefik/http/services/Service01/loadBalancer/healthCheck/interval": "foobar", - "traefik/http/services/Service01/loadBalancer/healthCheck/timeout": "foobar", + "traefik/http/services/Service01/loadBalancer/healthCheck/interval": "1s", + "traefik/http/services/Service01/loadBalancer/healthCheck/timeout": "1s", "traefik/http/services/Service01/loadBalancer/healthCheck/hostname": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/headers/name0": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/headers/name1": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/scheme": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/followredirects": "true", - "traefik/http/services/Service01/loadBalancer/responseForwarding/flushInterval": "foobar", + "traefik/http/services/Service01/loadBalancer/responseForwarding/flushInterval": "100ms", "traefik/http/services/Service01/loadBalancer/passHostHeader": "true", "traefik/http/services/Service01/loadBalancer/sticky/cookie/name": "foobar", "traefik/http/services/Service01/loadBalancer/sticky/cookie/secure": "true", @@ -644,8 +644,8 @@ func Test_buildConfiguration(t *testing.T) { Scheme: "foobar", Path: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", FollowRedirects: func(v bool) *bool { return &v }(true), Headers: map[string]string{ @@ -655,7 +655,7 @@ func Test_buildConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, diff --git a/pkg/redactor/redactor_config_test.go b/pkg/redactor/redactor_config_test.go index 368f2d9564..019eb5247a 100644 --- a/pkg/redactor/redactor_config_test.go +++ b/pkg/redactor/redactor_config_test.go @@ -82,8 +82,8 @@ func init() { Scheme: "foo", Path: "foo", Port: 42, - Interval: "foo", - Timeout: "foo", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foo", FollowRedirects: boolPtr(true), Headers: map[string]string{ @@ -92,7 +92,7 @@ func init() { }, PassHostHeader: boolPtr(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foo", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, ServersTransport: "foo", Servers: []dynamic.Server{ diff --git a/pkg/redactor/testdata/anonymized-dynamic-config.json b/pkg/redactor/testdata/anonymized-dynamic-config.json index 8339f58c3c..6b967f91f3 100644 --- a/pkg/redactor/testdata/anonymized-dynamic-config.json +++ b/pkg/redactor/testdata/anonymized-dynamic-config.json @@ -75,8 +75,8 @@ "scheme": "foo", "path": "foo", "port": 42, - "interval": "foo", - "timeout": "foo", + "interval": "1s", + "timeout": "1s", "hostname": "xxxx", "followRedirects": true, "headers": { @@ -85,7 +85,7 @@ }, "passHostHeader": true, "responseForwarding": { - "flushInterval": "foo" + "flushInterval": "100ms" }, "serversTransport": "foo" } @@ -475,4 +475,4 @@ } } } -} \ No newline at end of file +} diff --git a/pkg/redactor/testdata/secured-dynamic-config.json b/pkg/redactor/testdata/secured-dynamic-config.json index 352421884d..4447b1698e 100644 --- a/pkg/redactor/testdata/secured-dynamic-config.json +++ b/pkg/redactor/testdata/secured-dynamic-config.json @@ -75,8 +75,8 @@ "scheme": "foo", "path": "foo", "port": 42, - "interval": "foo", - "timeout": "foo", + "interval": "1s", + "timeout": "1s", "hostname": "foo", "followRedirects": true, "headers": { @@ -85,7 +85,7 @@ }, "passHostHeader": true, "responseForwarding": { - "flushInterval": "foo" + "flushInterval": "100ms" }, "serversTransport": "foo" } @@ -483,4 +483,4 @@ } } } -} \ No newline at end of file +} diff --git a/pkg/server/router/router.go b/pkg/server/router/router.go index 0b88b0fafa..eabf5cbc84 100644 --- a/pkg/server/router/router.go +++ b/pkg/server/router/router.go @@ -31,7 +31,7 @@ type middlewareBuilder interface { type serviceManager interface { BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) - LaunchHealthCheck() + LaunchHealthCheck(ctx context.Context) } // Manager A route/router manager. diff --git a/pkg/server/router/router_test.go b/pkg/server/router/router_test.go index d5fba5af1a..2d47f64c92 100644 --- a/pkg/server/router/router_test.go +++ b/pkg/server/router/router_test.go @@ -8,10 +8,12 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/containous/alice" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/metrics" @@ -482,7 +484,7 @@ func TestRuntimeConfiguration(t *testing.T) { }, }, HealthCheck: &dynamic.ServerHealthCheck{ - Interval: "500ms", + Interval: ptypes.Duration(500 * time.Millisecond), Path: "/health", }, }, diff --git a/pkg/server/routerfactory.go b/pkg/server/routerfactory.go index 6b7b80ff4b..7a7e588b34 100644 --- a/pkg/server/routerfactory.go +++ b/pkg/server/routerfactory.go @@ -31,6 +31,8 @@ type RouterFactory struct { chainBuilder *middleware.ChainBuilder tlsManager *tls.Manager + + cancelPrevState func() } // NewRouterFactory creates a new RouterFactory. @@ -65,7 +67,12 @@ func NewRouterFactory(staticConfiguration static.Configuration, managerFactory * // CreateRouters creates new TCPRouters and UDPRouters. func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string]*tcprouter.Router, map[string]udptypes.Handler) { - ctx := context.Background() + if f.cancelPrevState != nil { + f.cancelPrevState() + } + + var ctx context.Context + ctx, f.cancelPrevState = context.WithCancel(context.Background()) // HTTP serviceManager := f.managerFactory.Build(rtConf) @@ -77,7 +84,7 @@ func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string handlersNonTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, false) handlersTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, true) - serviceManager.LaunchHealthCheck() + serviceManager.LaunchHealthCheck(ctx) // TCP svcTCPManager := tcp.NewManager(rtConf) diff --git a/pkg/server/service/internalhandler.go b/pkg/server/service/internalhandler.go index 4bd6fd2585..e5c3ab26d7 100644 --- a/pkg/server/service/internalhandler.go +++ b/pkg/server/service/internalhandler.go @@ -10,7 +10,7 @@ import ( type serviceManager interface { BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) - LaunchHealthCheck() + LaunchHealthCheck(ctx context.Context) } // InternalHandlers is the internal HTTP handlers builder. diff --git a/pkg/server/service/loadbalancer/wrr/wrr.go b/pkg/server/service/loadbalancer/wrr/wrr.go index ecca04b624..720945ffa3 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr.go +++ b/pkg/server/service/loadbalancer/wrr/wrr.go @@ -42,7 +42,7 @@ type Balancer struct { curDeadline float64 // status is a record of which child services of the Balancer are healthy, keyed // by name of child service. A service is initially added to the map when it is - // created via AddService, and it is later removed or added to the map as needed, + // created via Add, and it is later removed or added to the map as needed, // through the SetStatus method. status map[string]struct{} // updaters is the list of hooks that are run (to update the Balancer @@ -51,11 +51,11 @@ type Balancer struct { } // New creates a new load balancer. -func New(sticky *dynamic.Sticky, hc *dynamic.HealthCheck) *Balancer { +func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer { balancer := &Balancer{ status: make(map[string]struct{}), handlerMap: make(map[string]*namedHandler), - wantsHealthCheck: hc != nil, + wantsHealthCheck: wantHealthCheck, } if sticky != nil && sticky.Cookie != nil { balancer.stickyCookie = &stickyCookie{ @@ -155,10 +155,7 @@ func (b *Balancer) nextServer() (*namedHandler, error) { b.handlersMu.Lock() defer b.handlersMu.Unlock() - if len(b.handlers) == 0 { - return nil, errors.New("no servers in the pool") - } - if len(b.status) == 0 { + if len(b.handlers) == 0 || len(b.status) == 0 { return nil, errNoAvailableServer } @@ -224,9 +221,9 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { server.ServeHTTP(w, req) } -// AddService adds a handler. +// Add adds a handler. // A handler with a non-positive weight is ignored. -func (b *Balancer) AddService(name string, handler http.Handler, weight *int) { +func (b *Balancer) Add(name string, handler http.Handler, weight *int) { w := 1 if weight != nil { w = *weight diff --git a/pkg/server/service/loadbalancer/wrr/wrr_test.go b/pkg/server/service/loadbalancer/wrr/wrr_test.go index 32068504e3..19f2cf38ef 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr_test.go +++ b/pkg/server/service/loadbalancer/wrr/wrr_test.go @@ -10,31 +10,15 @@ import ( "github.com/traefik/traefik/v2/pkg/config/dynamic" ) -func Int(v int) *int { return &v } - -type responseRecorder struct { - *httptest.ResponseRecorder - save map[string]int - sequence []string - status []int -} - -func (r *responseRecorder) WriteHeader(statusCode int) { - r.save[r.Header().Get("server")]++ - r.sequence = append(r.sequence, r.Header().Get("server")) - r.status = append(r.status, statusCode) - r.ResponseRecorder.WriteHeader(statusCode) -} - func TestBalancer(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(3)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(1)) @@ -49,23 +33,23 @@ func TestBalancer(t *testing.T) { } func TestBalancerNoService(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) recorder := httptest.NewRecorder() balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode) + assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) } func TestBalancerOneServerZeroWeight(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} for range 3 { @@ -80,13 +64,13 @@ type key string const serviceName key = "serviceName" func TestBalancerNoServiceUp(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) }), Int(1)) @@ -100,14 +84,14 @@ func TestBalancerNoServiceUp(t *testing.T) { } func TestBalancerOneServerDown(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) }), Int(1)) balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false) @@ -121,14 +105,14 @@ func TestBalancerOneServerDown(t *testing.T) { } func TestBalancerDownThenUp(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(1)) @@ -150,35 +134,35 @@ func TestBalancerDownThenUp(t *testing.T) { } func TestBalancerPropagate(t *testing.T) { - balancer1 := New(nil, &dynamic.HealthCheck{}) + balancer1 := New(nil, true) - balancer1.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer1.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer1.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer1.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer2 := New(nil, &dynamic.HealthCheck{}) - balancer2.AddService("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer2 := New(nil, true) + balancer2.Add("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "third") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer2.AddService("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer2.Add("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "fourth") rw.WriteHeader(http.StatusOK) }), Int(1)) - topBalancer := New(nil, &dynamic.HealthCheck{}) - topBalancer.AddService("balancer1", balancer1, Int(1)) + topBalancer := New(nil, true) + topBalancer.Add("balancer1", balancer1, Int(1)) _ = balancer1.RegisterStatusUpdater(func(up bool) { topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer1", up) // TODO(mpl): if test gets flaky, add channel or something here to signal that // propagation is done, and wait on it before sending request. }) - topBalancer.AddService("balancer2", balancer2, Int(1)) + topBalancer.Add("balancer2", balancer2, Int(1)) _ = balancer2.RegisterStatusUpdater(func(up bool) { topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer2", up) }) @@ -223,28 +207,28 @@ func TestBalancerPropagate(t *testing.T) { } func TestBalancerAllServersZeroWeight(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) - balancer.AddService("test2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) + balancer.Add("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) + balancer.Add("test2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) recorder := httptest.NewRecorder() balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode) + assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) } func TestSticky(t *testing.T) { balancer := New(&dynamic.Sticky{ Cookie: &dynamic.Cookie{Name: "test"}, - }, nil) + }, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(2)) @@ -270,14 +254,14 @@ func TestSticky(t *testing.T) { func TestSticky_FallBack(t *testing.T) { balancer := New(&dynamic.Sticky{ Cookie: &dynamic.Cookie{Name: "test"}, - }, nil) + }, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(2)) @@ -299,21 +283,21 @@ func TestSticky_FallBack(t *testing.T) { // TestBalancerBias makes sure that the WRR algorithm spreads elements evenly right from the start, // and that it does not "over-favor" the high-weighted ones with a biased start-up regime. func TestBalancerBias(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "A") rw.WriteHeader(http.StatusOK) }), Int(11)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "B") rw.WriteHeader(http.StatusOK) }), Int(3)) recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 14 { + for i := 0; i < 14; i++ { balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) } @@ -321,3 +305,19 @@ func TestBalancerBias(t *testing.T) { assert.Equal(t, wantSequence, recorder.sequence) } + +func Int(v int) *int { return &v } + +type responseRecorder struct { + *httptest.ResponseRecorder + save map[string]int + sequence []string + status []int +} + +func (r *responseRecorder) WriteHeader(statusCode int) { + r.save[r.Header().Get("server")]++ + r.sequence = append(r.sequence, r.Header().Get("server")) + r.status = append(r.status, statusCode) + r.ResponseRecorder.WriteHeader(statusCode) +} diff --git a/pkg/server/service/proxy.go b/pkg/server/service/proxy.go index 257c4ce0d4..267b97b707 100644 --- a/pkg/server/service/proxy.go +++ b/pkg/server/service/proxy.go @@ -3,7 +3,6 @@ package service import ( "context" "errors" - "fmt" "io" "net" "net/http" @@ -12,8 +11,6 @@ import ( "strings" "time" - ptypes "github.com/traefik/paerser/types" - "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/log" "golang.org/x/net/http/httpguts" ) @@ -24,100 +21,104 @@ const StatusClientClosedRequest = 499 // StatusClientClosedRequestText non-standard HTTP status for client disconnection. const StatusClientClosedRequestText = "Client Closed Request" -func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwarding, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) (http.Handler, error) { - var flushInterval ptypes.Duration - if responseForwarding != nil { - err := flushInterval.Set(responseForwarding.FlushInterval) - if err != nil { - return nil, fmt.Errorf("error creating flush interval: %w", err) - } - } - if flushInterval == 0 { - flushInterval = ptypes.Duration(100 * time.Millisecond) +func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler { + return &httputil.ReverseProxy{ + Director: directorBuilder(target, passHostHeader), + Transport: roundTripper, + FlushInterval: flushInterval, + BufferPool: bufferPool, + ErrorHandler: errorHandler, } +} + +func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Request) { + return func(outReq *http.Request) { + outReq.URL.Scheme = target.Scheme + outReq.URL.Host = target.Host - proxy := &httputil.ReverseProxy{ - Director: func(outReq *http.Request) { - u := outReq.URL - if outReq.RequestURI != "" { - parsedURL, err := url.ParseRequestURI(outReq.RequestURI) - if err == nil { - u = parsedURL - } + u := outReq.URL + if outReq.RequestURI != "" { + parsedURL, err := url.ParseRequestURI(outReq.RequestURI) + if err == nil { + u = parsedURL } + } - outReq.URL.Path = u.Path - outReq.URL.RawPath = u.RawPath - // If a plugin/middleware adds semicolons in query params, they should be urlEncoded. - outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&") - outReq.RequestURI = "" // Outgoing request should not have RequestURI + outReq.URL.Path = u.Path + outReq.URL.RawPath = u.RawPath + // If a plugin/middleware adds semicolons in query params, they should be urlEncoded. + outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&") + outReq.RequestURI = "" // Outgoing request should not have RequestURI - outReq.Proto = "HTTP/1.1" - outReq.ProtoMajor = 1 - outReq.ProtoMinor = 1 + outReq.Proto = "HTTP/1.1" + outReq.ProtoMajor = 1 + outReq.ProtoMinor = 1 - // Do not pass client Host header unless optsetter PassHostHeader is set. - if passHostHeader != nil && !*passHostHeader { - outReq.Host = outReq.URL.Host - } - - // Even if the websocket RFC says that headers should be case-insensitive, - // some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept, - // Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. - // https://tools.ietf.org/html/rfc6455#page-20 - if isWebSocketUpgrade(outReq) { - outReq.Header["Sec-WebSocket-Key"] = outReq.Header["Sec-Websocket-Key"] - outReq.Header["Sec-WebSocket-Extensions"] = outReq.Header["Sec-Websocket-Extensions"] - outReq.Header["Sec-WebSocket-Accept"] = outReq.Header["Sec-Websocket-Accept"] - outReq.Header["Sec-WebSocket-Protocol"] = outReq.Header["Sec-Websocket-Protocol"] - outReq.Header["Sec-WebSocket-Version"] = outReq.Header["Sec-Websocket-Version"] - delete(outReq.Header, "Sec-Websocket-Key") - delete(outReq.Header, "Sec-Websocket-Extensions") - delete(outReq.Header, "Sec-Websocket-Accept") - delete(outReq.Header, "Sec-Websocket-Protocol") - delete(outReq.Header, "Sec-Websocket-Version") - } - }, - Transport: roundTripper, - FlushInterval: time.Duration(flushInterval), - BufferPool: bufferPool, - ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) { - statusCode := http.StatusInternalServerError + // Do not pass client Host header unless optsetter PassHostHeader is set. + if !passHostHeader { + outReq.Host = outReq.URL.Host + } - switch { - case errors.Is(err, io.EOF): - statusCode = http.StatusBadGateway - case errors.Is(err, context.Canceled): - statusCode = StatusClientClosedRequest - default: - var netErr net.Error - if errors.As(err, &netErr) { - if netErr.Timeout() { - statusCode = http.StatusGatewayTimeout - } else { - statusCode = http.StatusBadGateway - } - } - } + cleanWebSocketHeaders(outReq) + } +} - log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err) - w.WriteHeader(statusCode) - _, werr := w.Write([]byte(statusText(statusCode))) - if werr != nil { - log.Debugf("Error while writing status code", werr) - } - }, +// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive, +// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept, +// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. +// https://tools.ietf.org/html/rfc6455#page-20 +func cleanWebSocketHeaders(req *http.Request) { + if !isWebSocketUpgrade(req) { + return } - return proxy, nil + req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"] + delete(req.Header, "Sec-Websocket-Key") + + req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"] + delete(req.Header, "Sec-Websocket-Extensions") + + req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"] + delete(req.Header, "Sec-Websocket-Accept") + + req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"] + delete(req.Header, "Sec-Websocket-Protocol") + + req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"] + delete(req.Header, "Sec-Websocket-Version") } func isWebSocketUpgrade(req *http.Request) bool { - if !httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { - return false + return httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") && + strings.EqualFold(req.Header.Get("Upgrade"), "websocket") +} + +func errorHandler(w http.ResponseWriter, req *http.Request, err error) { + statusCode := http.StatusInternalServerError + + switch { + case errors.Is(err, io.EOF): + statusCode = http.StatusBadGateway + case errors.Is(err, context.Canceled): + statusCode = StatusClientClosedRequest + default: + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + statusCode = http.StatusGatewayTimeout + } else { + statusCode = http.StatusBadGateway + } + } } - return strings.EqualFold(req.Header.Get("Upgrade"), "websocket") + logger := log.FromContext(req.Context()) + logger.WithError(err).Debugf("%d %s", statusCode, statusText(statusCode)) + + w.WriteHeader(statusCode) + if _, werr := w.Write([]byte(statusText(statusCode))); werr != nil { + logger.WithError(werr).Debug("Error while writing status code") + } } func statusText(statusCode int) string { diff --git a/pkg/server/service/proxy_test.go b/pkg/server/service/proxy_test.go index 567ba61057..afaff38809 100644 --- a/pkg/server/service/proxy_test.go +++ b/pkg/server/service/proxy_test.go @@ -28,7 +28,7 @@ func BenchmarkProxy(b *testing.B) { req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil) pool := newBufferPool() - handler, _ := buildProxy(Bool(false), nil, &staticTransport{res}, pool) + handler := buildSingleHostProxy(req.URL, false, 0, &staticTransport{res}, pool) b.ReportAllocs() for range b.N { diff --git a/pkg/server/service/proxy_websocket_test.go b/pkg/server/service/proxy_websocket_test.go index 108133c37e..1d2293aeae 100644 --- a/pkg/server/service/proxy_websocket_test.go +++ b/pkg/server/service/proxy_websocket_test.go @@ -21,9 +21,6 @@ import ( func Bool(v bool) *bool { return &v } func TestWebSocketTCPClose(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - errChan := make(chan error, 1) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -42,7 +39,7 @@ func TestWebSocketTCPClose(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) proxyAddr := proxy.Listener.Addr().String() _, conn, err := newWebsocketRequest( @@ -61,10 +58,6 @@ func TestWebSocketTCPClose(t *testing.T) { } func TestWebSocketPingPong(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{ HandshakeTimeout: 10 * time.Second, CheckOrigin: func(*http.Request) bool { @@ -86,17 +79,10 @@ func TestWebSocketPingPong(t *testing.T) { _, _, _ = ws.ReadMessage() }) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() - proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.URL = parseURI(t, srv.URL) - f.ServeHTTP(w, req) - })) - defer proxy.Close() - + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) serverAddr := proxy.Listener.Addr().String() headers := http.Header{} @@ -127,9 +113,6 @@ func TestWebSocketPingPong(t *testing.T) { } func TestWebSocketEcho(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { msg := make([]byte, 4) @@ -145,17 +128,10 @@ func TestWebSocketEcho(t *testing.T) { require.NoError(t, err) })) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() - proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.URL = parseURI(t, srv.URL) - f.ServeHTTP(w, req) - })) - defer proxy.Close() - + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) serverAddr := proxy.Listener.Addr().String() headers := http.Header{} @@ -193,10 +169,6 @@ func TestWebSocketPassHost(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - f, err := buildProxy(Bool(test.passHost), nil, http.DefaultTransport, nil) - - require.NoError(t, err) - mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { req := conn.Request() @@ -208,7 +180,7 @@ func TestWebSocketPassHost(t *testing.T) { } msg := make([]byte, 4) - _, err = conn.Read(msg) + _, err := conn.Read(msg) require.NoError(t, err) fmt.Println(string(msg)) @@ -219,16 +191,10 @@ func TestWebSocketPassHost(t *testing.T) { require.NoError(t, err) })) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() - proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.URL = parseURI(t, srv.URL) - f.ServeHTTP(w, req) - })) - defer proxy.Close() + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) serverAddr := proxy.Listener.Addr().String() @@ -252,9 +218,6 @@ func TestWebSocketPassHost(t *testing.T) { } func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} @@ -277,7 +240,7 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() @@ -293,9 +256,6 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { } func TestWebSocketRequestWithOrigin(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) @@ -316,11 +276,11 @@ func TestWebSocketRequestWithOrigin(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() - _, err = newWebsocketRequest( + _, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("echo"), @@ -339,9 +299,6 @@ func TestWebSocketRequestWithOrigin(t *testing.T) { } func TestWebSocketRequestWithQueryParams(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) @@ -363,7 +320,7 @@ func TestWebSocketRequestWithQueryParams(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() @@ -379,18 +336,14 @@ func TestWebSocketRequestWithQueryParams(t *testing.T) { } func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { conn.Close() })) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() + f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil) proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = parseURI(t, srv.URL) w.Header().Set("HEADER-KEY", "HEADER-VALUE") @@ -403,6 +356,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) + conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", err, resp) defer conn.Close() @@ -411,9 +365,6 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { } func TestWebSocketRequestWithEncodedChar(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) @@ -435,7 +386,7 @@ func TestWebSocketRequestWithEncodedChar(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() @@ -451,18 +402,14 @@ func TestWebSocketRequestWithEncodedChar(t *testing.T) { } func TestWebSocketUpgradeFailed(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - mux := http.NewServeMux() mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusBadRequest) }) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() + f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil) proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path @@ -501,9 +448,6 @@ func TestWebSocketUpgradeFailed(t *testing.T) { } func TestForwardsWebsocketTraffic(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { _, err := conn.Write([]byte("ok")) @@ -512,12 +456,10 @@ func TestForwardsWebsocketTraffic(t *testing.T) { err = conn.Close() require.NoError(t, err) })) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() @@ -557,15 +499,12 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { srv := createTLSWebsocketServer() defer srv.Close() - forwarderWithoutTLSConfig, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - - proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL) + proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxyWithoutTLSConfig.Close() proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String() - _, err = newWebsocketRequest( + _, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), @@ -576,10 +515,8 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { transport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - forwarderWithTLSConfig, err := buildProxy(Bool(true), nil, transport, nil) - require.NoError(t, err) - proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL) + proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, transport) defer proxyWithTLSConfig.Close() proxyAddr = proxyWithTLSConfig.Listener.Addr().String() @@ -597,10 +534,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { defaultTransport := http.DefaultTransport.(*http.Transport).Clone() defaultTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(Bool(true), nil, defaultTransport, nil) - require.NoError(t, err) - - proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL) + proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, srv.URL, defaultTransport) defer proxyWithTLSConfig.Close() proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String() @@ -705,15 +639,19 @@ func parseURI(t *testing.T, uri string) *url.URL { return out } -func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server { +func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server { t.Helper() - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + u := parseURI(t, uri) + proxy := buildSingleHostProxy(u, true, 0, transport, nil) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path // Set new backend URL - req.URL = parseURI(t, url) + req.URL = u req.URL.Path = path proxy.ServeHTTP(w, req) })) + t.Cleanup(srv.Close) + return srv } diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index cc43e76a2a..dc8013d515 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -4,36 +4,28 @@ import ( "context" "errors" "fmt" + "hash/fnv" "math/rand" "net/http" "net/http/httputil" "net/url" "reflect" + "strings" "time" - "github.com/containous/alice" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/healthcheck" "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/metrics" "github.com/traefik/traefik/v2/pkg/middlewares/accesslog" - "github.com/traefik/traefik/v2/pkg/middlewares/emptybackendhandler" metricsMiddle "github.com/traefik/traefik/v2/pkg/middlewares/metrics" - "github.com/traefik/traefik/v2/pkg/middlewares/pipelining" "github.com/traefik/traefik/v2/pkg/safe" "github.com/traefik/traefik/v2/pkg/server/cookie" "github.com/traefik/traefik/v2/pkg/server/provider" "github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/failover" "github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/mirror" "github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/wrr" - "github.com/vulcand/oxy/v2/roundrobin" - "github.com/vulcand/oxy/v2/roundrobin/stickycookie" -) - -const ( - defaultHealthCheckInterval = 30 * time.Second - defaultHealthCheckTimeout = 5 * time.Second ) const defaultMaxBodySize int64 = -1 @@ -43,6 +35,19 @@ type RoundTripperGetter interface { Get(name string) (http.RoundTripper, error) } +// Manager The service manager. +type Manager struct { + routinePool *safe.Pool + metricsRegistry metrics.Registry + bufferPool httputil.BufferPool + roundTripperManager RoundTripperGetter + + services map[string]http.Handler + configs map[string]*runtime.ServiceInfo + healthCheckers map[string]*healthcheck.ServiceHealthChecker + rand *rand.Rand // For the initial shuffling of load-balancers. +} + // NewManager creates a new Manager. func NewManager(configs map[string]*runtime.ServiceInfo, metricsRegistry metrics.Registry, routinePool *safe.Pool, roundTripperManager RoundTripperGetter) *Manager { return &Manager{ @@ -50,27 +55,13 @@ func NewManager(configs map[string]*runtime.ServiceInfo, metricsRegistry metrics metricsRegistry: metricsRegistry, bufferPool: newBufferPool(), roundTripperManager: roundTripperManager, - balancers: make(map[string]healthcheck.Balancers), + services: make(map[string]http.Handler), configs: configs, + healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker), rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } -// Manager The service manager. -type Manager struct { - routinePool *safe.Pool - metricsRegistry metrics.Registry - bufferPool httputil.BufferPool - roundTripperManager RoundTripperGetter - // balancers is the map of all Balancers, keyed by service name. - // There is one Balancer per service handler, and there is one service handler per reference to a service - // (e.g. if 2 routers refer to the same service name, 2 service handlers are created), - // which is why there is not just one Balancer per service name. - balancers map[string]healthcheck.Balancers - configs map[string]*runtime.ServiceInfo - rand *rand.Rand // For the initial shuffling of load-balancers. -} - // BuildHTTP Creates a http.Handler for a service configuration. func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) { ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName)) @@ -78,11 +69,20 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H serviceName = provider.GetQualifiedName(ctx, serviceName) ctx = provider.AddInContext(ctx, serviceName) + handler, ok := m.services[serviceName] + if ok { + return handler, nil + } + conf, ok := m.configs[serviceName] if !ok { return nil, fmt.Errorf("the service %q does not exist", serviceName) } + if conf.Status == runtime.StatusDisabled { + return nil, errors.New(strings.Join(conf.Err, ", ")) + } + value := reflect.ValueOf(*conf.Service) var count int for i := range value.NumField() { @@ -101,7 +101,7 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H switch { case conf.LoadBalancer != nil: var err error - lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer) + lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf) if err != nil { conf.AddError(err, true) return nil, err @@ -133,6 +133,8 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H return nil, sErr } + m.services[serviceName] = lb + return lb, nil } @@ -214,14 +216,14 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, config.Sticky.Cookie.Name = cookie.GetName(config.Sticky.Cookie.Name, serviceName) } - balancer := wrr.New(config.Sticky, config.HealthCheck) + balancer := wrr.New(config.Sticky, config.HealthCheck != nil) for _, service := range shuffle(config.Services, m.rand) { serviceHandler, err := m.BuildHTTP(ctx, service.Name) if err != nil { return nil, err } - balancer.AddService(service.Name, serviceHandler, service.Weight) + balancer.Add(service.Name, serviceHandler, service.Weight) if config.HealthCheck == nil { continue @@ -245,201 +247,94 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, return balancer, nil } -func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName string, service *dynamic.ServersLoadBalancer) (http.Handler, error) { - if service.PassHostHeader == nil { - defaultPassHostHeader := true - service.PassHostHeader = &defaultPassHostHeader - } +func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName string, info *runtime.ServiceInfo) (http.Handler, error) { + service := info.LoadBalancer - if len(service.ServersTransport) > 0 { - service.ServersTransport = provider.GetQualifiedName(ctx, service.ServersTransport) - } + logger := log.FromContext(ctx) + logger.Debug("Creating load-balancer") - roundTripper, err := m.roundTripperManager.Get(service.ServersTransport) - if err != nil { - return nil, err + // TODO: should we keep this config value as Go is now handling stream response correctly? + flushInterval := dynamic.DefaultFlushInterval + if service.ResponseForwarding != nil { + flushInterval = service.ResponseForwarding.FlushInterval } - fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, roundTripper, m.bufferPool) - if err != nil { - return nil, err + if len(service.ServersTransport) > 0 { + service.ServersTransport = provider.GetQualifiedName(ctx, service.ServersTransport) } - alHandler := func(next http.Handler) (http.Handler, error) { - return accesslog.NewFieldHandler(next, accesslog.ServiceName, serviceName, accesslog.AddServiceFields), nil - } - chain := alice.New() - if m.metricsRegistry != nil && m.metricsRegistry.IsSvcEnabled() { - chain = chain.Append(metricsMiddle.WrapServiceHandler(ctx, m.metricsRegistry, serviceName)) + if service.Sticky != nil && service.Sticky.Cookie != nil { + service.Sticky.Cookie.Name = cookie.GetName(service.Sticky.Cookie.Name, serviceName) } - handler, err := chain.Append(alHandler).Then(pipelining.New(ctx, fwd, "pipelining")) - if err != nil { - return nil, err + // We make sure that the PassHostHeader value is defined to avoid panics. + passHostHeader := dynamic.DefaultPassHostHeader + if service.PassHostHeader != nil { + passHostHeader = *service.PassHostHeader } - balancer, err := m.getLoadBalancer(ctx, serviceName, service, handler) + roundTripper, err := m.roundTripperManager.Get(service.ServersTransport) if err != nil { return nil, err } - // TODO rename and checks - m.balancers[serviceName] = append(m.balancers[serviceName], balancer) + lb := wrr.New(service.Sticky, service.HealthCheck != nil) + healthCheckTargets := make(map[string]*url.URL) - // Empty (backend with no servers) - return emptybackendhandler.New(balancer), nil -} - -// LaunchHealthCheck launches the health checks. -func (m *Manager) LaunchHealthCheck() { - backendConfigs := make(map[string]*healthcheck.BackendConfig) + for _, server := range shuffle(service.Servers, m.rand) { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(server.URL)) // this will never return an error. - for serviceName, balancers := range m.balancers { - ctx := log.With(context.Background(), log.Str(log.ServiceName, serviceName)) + proxyName := fmt.Sprintf("%x", hasher.Sum(nil)) - service := m.configs[serviceName].LoadBalancer - - // Health Check - hcOpts := buildHealthCheckOptions(ctx, balancers, serviceName, service.HealthCheck) - if hcOpts == nil { - continue - } - hcOpts.Transport, _ = m.roundTripperManager.Get(service.ServersTransport) - log.FromContext(ctx).Debugf("Setting up healthcheck for service %s with %s", serviceName, *hcOpts) - - backendConfigs[serviceName] = healthcheck.NewBackendConfig(*hcOpts, serviceName) - } - - healthcheck.GetHealthCheck(m.metricsRegistry).SetBackendsConfiguration(context.Background(), backendConfigs) -} - -func buildHealthCheckOptions(ctx context.Context, lb healthcheck.Balancer, backend string, hc *dynamic.ServerHealthCheck) *healthcheck.Options { - if hc == nil { - return nil - } - - logger := log.FromContext(ctx) - - if hc.Path == "" { - logger.Errorf("Ignoring heath check configuration for '%s': no path provided", backend) - return nil - } - - interval := defaultHealthCheckInterval - if hc.Interval != "" { - intervalOverride, err := time.ParseDuration(hc.Interval) - switch { - case err != nil: - logger.Errorf("Illegal health check interval for '%s': %s", backend, err) - case intervalOverride <= 0: - logger.Errorf("Health check interval smaller than zero for service '%s'", backend) - default: - interval = intervalOverride - } - } - - timeout := defaultHealthCheckTimeout - if hc.Timeout != "" { - timeoutOverride, err := time.ParseDuration(hc.Timeout) - switch { - case err != nil: - logger.Errorf("Illegal health check timeout for backend '%s': %s", backend, err) - case timeoutOverride <= 0: - logger.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend) - default: - timeout = timeoutOverride + target, err := url.Parse(server.URL) + if err != nil { + return nil, fmt.Errorf("error parsing server URL %s: %w", server.URL, err) } - } - - followRedirects := true - if hc.FollowRedirects != nil { - followRedirects = *hc.FollowRedirects - } - - return &healthcheck.Options{ - Scheme: hc.Scheme, - Path: hc.Path, - Method: hc.Method, - Port: hc.Port, - Interval: interval, - Timeout: timeout, - LB: lb, - Hostname: hc.Hostname, - Headers: hc.Headers, - FollowRedirects: followRedirects, - } -} - -func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, service *dynamic.ServersLoadBalancer, fwd http.Handler) (healthcheck.BalancerStatusHandler, error) { - logger := log.FromContext(ctx) - logger.Debug("Creating load-balancer") - var options []roundrobin.LBOption + logger.WithField(log.ServerName, proxyName).WithField("target", target).Info("Creating server") - var cookieName string - if service.Sticky != nil && service.Sticky.Cookie != nil { - cookieName = cookie.GetName(service.Sticky.Cookie.Name, serviceName) + proxy := buildSingleHostProxy(target, passHostHeader, time.Duration(flushInterval), roundTripper, m.bufferPool) - opts := roundrobin.CookieOptions{ - HTTPOnly: service.Sticky.Cookie.HTTPOnly, - Secure: service.Sticky.Cookie.Secure, - SameSite: convertSameSite(service.Sticky.Cookie.SameSite), - } + // FIXME: Traefik v2.6.x that is currently in prod defines ServiceURL + // field as object where as the v3.x.x. branch converted it to string. + // to allow both versions coexist in production ServiceURL is renamed + // to ServiceURLRaw. + proxy = accesslog.NewFieldHandler(proxy, "ServiceURLRaw", target.String(), nil) + proxy = accesslog.NewFieldHandler(proxy, accesslog.ServiceAddr, target.Host, nil) + proxy = accesslog.NewFieldHandler(proxy, accesslog.ServiceName, serviceName, nil) - // Sticky Cookie Value - cv, err := stickycookie.NewFallbackValue(&stickycookie.RawValue{}, &stickycookie.HashValue{}) - if err != nil { - return nil, err + if m.metricsRegistry != nil && m.metricsRegistry.IsSvcEnabled() { + proxy = metricsMiddle.NewServiceMiddleware(ctx, proxy, m.metricsRegistry, serviceName) } - options = append(options, roundrobin.EnableStickySession(roundrobin.NewStickySessionWithOptions(cookieName, opts).SetCookieValue(cv))) + lb.Add(proxyName, proxy, nil) - logger.Debugf("Sticky session cookie name: %v", cookieName) - } + // servers are considered UP by default. + info.UpdateServerStatus(target.String(), runtime.StatusUp) - lb, err := roundrobin.New(fwd, options...) - if err != nil { - return nil, err + healthCheckTargets[proxyName] = target } - lbsu := healthcheck.NewLBStatusUpdater(lb, m.configs[serviceName], service.HealthCheck) - if err := m.upsertServers(ctx, lbsu, service.Servers); err != nil { - return nil, fmt.Errorf("error configuring load balancer for service %s: %w", serviceName, err) + if service.HealthCheck != nil { + m.healthCheckers[serviceName] = healthcheck.NewServiceHealthChecker( + ctx, + m.metricsRegistry, + service.HealthCheck, + lb, + info, + roundTripper, + healthCheckTargets, + ) } - return lbsu, nil -} - -func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []dynamic.Server) error { - logger := log.FromContext(ctx) - - for name, srv := range shuffle(servers, m.rand) { - u, err := url.Parse(srv.URL) - if err != nil { - return fmt.Errorf("error parsing server URL %s: %w", srv.URL, err) - } - - logger.WithField(log.ServerName, name).Debugf("Creating server %d %s", name, u) - - if err := lb.UpsertServer(u, roundrobin.Weight(1)); err != nil { - return fmt.Errorf("error adding server %s to load balancer: %w", srv.URL, err) - } - - // TODO Handle Metrics - } - return nil + return lb, nil } -func convertSameSite(sameSite string) http.SameSite { - switch sameSite { - case "none": - return http.SameSiteNoneMode - case "lax": - return http.SameSiteLaxMode - case "strict": - return http.SameSiteStrictMode - default: - return 0 +// LaunchHealthCheck launches the health checks. +func (m *Manager) LaunchHealthCheck(ctx context.Context) { + for serviceName, hc := range m.healthCheckers { + go hc.Launch(log.With(ctx, log.Str(log.ServiceName, serviceName))) } } diff --git a/pkg/server/service/service_test.go b/pkg/server/service/service_test.go index b761b410e9..8fa27c429b 100644 --- a/pkg/server/service/service_test.go +++ b/pkg/server/service/service_test.go @@ -24,63 +24,6 @@ func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) { panic("implement me") } -func TestGetLoadBalancer(t *testing.T) { - sm := Manager{} - - testCases := []struct { - desc string - serviceName string - service *dynamic.ServersLoadBalancer - fwd http.Handler - expectError bool - }{ - { - desc: "Fails when provided an invalid URL", - serviceName: "test", - service: &dynamic.ServersLoadBalancer{ - Servers: []dynamic.Server{ - { - URL: ":", - }, - }, - }, - fwd: &MockForwarder{}, - expectError: true, - }, - { - desc: "Succeeds when there are no servers", - serviceName: "test", - service: &dynamic.ServersLoadBalancer{}, - fwd: &MockForwarder{}, - expectError: false, - }, - { - desc: "Succeeds when sticky.cookie is set", - serviceName: "test", - service: &dynamic.ServersLoadBalancer{ - Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}}, - }, - fwd: &MockForwarder{}, - expectError: false, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - handler, err := sm.getLoadBalancer(context.Background(), test.serviceName, test.service, test.fwd) - if test.expectError { - require.Error(t, err) - assert.Nil(t, handler) - } else { - require.NoError(t, err) - assert.NotNil(t, handler) - } - }) - } -} - func TestGetLoadBalancerServiceHandler(t *testing.T) { sm := NewManager(nil, nil, nil, &RoundTripperManager{ roundTrippers: map[string]http.RoundTripper{ @@ -336,7 +279,8 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service) + serviceInfo := runtime.ServiceInfo{Service: &dynamic.Service{LoadBalancer: test.service}} + handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, &serviceInfo) assert.NoError(t, err) assert.NotNil(t, handler) @@ -414,7 +358,8 @@ func Test1xxResponses(t *testing.T) { }, }, } - handler, err := sm.getLoadBalancerServiceHandler(context.Background(), "foobar", config) + serviceInfo := runtime.ServiceInfo{Service: &dynamic.Service{LoadBalancer: config}} + handler, err := sm.getLoadBalancerServiceHandler(context.Background(), "foobar", &serviceInfo) assert.NoError(t, err) frontend := httptest.NewServer(handler) From db233056cb84344dba06bf32833b916518a800f1 Mon Sep 17 00:00:00 2001 From: Maxim Vladimirskiy Date: Fri, 1 Sep 2023 17:44:03 +0300 Subject: [PATCH 3/3] Change WRR LB to ensure even pending requests --- pkg/server/service/loadbalancer/wrr/wrr.go | 272 ++++++----- .../service/loadbalancer/wrr/wrr_test.go | 451 +++++++----------- 2 files changed, 315 insertions(+), 408 deletions(-) diff --git a/pkg/server/service/loadbalancer/wrr/wrr.go b/pkg/server/service/loadbalancer/wrr/wrr.go index 720945ffa3..009835fbf6 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr.go +++ b/pkg/server/service/loadbalancer/wrr/wrr.go @@ -4,12 +4,11 @@ import ( "container/heap" "context" "errors" - "hash/fnv" "net/http" - "strconv" "sync" "github.com/traefik/traefik/v2/pkg/config/dynamic" + "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/log" ) @@ -17,7 +16,9 @@ type namedHandler struct { http.Handler name string weight float64 - deadline float64 + pending uint64 + healthy bool + queueIdx int } type stickyCookie struct { @@ -34,27 +35,20 @@ type stickyCookie struct { type Balancer struct { stickyCookie *stickyCookie wantsHealthCheck bool - - handlersMu sync.RWMutex - // References all the handlers by name and also by the hashed value of the name. - handlerMap map[string]*namedHandler - handlers []*namedHandler - curDeadline float64 - // status is a record of which child services of the Balancer are healthy, keyed - // by name of child service. A service is initially added to the map when it is - // created via Add, and it is later removed or added to the map as needed, - // through the SetStatus method. - status map[string]struct{} // updaters is the list of hooks that are run (to update the Balancer // parent(s)), whenever the Balancer status changes. updaters []func(bool) + + mutex sync.RWMutex + enabledHandlers priorityQueue + handlersByName map[string]*namedHandler + healthyCount int } // New creates a new load balancer. func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer { balancer := &Balancer{ - status: make(map[string]struct{}), - handlerMap: make(map[string]*namedHandler), + handlersByName: make(map[string]*namedHandler), wantsHealthCheck: wantHealthCheck, } if sticky != nil && sticky.Cookie != nil { @@ -64,78 +58,53 @@ func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer { httpOnly: sticky.Cookie.HTTPOnly, } } - return balancer } -// Len implements heap.Interface/sort.Interface. -func (b *Balancer) Len() int { return len(b.handlers) } - -// Less implements heap.Interface/sort.Interface. -func (b *Balancer) Less(i, j int) bool { - return b.handlers[i].deadline < b.handlers[j].deadline -} - -// Swap implements heap.Interface/sort.Interface. -func (b *Balancer) Swap(i, j int) { - b.handlers[i], b.handlers[j] = b.handlers[j], b.handlers[i] -} - -// Push implements heap.Interface for pushing an item into the heap. -func (b *Balancer) Push(x interface{}) { - h, ok := x.(*namedHandler) - if !ok { - return - } - - b.handlers = append(b.handlers, h) -} - -// Pop implements heap.Interface for popping an item from the heap. -// It panics if b.Len() < 1. -func (b *Balancer) Pop() interface{} { - h := b.handlers[len(b.handlers)-1] - b.handlers = b.handlers[0 : len(b.handlers)-1] - return h -} - // SetStatus sets on the balancer that its given child is now of the given -// status. balancerName is only needed for logging purposes. -func (b *Balancer) SetStatus(ctx context.Context, childName string, up bool) { - b.handlersMu.Lock() - defer b.handlersMu.Unlock() - - upBefore := len(b.status) > 0 - - status := "DOWN" - if up { - status = "UP" - } - log.FromContext(ctx).Debugf("Setting status of %s to %v", childName, status) - if up { - b.status[childName] = struct{}{} - } else { - delete(b.status, childName) +// status. +func (b *Balancer) SetStatus(ctx context.Context, childName string, healthy bool) { + log.FromContext(ctx).Debugf("Setting status of %s to %v", childName, statusAsStr(healthy)) + + b.mutex.Lock() + nh := b.handlersByName[childName] + if nh == nil { + b.mutex.Unlock() + return } - upAfter := len(b.status) > 0 - status = "DOWN" - if upAfter { - status = "UP" + healthyBefore := b.healthyCount > 0 + if nh.healthy != healthy { + nh.healthy = healthy + if healthy { + b.healthyCount++ + b.enabledHandlers.push(nh) + } else { + b.healthyCount-- + } } + healthyAfter := b.healthyCount > 0 + b.mutex.Unlock() // No Status Change - if upBefore == upAfter { + if healthyBefore == healthyAfter { // We're still with the same status, no need to propagate - log.FromContext(ctx).Debugf("Still %s, no need to propagate", status) + log.FromContext(ctx).Debugf("Still %s, no need to propagate", statusAsStr(healthyBefore)) return } // Status Change - log.FromContext(ctx).Debugf("Propagating new %s status", status) + log.FromContext(ctx).Debugf("Propagating new %s status", statusAsStr(healthyAfter)) for _, fn := range b.updaters { - fn(upAfter) + fn(healthyAfter) + } +} + +func statusAsStr(healthy bool) string { + if healthy { + return runtime.StatusUp } + return runtime.StatusDown } // RegisterStatusUpdater adds fn to the list of hooks that are run when the @@ -151,59 +120,61 @@ func (b *Balancer) RegisterStatusUpdater(fn func(up bool)) error { var errNoAvailableServer = errors.New("no available server") -func (b *Balancer) nextServer() (*namedHandler, error) { - b.handlersMu.Lock() - defer b.handlersMu.Unlock() - - if len(b.handlers) == 0 || len(b.status) == 0 { - return nil, errNoAvailableServer +func (b *Balancer) acquireHandler(preferredName string) (*namedHandler, error) { + b.mutex.Lock() + defer b.mutex.Unlock() + var nh *namedHandler + // Check the preferred handler fist if provided. + if preferredName != "" { + nh = b.handlersByName[preferredName] + if nh != nil && nh.healthy { + nh.pending++ + b.enabledHandlers.fix(nh) + return nh, nil + } } - - var handler *namedHandler + // Pick the handler with the least number of pending requests. for { - // Pick handler with closest deadline. - handler = heap.Pop(b).(*namedHandler) - - // curDeadline should be handler's deadline so that new added entry would have a fair competition environment with the old ones. - b.curDeadline = handler.deadline - handler.deadline += 1 / handler.weight - - heap.Push(b, handler) - if _, ok := b.status[handler.name]; ok { - break + nh = b.enabledHandlers.pop() + if nh == nil { + return nil, errNoAvailableServer + } + // If the handler is marked as unhealthy, then continue with the next + // best option. It will be put back into the priority queue once its + // status changes to healthy. + if !nh.healthy { + continue } + // Otherwise increment the number of pending requests, put it back into + // the priority queue, and return it as a selected for the request. + nh.pending++ + b.enabledHandlers.push(nh) + log.WithoutContext().Debugf("Service selected by WRR: %s", nh.name) + return nh, nil } +} - log.WithoutContext().Debugf("Service selected by WRR: %s", handler.name) - return handler, nil +func (b *Balancer) releaseHandler(nh *namedHandler) { + b.mutex.Lock() + defer b.mutex.Unlock() + nh.pending-- + if nh.healthy { + b.enabledHandlers.fix(nh) + } } func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + var preferredName string if b.stickyCookie != nil { cookie, err := req.Cookie(b.stickyCookie.name) - if err != nil && !errors.Is(err, http.ErrNoCookie) { log.WithoutContext().Warnf("Error while reading cookie: %v", err) } - if err == nil && cookie != nil { - b.handlersMu.RLock() - handler, ok := b.handlerMap[cookie.Value] - b.handlersMu.RUnlock() - - if ok && handler != nil { - b.handlersMu.RLock() - _, isHealthy := b.status[handler.name] - b.handlersMu.RUnlock() - if isHealthy { - handler.ServeHTTP(w, req) - return - } - } + preferredName = cookie.Value } } - - server, err := b.nextServer() + nh, err := b.acquireHandler(preferredName) if err != nil { if errors.Is(err, errNoAvailableServer) { http.Error(w, errNoAvailableServer.Error(), http.StatusServiceUnavailable) @@ -214,11 +185,18 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if b.stickyCookie != nil { - cookie := &http.Cookie{Name: b.stickyCookie.name, Value: hash(server.name), Path: "/", HttpOnly: b.stickyCookie.httpOnly, Secure: b.stickyCookie.secure} + cookie := &http.Cookie{ + Name: b.stickyCookie.name, + Value: nh.name, + Path: "/", + HttpOnly: b.stickyCookie.httpOnly, + Secure: b.stickyCookie.secure, + } http.SetCookie(w, cookie) } - server.ServeHTTP(w, req) + nh.ServeHTTP(w, req) + b.releaseHandler(nh) } // Add adds a handler. @@ -233,21 +211,67 @@ func (b *Balancer) Add(name string, handler http.Handler, weight *int) { return } - h := &namedHandler{Handler: handler, name: name, weight: float64(w)} + nh := &namedHandler{ + Handler: handler, + name: name, + weight: float64(w), + pending: 1, + healthy: true, + } + b.mutex.Lock() + b.enabledHandlers.push(nh) + b.handlersByName[nh.name] = nh + b.healthyCount++ + b.mutex.Unlock() +} + +type priorityQueue struct { + heap []*namedHandler +} - b.handlersMu.Lock() - h.deadline = b.curDeadline + 1/h.weight - heap.Push(b, h) - b.status[name] = struct{}{} - b.handlerMap[name] = h - b.handlerMap[hash(name)] = h - b.handlersMu.Unlock() +func (pq *priorityQueue) push(nh *namedHandler) { + heap.Push(pq, nh) } -func hash(input string) string { - hasher := fnv.New64() - // We purposely ignore the error because the implementation always returns nil. - _, _ = hasher.Write([]byte(input)) +func (pq *priorityQueue) pop() *namedHandler { + if len(pq.heap) < 1 { + return nil + } + return heap.Pop(pq).(*namedHandler) +} - return strconv.FormatUint(hasher.Sum64(), 16) +func (pq *priorityQueue) fix(nh *namedHandler) { + heap.Fix(pq, nh.queueIdx) +} + +// Len implements heap.Interface/sort.Interface. +func (pq *priorityQueue) Len() int { return len(pq.heap) } + +// Less implements heap.Interface/sort.Interface. +func (pq *priorityQueue) Less(i, j int) bool { + nhi, nhj := pq.heap[i], pq.heap[j] + return float64(nhi.pending)/nhi.weight < float64(nhj.pending)/nhj.weight +} + +// Swap implements heap.Interface/sort.Interface. +func (pq *priorityQueue) Swap(i, j int) { + pq.heap[i], pq.heap[j] = pq.heap[j], pq.heap[i] + pq.heap[i].queueIdx = i + pq.heap[j].queueIdx = j +} + +// Push implements heap.Interface for pushing an item into the heap. +func (pq *priorityQueue) Push(x interface{}) { + nh := x.(*namedHandler) + nh.queueIdx = len(pq.heap) + pq.heap = append(pq.heap, nh) +} + +// Pop implements heap.Interface for popping an item from the heap. +// It panics if b.Len() < 1. +func (pq *priorityQueue) Pop() interface{} { + lastIdx := len(pq.heap) - 1 + nh := pq.heap[lastIdx] + pq.heap = pq.heap[0:lastIdx] + return nh } diff --git a/pkg/server/service/loadbalancer/wrr/wrr_test.go b/pkg/server/service/loadbalancer/wrr/wrr_test.go index 19f2cf38ef..e08fc9a3a3 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr_test.go +++ b/pkg/server/service/loadbalancer/wrr/wrr_test.go @@ -2,322 +2,205 @@ package wrr import ( "context" + "fmt" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" - "github.com/traefik/traefik/v2/pkg/config/dynamic" + "github.com/stretchr/testify/require" ) -func TestBalancer(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(3)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 4 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - } +const ( + handlerAny = "" +) - assert.Equal(t, 3, recorder.save["first"]) - assert.Equal(t, 1, recorder.save["second"]) +func TestBalancerWeights(t *testing.T) { + b := New(nil, false) + addDummyHandler(b, "A", 3) + addDummyHandler(b, "B", 1) + + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 1, "B": 0}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 2, "B": 0}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 2, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 3, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 4, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 5, "B": 1}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 2}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 6, "B": 2}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 7, "B": 2}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 8, "B": 2}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 8, "B": 3}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 9, "B": 3}) + assertRelease(t, b, "B", map[string]int{"A": 9, "B": 2}) + assertRelease(t, b, "B", map[string]int{"A": 9, "B": 1}) + assertRelease(t, b, "B", map[string]int{"A": 9, "B": 0}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 9, "B": 1}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 9, "B": 2}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 9, "B": 3}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 10, "B": 3}) } -func TestBalancerNoService(t *testing.T) { - balancer := New(nil, false) - - recorder := httptest.NewRecorder() - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - - assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) +func TestBalancerUpAndDown(t *testing.T) { + b := New(nil, false) + addDummyHandler(b, "A", 1) + addDummyHandler(b, "B", 1) + + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 1, "B": 0}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 1, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 2, "B": 1}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 2, "B": 2}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 3, "B": 2}) + b.SetStatus(context.Background(), "B", false) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 4, "B": 2}) + b.SetStatus(context.Background(), "B", false) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 5, "B": 2}) + b.SetStatus(context.Background(), "A", false) + _, err := b.acquireHandler(handlerAny) + assert.Equal(t, errNoAvailableServer, err) + assertRelease(t, b, "B", map[string]int{"A": 5, "B": 1}) + assertRelease(t, b, "A", map[string]int{"A": 4, "B": 1}) + assertRelease(t, b, "A", map[string]int{"A": 3, "B": 1}) + _, err = b.acquireHandler(handlerAny) + assert.Equal(t, errNoAvailableServer, err) + b.SetStatus(context.Background(), "A", true) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 4, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 5, "B": 1}) + b.SetStatus(context.Background(), "B", true) + b.SetStatus(context.Background(), "B", true) + b.SetStatus(context.Background(), "A", true) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 2}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 3}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 4}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 5}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 6, "B": 5}) } -func TestBalancerOneServerZeroWeight(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) +func TestBalancerZeroWeight(t *testing.T) { + b := New(nil, false) + addDummyHandler(b, "A", 0) + addDummyHandler(b, "B", 1) - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 3 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - } - - assert.Equal(t, 3, recorder.save["first"]) + assertAcquire(t, b, handlerAny, "B", map[string]int{"B": 1}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"B": 2}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"B": 3}) } -type key string - -const serviceName key = "serviceName" - -func TestBalancerNoServiceUp(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusInternalServerError) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusInternalServerError) - }), Int(1)) - - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "first", false) - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false) - - recorder := httptest.NewRecorder() - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - - assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) +func TestBalancerPropagate(t *testing.T) { + b := New(nil, true) + addDummyHandler(b, "A", 1) + addDummyHandler(b, "B", 1) + updates := []bool{} + err := b.RegisterStatusUpdater(func(healthy bool) { + updates = append(updates, healthy) + }) + require.NoError(t, err) + + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{}, updates) + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{}, updates) + b.SetStatus(context.Background(), "B", false) + assert.Equal(t, []bool{false}, updates) + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{false}, updates) + b.SetStatus(context.Background(), "B", false) + assert.Equal(t, []bool{false}, updates) + b.SetStatus(context.Background(), "B", true) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "B", true) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "A", true) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "B", false) + assert.Equal(t, []bool{false, true, false}, updates) } -func TestBalancerOneServerDown(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusInternalServerError) - }), Int(1)) - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 3 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - } - - assert.Equal(t, 3, recorder.save["first"]) +func TestBalancerSticky(t *testing.T) { + b := New(nil, false) + addDummyHandler(b, "A", 1) + addDummyHandler(b, "B", 1) + + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 1, "B": 0}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 1, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 2, "B": 1}) + assertAcquire(t, b, "A", "A", map[string]int{"A": 3, "B": 1}) + assertAcquire(t, b, "A", "A", map[string]int{"A": 4, "B": 1}) + assertAcquire(t, b, "A", "A", map[string]int{"A": 5, "B": 1}) + b.SetStatus(context.Background(), "A", false) + // Even though A is preferred B is allocated when A is not available. + assertAcquire(t, b, "A", "B", map[string]int{"A": 5, "B": 2}) + assertAcquire(t, b, "A", "B", map[string]int{"A": 5, "B": 3}) + b.SetStatus(context.Background(), "A", true) + assertAcquire(t, b, "A", "A", map[string]int{"A": 6, "B": 3}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 6, "B": 4}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 6, "B": 5}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 6, "B": 6}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 7, "B": 6}) } -func TestBalancerDownThenUp(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 3 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - } - assert.Equal(t, 3, recorder.save["first"]) - - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", true) - recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 2 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) +// When sticky sessions are allocated that does not mess up selection order. +// Internally heap is used and sticky allocation has to maintain correct +// ordering of handlers in the priority queue. +func TestBalancerMany(t *testing.T) { + b := New(nil, false) + for _, handlerName := range "ABCDEFGH" { + addDummyHandler(b, fmt.Sprintf("%c", handlerName), 1) } - assert.Equal(t, 1, recorder.save["first"]) - assert.Equal(t, 1, recorder.save["second"]) -} - -func TestBalancerPropagate(t *testing.T) { - balancer1 := New(nil, true) - - balancer1.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - balancer1.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer2 := New(nil, true) - balancer2.Add("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "third") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - balancer2.Add("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "fourth") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - topBalancer := New(nil, true) - topBalancer.Add("balancer1", balancer1, Int(1)) - _ = balancer1.RegisterStatusUpdater(func(up bool) { - topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer1", up) - // TODO(mpl): if test gets flaky, add channel or something here to signal that - // propagation is done, and wait on it before sending request. - }) - topBalancer.Add("balancer2", balancer2, Int(1)) - _ = balancer2.RegisterStatusUpdater(func(up bool) { - topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer2", up) - }) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 8 { - topBalancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + for i := 0; i < 100; i++ { + _, err := b.acquireHandler(handlerAny) + require.NoError(t, err) } - assert.Equal(t, 2, recorder.save["first"]) - assert.Equal(t, 2, recorder.save["second"]) - assert.Equal(t, 2, recorder.save["third"]) - assert.Equal(t, 2, recorder.save["fourth"]) - wantStatus := []int{200, 200, 200, 200, 200, 200, 200, 200} - assert.Equal(t, wantStatus, recorder.status) - - // fourth gets downed, but balancer2 still up since third is still up. - balancer2.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "fourth", false) - recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 8 { - topBalancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + assert.Equal(t, map[string]int{"A": 13, "B": 13, "C": 12, "D": 13, "E": 12, "F": 12, "G": 12, "H": 13}, pendingCounts(b)) + for i := 0; i < 10; i++ { + _, err := b.acquireHandler("D") + require.NoError(t, err) } - assert.Equal(t, 2, recorder.save["first"]) - assert.Equal(t, 2, recorder.save["second"]) - assert.Equal(t, 4, recorder.save["third"]) - assert.Equal(t, 0, recorder.save["fourth"]) - wantStatus = []int{200, 200, 200, 200, 200, 200, 200, 200} - assert.Equal(t, wantStatus, recorder.status) - - // third gets downed, and the propagation triggers balancer2 to be marked as - // down as well for topBalancer. - balancer2.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "third", false) - recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 8 { - topBalancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + assert.Equal(t, map[string]int{"A": 13, "B": 13, "C": 12, "D": 23, "E": 12, "F": 12, "G": 12, "H": 13}, pendingCounts(b)) + for i := 0; i < 74; i++ { + _, err := b.acquireHandler(handlerAny) + require.NoError(t, err) } - assert.Equal(t, 4, recorder.save["first"]) - assert.Equal(t, 4, recorder.save["second"]) - assert.Equal(t, 0, recorder.save["third"]) - assert.Equal(t, 0, recorder.save["fourth"]) - wantStatus = []int{200, 200, 200, 200, 200, 200, 200, 200} - assert.Equal(t, wantStatus, recorder.status) -} - -func TestBalancerAllServersZeroWeight(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) - balancer.Add("test2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) - - recorder := httptest.NewRecorder() - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - - assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) -} - -func TestSticky(t *testing.T) { - balancer := New(&dynamic.Sticky{ - Cookie: &dynamic.Cookie{Name: "test"}, - }, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(2)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - - req := httptest.NewRequest(http.MethodGet, "/", nil) - for range 3 { - for _, cookie := range recorder.Result().Cookies() { - assert.NotContains(t, "test=first", cookie.Value) - assert.NotContains(t, "test=second", cookie.Value) - req.AddCookie(cookie) - } - recorder.ResponseRecorder = httptest.NewRecorder() - - balancer.ServeHTTP(recorder, req) + assert.Equal(t, map[string]int{"A": 23, "B": 23, "C": 23, "D": 23, "E": 23, "F": 23, "G": 23, "H": 23}, pendingCounts(b)) + for i := 0; i < 8; i++ { + _, err := b.acquireHandler(handlerAny) + require.NoError(t, err) } - - assert.Equal(t, 0, recorder.save["first"]) - assert.Equal(t, 3, recorder.save["second"]) + assert.Equal(t, map[string]int{"A": 24, "B": 24, "C": 24, "D": 24, "E": 24, "F": 24, "G": 24, "H": 24}, pendingCounts(b)) } -func TestSticky_FallBack(t *testing.T) { - balancer := New(&dynamic.Sticky{ - Cookie: &dynamic.Cookie{Name: "test"}, - }, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") +func addDummyHandler(b *Balancer, handlerName string, weight int) { + h := func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", handlerName) rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(2)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.AddCookie(&http.Cookie{Name: "test", Value: "second"}) - for range 3 { - recorder.ResponseRecorder = httptest.NewRecorder() - - balancer.ServeHTTP(recorder, req) } - - assert.Equal(t, 0, recorder.save["first"]) - assert.Equal(t, 3, recorder.save["second"]) + b.Add(handlerName, http.HandlerFunc(h), &weight) } -// TestBalancerBias makes sure that the WRR algorithm spreads elements evenly right from the start, -// and that it does not "over-favor" the high-weighted ones with a biased start-up regime. -func TestBalancerBias(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "A") - rw.WriteHeader(http.StatusOK) - }), Int(11)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "B") - rw.WriteHeader(http.StatusOK) - }), Int(3)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - - for i := 0; i < 14; i++ { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) +func pendingCounts(b *Balancer) map[string]int { + countsByName := make(map[string]int) + b.mutex.Lock() + for handlerName, handler := range b.handlersByName { + countsByName[handlerName] = int(handler.pending) - 1 } - - wantSequence := []string{"A", "A", "A", "B", "A", "A", "A", "A", "B", "A", "A", "A", "B", "A"} - - assert.Equal(t, wantSequence, recorder.sequence) + b.mutex.Unlock() + return countsByName } -func Int(v int) *int { return &v } - -type responseRecorder struct { - *httptest.ResponseRecorder - save map[string]int - sequence []string - status []int +func assertAcquire(t *testing.T, b *Balancer, preferredName, acquiredName string, want map[string]int) { + nh, err := b.acquireHandler(preferredName) + require.NoError(t, err) + assert.Equal(t, acquiredName, nh.name) + assert.Equal(t, want, pendingCounts(b)) } -func (r *responseRecorder) WriteHeader(statusCode int) { - r.save[r.Header().Get("server")]++ - r.sequence = append(r.sequence, r.Header().Get("server")) - r.status = append(r.status, statusCode) - r.ResponseRecorder.WriteHeader(statusCode) +func assertRelease(t *testing.T, b *Balancer, acquiredName string, want map[string]int) { + b.mutex.Lock() + nh := b.handlersByName[acquiredName] + b.mutex.Unlock() + b.releaseHandler(nh) + assert.Equal(t, want, pendingCounts(b)) }