diff --git a/pkg/endpoints/manager.go b/pkg/endpoints/manager.go index bd2d3e39..ba41b78e 100644 --- a/pkg/endpoints/manager.go +++ b/pkg/endpoints/manager.go @@ -6,6 +6,7 @@ import ( "log" "sync" + corev1 "k8s.io/api/core/v1" disv1 "k8s.io/api/discovery/v1" ctrl "sigs.k8s.io/controller-runtime" @@ -58,6 +59,8 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, return ctrl.Result{}, fmt.Errorf("listing endpointslices: %w", err) } + deployments := map[string]map[string]struct{}{} + ips := map[string]struct{}{} for _, sliceItem := range sliceList.Items { for _, endpointItem := range sliceItem.Endpoints { @@ -67,9 +70,24 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, } } ready := endpointItem.Conditions.Ready - if ready != nil && *ready { - for _, ip := range endpointItem.Addresses { - ips[ip] = struct{}{} + if ready != nil && *ready && endpointItem.TargetRef.Kind == "Pod" { + podName := endpointItem.TargetRef.Name + namespace := endpointItem.TargetRef.Namespace // Assuming the EndpointSlice and Pod are in the same namespace + // Fetch the Pod using the client + var pod corev1.Pod + if err := r.Get(ctx, client.ObjectKey{Name: podName, Namespace: namespace}, &pod); err != nil { + log.Printf("error fetching pod: %v\n", err) + continue + } + if pod.OwnerReferences != nil { + for _, owner := range pod.OwnerReferences { + if owner.Kind == "ReplicaSet" { + for _, ip := range endpointItem.Addresses { + ips[ip] = struct{}{} + } + deployments[owner.Name] = ips + } + } } } } @@ -86,25 +104,25 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, } } - priorLen := r.getEndpoints(serviceName).lenIPs() - r.getEndpoints(serviceName).setIPs(ips, ports) + for deploymentName, ips := range deployments { + priorLen := r.getEndpoints(deploymentName).lenIPs() + r.getEndpoints(deploymentName).setIPs(ips, ports) + + if priorLen != len(ips) { + r.EndpointSizeCallback(deploymentName, len(ips)) + } - if priorLen != len(ips) { - // TODO: Currently Service name needs to match Deployment name, however - // this shouldn't be the case. We should be able to reference deployment - // replicas by something else. - r.EndpointSizeCallback(serviceName, len(ips)) } return ctrl.Result{}, nil } -func (r *Manager) getEndpoints(service string) *endpointGroup { +func (r *Manager) getEndpoints(deploymentName string) *endpointGroup { r.endpointsMtx.Lock() - e, ok := r.endpoints[service] + e, ok := r.endpoints[deploymentName] if !ok { e = newEndpointGroup() - r.endpoints[service] = e + r.endpoints[deploymentName] = e } r.endpointsMtx.Unlock() return e @@ -114,11 +132,11 @@ func (r *Manager) getEndpoints(service string) *endpointGroup { // becomes available or the context times out. // // It returns a string in the format "host:port" or error on timeout -func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string) (string, error) { - return r.getEndpoints(service).getBestHost(ctx, portName) +func (r *Manager) AwaitHostAddress(ctx context.Context, deploymentName, portName string) (string, error) { + return r.getEndpoints(deploymentName).getBestHost(ctx, portName) } // GetAllHosts retrieves the list of all hosts for a given service and port. -func (r *Manager) GetAllHosts(service, portName string) []string { - return r.getEndpoints(service).getAllHosts(portName) +func (r *Manager) GetAllHosts(deploymentName, portName string) []string { + return r.getEndpoints(deploymentName).getAllHosts(portName) }