diff --git a/adapter/inbound.go b/adapter/inbound.go index 2d24083c4a..f373f0fa0a 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -46,7 +46,6 @@ type InboundContext struct { SourceGeoIPCode string GeoIPCode string ProcessInfo *process.Info - FakeIP bool // dns cache diff --git a/route/router.go b/route/router.go index 2d2f2cde57..d1d7c0ac1c 100644 --- a/route/router.go +++ b/route/router.go @@ -625,7 +625,6 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad Fqdn: domain, Port: metadata.Destination.Port, } - metadata.FakeIP = true r.logger.DebugContext(ctx, "found fakeip domain: ", domain) } @@ -703,6 +702,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad } func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + var rewriteDestination bool if metadata.InboundDetour != "" { if metadata.LastInbound == metadata.InboundDetour { return E.New("routing loop on detour: ", metadata.InboundDetour) @@ -740,7 +740,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m Fqdn: domain, Port: metadata.Destination.Port, } - metadata.FakeIP = true + rewriteDestination = true r.logger.DebugContext(ctx, "found fakeip domain: ", domain) } @@ -766,6 +766,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m metadata.Protocol = sniffMetadata.Protocol metadata.Domain = sniffMetadata.Domain if metadata.InboundOptions.SniffOverrideDestination && M.IsDomainName(metadata.Domain) { + rewriteDestination = true metadata.Destination = M.Socksaddr{ Fqdn: metadata.Domain, Port: metadata.Destination.Port, @@ -817,7 +818,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m conn = statsService.RoutedPacketConnection(metadata.Inbound, detour.Tag(), metadata.User, conn) } } - if metadata.FakeIP { + if rewriteDestination { conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) } return detour.NewPacketConnection(ctx, conn, metadata) diff --git a/route/router_dns.go b/route/router_dns.go index 1532df94f8..800da5bfd7 100644 --- a/route/router_dns.go +++ b/route/router_dns.go @@ -37,7 +37,7 @@ func (m *DNSReverseMapping) Query(address netip.Addr) (string, bool) { return domain, loaded } -func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport, dns.DomainStrategy) { +func (r *Router) matchDNS(ctx context.Context, useFakeIP bool) (context.Context, dns.Transport, dns.DomainStrategy) { metadata := adapter.ContextFrom(ctx) if metadata == nil { panic("no context") @@ -50,7 +50,7 @@ func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport, r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour) continue } - if _, isFakeIP := transport.(adapter.FakeIPTransport); isFakeIP && metadata.FakeIP { + if _, isFakeIP := transport.(adapter.FakeIPTransport); isFakeIP && !useFakeIP { continue } r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) @@ -67,10 +67,20 @@ func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport, } } } - if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded { - return ctx, r.defaultTransport, domainStrategy + transport := r.defaultTransport + if _, isFakeIP := transport.(adapter.FakeIPTransport); isFakeIP && !useFakeIP { + for _, t := range r.transports { + if _, isFakeIP := t.(adapter.FakeIPTransport); isFakeIP { + continue + } + transport = t + break + } + } + if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded { + return ctx, transport, domainStrategy } else { - return ctx, r.defaultTransport, r.defaultDomainStrategy + return ctx, transport, r.defaultDomainStrategy } } @@ -96,7 +106,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er } metadata.Domain = fqdnToDomain(message.Question[0].Name) } - ctx, transport, strategy := r.matchDNS(ctx) + ctx, transport, strategy := r.matchDNS(ctx, true) ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout) defer cancel() response, err = r.dnsClient.Exchange(ctx, transport, message, strategy) @@ -124,7 +134,7 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS r.dnsLogger.DebugContext(ctx, "lookup domain ", domain) ctx, metadata := adapter.AppendContext(ctx) metadata.Domain = domain - ctx, transport, transportStrategy := r.matchDNS(ctx) + ctx, transport, transportStrategy := r.matchDNS(ctx, false) if strategy == dns.DomainStrategyAsIS { strategy = transportStrategy }