From c830b8aaf7a8889e51a30972ad7b3a0e13e23a34 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sun, 28 Jul 2024 10:07:37 +0800 Subject: [PATCH] feat: support convert `mrs` format back to `text` format --- component/cidr/ipcidr_set.go | 10 ++++++++ component/trie/domain.go | 19 +++++++++++----- component/trie/domain_set.go | 38 ++++++++++++++++++++++++++++++- component/trie/domain_set_test.go | 20 ++++++++++++++++ component/trie/domain_test.go | 3 ++- docs/config.yaml | 15 ++++++++---- rules/provider/domain_strategy.go | 22 ++++++++++++++++++ rules/provider/ipcidr_strategy.go | 9 ++++++++ rules/provider/mrs_converter.go | 12 ++++++++++ rules/provider/provider.go | 1 + 10 files changed, 137 insertions(+), 12 deletions(-) diff --git a/component/cidr/ipcidr_set.go b/component/cidr/ipcidr_set.go index 521fabab13..4907146039 100644 --- a/component/cidr/ipcidr_set.go +++ b/component/cidr/ipcidr_set.go @@ -57,6 +57,16 @@ func (set *IpCidrSet) Merge() error { return nil } +func (set *IpCidrSet) Foreach(f func(prefix netip.Prefix) bool) { + for _, r := range set.rr { + for _, prefix := range r.Prefixes() { + if !f(prefix) { + return + } + } + } +} + // ToIPSet not safe convert to *netipx.IPSet // be careful, must be used after Merge func (set *IpCidrSet) ToIPSet() *netipx.IPSet { diff --git a/component/trie/domain.go b/component/trie/domain.go index 3decbb0255..db30402ede 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -123,16 +123,18 @@ func (t *DomainTrie[T]) Optimize() { t.root.optimize() } -func (t *DomainTrie[T]) Foreach(print func(domain string, data T)) { +func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) { for key, data := range t.root.getChildren() { - recursion([]string{key}, data, print) + recursion([]string{key}, data, fn) if data != nil && data.inited { - print(joinDomain([]string{key}), data.data) + if !fn(joinDomain([]string{key}), data.data) { + return + } } } } -func recursion[T any](items []string, node *Node[T], fn func(domain string, data T)) { +func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool { for key, data := range node.getChildren() { newItems := append([]string{key}, items...) if data != nil && data.inited { @@ -140,10 +142,15 @@ func recursion[T any](items []string, node *Node[T], fn func(domain string, data if domain[0] == domainStepByte { domain = complexWildcard + domain } - fn(domain, data.Data()) + if !fn(domain, data.Data()) { + return false + } + } + if !recursion(newItems, data, fn) { + return false } - recursion(newItems, data, fn) } + return true } func joinDomain(items []string) string { diff --git a/component/trie/domain_set.go b/component/trie/domain_set.go index 860d1235d5..7778d13379 100644 --- a/component/trie/domain_set.go +++ b/component/trie/domain_set.go @@ -28,8 +28,9 @@ type qElt struct{ s, e, col int } // NewDomainSet creates a new *DomainSet struct, from a DomainTrie. func (t *DomainTrie[T]) NewDomainSet() *DomainSet { reserveDomains := make([]string, 0) - t.Foreach(func(domain string, data T) { + t.Foreach(func(domain string, data T) bool { reserveDomains = append(reserveDomains, utils.Reverse(domain)) + return true }) // ensure that the same prefix is continuous // and according to the ascending sequence of length @@ -136,6 +137,41 @@ func (ss *DomainSet) Has(key string) bool { } +func (ss *DomainSet) keys(f func(key string) bool) { + var currentKey []byte + var traverse func(int, int) bool + traverse = func(nodeId, bmIdx int) bool { + if getBit(ss.leaves, nodeId) != 0 { + if !f(string(currentKey)) { + return false + } + } + + for ; ; bmIdx++ { + if getBit(ss.labelBitmap, bmIdx) != 0 { + return true + } + nextLabel := ss.labels[bmIdx-nodeId] + currentKey = append(currentKey, nextLabel) + nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) + nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1 + if !traverse(nextNodeId, nextBmIdx) { + return false + } + currentKey = currentKey[:len(currentKey)-1] + } + } + + traverse(0, 0) + return +} + +func (ss *DomainSet) Foreach(f func(key string) bool) { + ss.keys(func(key string) bool { + return f(utils.Reverse(key)) + }) +} + func setBit(bm *[]uint64, i int, v int) { for i>>6 >= len(*bm) { *bm = append(*bm, 0) diff --git a/component/trie/domain_set_test.go b/component/trie/domain_set_test.go index 77106d5ffc..e343d11d1c 100644 --- a/component/trie/domain_set_test.go +++ b/component/trie/domain_set_test.go @@ -1,12 +1,29 @@ package trie_test import ( + "golang.org/x/exp/slices" "testing" "github.com/metacubex/mihomo/component/trie" "github.com/stretchr/testify/assert" ) +func testDump(t *testing.T, tree *trie.DomainTrie[struct{}], set *trie.DomainSet) { + var dataSrc []string + tree.Foreach(func(domain string, data struct{}) bool { + dataSrc = append(dataSrc, domain) + return true + }) + slices.Sort(dataSrc) + var dataSet []string + set.Foreach(func(key string) bool { + dataSet = append(dataSet, key) + return true + }) + slices.Sort(dataSet) + assert.Equal(t, dataSrc, dataSet) +} + func TestDomainSet(t *testing.T) { tree := trie.New[struct{}]() domainSet := []string{ @@ -33,6 +50,7 @@ func TestDomainSet(t *testing.T) { assert.True(t, set.Has("google.com")) assert.False(t, set.Has("qq.com")) assert.False(t, set.Has("www.baidu.com")) + testDump(t, tree, set) } func TestDomainSetComplexWildcard(t *testing.T) { @@ -55,6 +73,7 @@ func TestDomainSetComplexWildcard(t *testing.T) { assert.False(t, set.Has("google.com")) assert.True(t, set.Has("www.baidu.com")) assert.True(t, set.Has("test.test.baidu.com")) + testDump(t, tree, set) } func TestDomainSetWildcard(t *testing.T) { @@ -82,4 +101,5 @@ func TestDomainSetWildcard(t *testing.T) { assert.False(t, set.Has("a.www.google.com")) assert.False(t, set.Has("test.qq.com")) assert.False(t, set.Has("test.test.test.qq.com")) + testDump(t, tree, set) } diff --git a/component/trie/domain_test.go b/component/trie/domain_test.go index 4c5d8002d8..916f61076d 100644 --- a/component/trie/domain_test.go +++ b/component/trie/domain_test.go @@ -121,8 +121,9 @@ func TestTrie_Foreach(t *testing.T) { assert.NoError(t, tree.Insert(domain, localIP)) } count := 0 - tree.Foreach(func(domain string, data netip.Addr) { + tree.Foreach(func(domain string, data netip.Addr) bool { count++ + return true }) assert.Equal(t, 7, count) } diff --git a/docs/config.yaml b/docs/config.yaml index 669c8be7ef..d7c686d01f 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -944,10 +944,17 @@ rule-providers: type: file rule3: # mrs类型ruleset,目前仅支持domain和ipcidr(即不支持classical), - # behavior=domain,format=yaml 可以通过“mihomo convert-ruleset domain yaml XXX.yaml XXX.mrs”转换得到 - # behavior=domain,format=text 可以通过“mihomo convert-ruleset domain text XXX.text XXX.mrs”转换得到 - # behavior=ipcidr,format=yaml 可以通过“mihomo convert-ruleset ipcidr yaml XXX.yaml XXX.mrs”转换得到 - # behavior=ipcidr,format=text 可以通过“mihomo convert-ruleset ipcidr text XXX.text XXX.mrs”转换得到 + # + # 对于behavior=domain: + # - format=yaml 可以通过“mihomo convert-ruleset domain yaml XXX.yaml XXX.mrs”转换到mrs格式 + # - format=text 可以通过“mihomo convert-ruleset domain text XXX.text XXX.mrs”转换到mrs格式 + # - XXX.mrs 可以通过"mihomo convert-ruleset domain mrs XXX.mrs XXX.text"转换回text格式(暂不支持转换回ymal格式) + # + # 对于behavior=ipcidr: + # - format=yaml 可以通过“mihomo convert-ruleset ipcidr yaml XXX.yaml XXX.mrs”转换到mrs格式 + # - format=text 可以通过“mihomo convert-ruleset ipcidr text XXX.text XXX.mrs”转换到mrs格式 + # - XXX.mrs 可以通过"mihomo convert-ruleset ipcidr mrs XXX.mrs XXX.text"转换回text格式(暂不支持转换回ymal格式) + # type: http url: "url" format: mrs diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index a999f5bd1c..b893f038b2 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -9,6 +9,8 @@ import ( C "github.com/metacubex/mihomo/constant" P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/log" + + "golang.org/x/exp/slices" ) type domainStrategy struct { @@ -78,6 +80,26 @@ func (d *domainStrategy) WriteMrs(w io.Writer) error { return d.domainSet.WriteBin(w) } +func (d *domainStrategy) DumpMrs(f func(key string) bool) { + if d.domainSet != nil { + var keys []string + d.domainSet.Foreach(func(key string) bool { + keys = append(keys, key) + return true + }) + slices.Sort(keys) + + for _, key := range keys { + if _, ok := slices.BinarySearch(keys, "+."+key); ok { + continue // ignore the rules added by trie internal processing + } + if !f(key) { + return + } + } + } +} + var _ mrsRuleStrategy = (*domainStrategy)(nil) func NewDomainStrategy() *domainStrategy { diff --git a/rules/provider/ipcidr_strategy.go b/rules/provider/ipcidr_strategy.go index 87cf7a2d8e..9efffed96b 100644 --- a/rules/provider/ipcidr_strategy.go +++ b/rules/provider/ipcidr_strategy.go @@ -3,6 +3,7 @@ package provider import ( "errors" "io" + "net/netip" "github.com/metacubex/mihomo/component/cidr" C "github.com/metacubex/mihomo/constant" @@ -82,6 +83,14 @@ func (i *ipcidrStrategy) WriteMrs(w io.Writer) error { return i.cidrSet.WriteBin(w) } +func (i *ipcidrStrategy) DumpMrs(f func(key string) bool) { + if i.cidrSet != nil { + i.cidrSet.Foreach(func(prefix netip.Prefix) bool { + return f(prefix.String()) + }) + } +} + func (i *ipcidrStrategy) ToIpCidr() *netipx.IPSet { return i.cidrSet.ToIPSet() } diff --git a/rules/provider/mrs_converter.go b/rules/provider/mrs_converter.go index a08301982e..edc24e7eea 100644 --- a/rules/provider/mrs_converter.go +++ b/rules/provider/mrs_converter.go @@ -3,6 +3,7 @@ package provider import ( "encoding/binary" "errors" + "fmt" "io" "os" @@ -21,6 +22,17 @@ func ConvertToMrs(buf []byte, behavior P.RuleBehavior, format P.RuleFormat, w io return errors.New("empty rule") } if _strategy, ok := strategy.(mrsRuleStrategy); ok { + if format == P.MrsRule { // export to TextRule + _strategy.DumpMrs(func(key string) bool { + _, err = fmt.Fprintln(w, key) + if err != nil { + return false + } + return true + }) + return nil + } + var encoder *zstd.Encoder encoder, err = zstd.NewWriter(w) if err != nil { diff --git a/rules/provider/provider.go b/rules/provider/provider.go index 8c5d7f9407..b9524c35e6 100644 --- a/rules/provider/provider.go +++ b/rules/provider/provider.go @@ -58,6 +58,7 @@ type mrsRuleStrategy interface { ruleStrategy FromMrs(r io.Reader, count int) error WriteMrs(w io.Writer) error + DumpMrs(f func(key string) bool) } func (rp *ruleSetProvider) Type() P.ProviderType {