diff --git a/solution/0900-0999/0924.Minimize Malware Spread/README.md b/solution/0900-0999/0924.Minimize Malware Spread/README.md index 12fde869b6032..cc6aa79aeb349 100644 --- a/solution/0900-0999/0924.Minimize Malware Spread/README.md +++ b/solution/0900-0999/0924.Minimize Malware Spread/README.md @@ -63,206 +63,358 @@ ## 解法 -### 方法一 +### 方法一:并查集 + +根据题目描述,如果初始时有若干个节点属于同一个连通分量,那么一共可以分为三种情况: + +1. 这些节点中没有一个节点被感染 +1. 这些节点中只有一个节点被感染 +1. 这些节点中有多个节点被感染 + +我们要考虑的是,移除某个感染节点后,剩下的节点中被感染的节点数最少。 + +情况一没有被感染的节点,不需要考虑;情况二只有一个节点被感染,那么移除这个节点后,该连通分量中的其他节点都不会被感染;情况三有多个节点被感染,那么移除任意一个感染节点后,该连通分量中的其他节点还是会被感染,所以我们只需要考虑情况二。 + +我们利用并查集 $uf$ 维护节点的连通关系,用一个变量 $ans$ 记录答案,用一个变量 $mx$ 记录当前能减少感染的最大节点数,初始时 $ans = n$, $mx = 0$。 + +然后遍历数组 $initial$,用一个哈希表或者一个长度为 $n$ 的数组 $cnt$ 统计每个连通分量中被感染节点的个数。 + +接下来,我们再遍历数组 $initial$,对于每个节点 $x$,我们找到其所在的连通分量的根节点 $root$,如果该连通分量中只有一个被感染节点,即 $cnt[root] = 1$,我们就更新答案,更新的条件是该连通分量中的节点数 $sz$ 大于 $mx$ 或者 $sz$ 等于 $mx$ 且 $x$ 的值小于 $ans$。 + +最后,如果 $ans$ 没有被更新,说明所有的连通分量中都有多个被感染节点,那么我们返回 $initial$ 中的最小值,否则返回 $ans$。 + +时间复杂度 $O(n^2 \times \alpha(n))$,空间复杂度 $O(n)$。其中 $n$ 是节点的个数,而 $\alpha(n)$ 是 Ackermann 函数的反函数。 ```python +class UnionFind: + __slots__ = "p", "size" + + def __init__(self, n: int): + self.p = list(range(n)) + self.size = [1] * n + + def find(self, x: int) -> int: + if self.p[x] != x: + self.p[x] = self.find(self.p[x]) + return self.p[x] + + def union(self, a: int, b: int) -> bool: + pa, pb = self.find(a), self.find(b) + if pa == pb: + return False + if self.size[pa] > self.size[pb]: + self.p[pb] = pa + self.size[pa] += self.size[pb] + else: + self.p[pa] = pb + self.size[pb] += self.size[pa] + return True + + def get_size(self, root: int) -> int: + return self.size[root] + + class Solution: def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int: n = len(graph) - p = list(range(n)) - size = [1] * n - - def find(x): - if p[x] != x: - p[x] = find(p[x]) - return p[x] - + uf = UnionFind(n) for i in range(n): for j in range(i + 1, n): - if graph[i][j] == 1: - pa, pb = find(i), find(j) - if pa == pb: - continue - p[pa] = pb - size[pb] += size[pa] - - mi = inf - res = initial[0] - initial.sort() - for i in range(len(initial)): - t = 0 - s = set() - for j in range(len(initial)): - if i == j: - continue - if find(initial[j]) in s: - continue - s.add(find(initial[j])) - t += size[find(initial[j])] - if mi > t: - mi = t - res = initial[i] - return res + graph[i][j] and uf.union(i, j) + cnt = Counter(uf.find(x) for x in initial) + ans, mx = n, 0 + for x in initial: + root = uf.find(x) + if cnt[root] > 1: + continue + sz = uf.get_size(root) + if sz > mx or (sz == mx and x < ans): + ans = x + mx = sz + return min(initial) if ans == n else ans ``` ```java -class Solution { - private int[] p; +class UnionFind { + private final int[] p; + private final int[] size; - public int minMalwareSpread(int[][] graph, int[] initial) { - int n = graph.length; + public UnionFind(int n) { p = new int[n]; + size = new int[n]; for (int i = 0; i < n; ++i) { p[i] = i; + size[i] = 1; + } + } + + public int find(int x) { + if (p[x] != x) { + p[x] = find(p[x]); + } + return p[x]; + } + + public boolean union(int a, int b) { + int pa = find(a), pb = find(b); + if (pa == pb) { + return false; + } + if (size[pa] > size[pb]) { + p[pb] = pa; + size[pa] += size[pb]; + } else { + p[pa] = pb; + size[pb] += size[pa]; } - int[] size = new int[n]; - Arrays.fill(size, 1); + return true; + } + + public int size(int root) { + return size[root]; + } +} + +class Solution { + public int minMalwareSpread(int[][] graph, int[] initial) { + int n = graph.length; + UnionFind uf = new UnionFind(n); for (int i = 0; i < n; ++i) { for (int j = i + 1; j < n; ++j) { if (graph[i][j] == 1) { - int pa = find(i), pb = find(j); - if (pa == pb) { - continue; - } - p[pa] = pb; - size[pb] += size[pa]; + uf.union(i, j); } } } - int mi = Integer.MAX_VALUE; - int res = initial[0]; - Arrays.sort(initial); - for (int i = 0; i < initial.length; ++i) { - int t = 0; - Set s = new HashSet<>(); - for (int j = 0; j < initial.length; ++j) { - if (i == j) { - continue; - } - if (s.contains(find(initial[j]))) { - continue; + int ans = n; + int mi = n, mx = 0; + int[] cnt = new int[n]; + for (int x : initial) { + ++cnt[uf.find(x)]; + mi = Math.min(mi, x); + } + + for (int x : initial) { + int root = uf.find(x); + if (cnt[root] == 1) { + int sz = uf.size(root); + if (sz > mx || (sz == mx && x < ans)) { + ans = x; + mx = sz; } - s.add(find(initial[j])); - t += size[find(initial[j])]; - } - if (mi > t) { - mi = t; - res = initial[i]; } } - return res; + return ans == n ? mi : ans; + } +} +``` + +```cpp +class UnionFind { +public: + UnionFind(int n) { + p = vector(n); + size = vector(n, 1); + iota(p.begin(), p.end(), 0); + } + + bool unite(int a, int b) { + int pa = find(a), pb = find(b); + if (pa == pb) { + return false; + } + if (size[pa] > size[pb]) { + p[pb] = pa; + size[pa] += size[pb]; + } else { + p[pa] = pb; + size[pb] += size[pa]; + } + return true; } - private int find(int x) { + int find(int x) { if (p[x] != x) { p[x] = find(p[x]); } return p[x]; } -} -``` -```cpp + int getSize(int root) { + return size[root]; + } + +private: + vector p, size; +}; + class Solution { public: - vector p; - int minMalwareSpread(vector>& graph, vector& initial) { int n = graph.size(); - p.resize(n); - for (int i = 0; i < n; ++i) p[i] = i; - vector size(n, 1); + UnionFind uf(n); for (int i = 0; i < n; ++i) { for (int j = i + 1; j < n; ++j) { if (graph[i][j]) { - int pa = find(i), pb = find(j); - if (pa == pb) continue; - p[pa] = pb; - size[pb] += size[pa]; + uf.unite(i, j); } } } - int mi = 400; - int res = initial[0]; - sort(initial.begin(), initial.end()); - for (int i = 0; i < initial.size(); ++i) { - int t = 0; - unordered_set s; - for (int j = 0; j < initial.size(); ++j) { - if (i == j) continue; - if (s.count(find(initial[j]))) continue; - s.insert(find(initial[j])); - t += size[find(initial[j])]; - } - if (mi > t) { - mi = t; - res = initial[i]; + int ans = n, mx = 0; + vector cnt(n); + for (int x : initial) { + ++cnt[uf.find(x)]; + } + for (int x : initial) { + int root = uf.find(x); + if (cnt[root] == 1) { + int sz = uf.getSize(root); + if (sz > mx || (sz == mx && ans > x)) { + ans = x; + mx = sz; + } } } - return res; - } - - int find(int x) { - if (p[x] != x) p[x] = find(p[x]); - return p[x]; + return ans == n ? *min_element(initial.begin(), initial.end()) : ans; } }; ``` ```go -var p []int +type unionFind struct { + p, size []int +} -func minMalwareSpread(graph [][]int, initial []int) int { - n := len(graph) - p = make([]int, n) +func newUnionFind(n int) *unionFind { + p := make([]int, n) size := make([]int, n) - for i := 0; i < n; i++ { + for i := range p { p[i] = i size[i] = 1 } - for i := 0; i < n; i++ { + return &unionFind{p, size} +} + +func (uf *unionFind) find(x int) int { + if uf.p[x] != x { + uf.p[x] = uf.find(uf.p[x]) + } + return uf.p[x] +} + +func (uf *unionFind) union(a, b int) bool { + pa, pb := uf.find(a), uf.find(b) + if pa == pb { + return false + } + if uf.size[pa] > uf.size[pb] { + uf.p[pb] = pa + uf.size[pa] += uf.size[pb] + } else { + uf.p[pa] = pb + uf.size[pb] += uf.size[pa] + } + return true +} + +func (uf *unionFind) getSize(root int) int { + return uf.size[root] +} + +func minMalwareSpread(graph [][]int, initial []int) int { + n := len(graph) + uf := newUnionFind(n) + for i := range graph { for j := i + 1; j < n; j++ { if graph[i][j] == 1 { - pa, pb := find(i), find(j) - if pa == pb { - continue - } - p[pa] = pb - size[pb] += size[pa] + uf.union(i, j) } } } - mi := 400 - res := initial[0] - sort.Ints(initial) - for i := 0; i < len(initial); i++ { - t := 0 - s := make(map[int]bool) - for j := 0; j < len(initial); j++ { - if i == j { - continue - } - if s[find(initial[j])] { - continue + cnt := make([]int, n) + ans, mx := n, 0 + for _, x := range initial { + cnt[uf.find(x)]++ + } + for _, x := range initial { + root := uf.find(x) + if cnt[root] == 1 { + sz := uf.getSize(root) + if sz > mx || sz == mx && x < ans { + ans, mx = x, sz } - s[find(initial[j])] = true - t += size[find(initial[j])] - } - if mi > t { - mi = t - res = initial[i] } } - return res + if ans == n { + return slices.Min(initial) + } + return ans } +``` -func find(x int) int { - if p[x] != x { - p[x] = find(p[x]) - } - return p[x] +```ts +class UnionFind { + p: number[]; + size: number[]; + constructor(n: number) { + this.p = Array(n) + .fill(0) + .map((_, i) => i); + this.size = Array(n).fill(1); + } + + find(x: number): number { + if (this.p[x] !== x) { + this.p[x] = this.find(this.p[x]); + } + return this.p[x]; + } + + union(a: number, b: number): boolean { + const [pa, pb] = [this.find(a), this.find(b)]; + if (pa === pb) { + return false; + } + if (this.size[pa] > this.size[pb]) { + this.p[pb] = pa; + this.size[pa] += this.size[pb]; + } else { + this.p[pa] = pb; + this.size[pb] += this.size[pa]; + } + return true; + } + + getSize(root: number): number { + return this.size[root]; + } +} + +function minMalwareSpread(graph: number[][], initial: number[]): number { + const n = graph.length; + const uf = new UnionFind(n); + for (let i = 0; i < n; ++i) { + for (let j = i + 1; j < n; ++j) { + graph[i][j] && uf.union(i, j); + } + } + let [ans, mx] = [n, 0]; + const cnt: number[] = Array(n).fill(0); + for (const x of initial) { + ++cnt[uf.find(x)]; + } + for (const x of initial) { + const root = uf.find(x); + if (cnt[root] === 1) { + const sz = uf.getSize(root); + if (sz > mx || (sz === mx && x < ans)) { + [ans, mx] = [x, sz]; + } + } + } + return ans === n ? Math.min(...initial) : ans; } ``` diff --git a/solution/0900-0999/0924.Minimize Malware Spread/README_EN.md b/solution/0900-0999/0924.Minimize Malware Spread/README_EN.md index 6acd0fd23d894..414a8f02969fa 100644 --- a/solution/0900-0999/0924.Minimize Malware Spread/README_EN.md +++ b/solution/0900-0999/0924.Minimize Malware Spread/README_EN.md @@ -44,206 +44,358 @@ ## Solutions -### Solution 1 +### Solution 1: Union Find + +According to the problem description, if there are several nodes in the same connected component initially, there can be three situations: + +1. None of these nodes are infected. +2. Only one node among these nodes is infected. +3. Multiple nodes among these nodes are infected. + +What we need to consider is to minimize the number of infected nodes left after removing a certain infected node. + +For situation 1, there are no infected nodes, so we don't need to consider it; for situation 2, only one node is infected, so after removing this node, the other nodes in this connected component will not be infected; for situation 3, multiple nodes are infected, so after removing any infected node, the other nodes in this connected component will still be infected. Therefore, we only need to consider situation 2. + +We use a union find set $uf$ to maintain the connectivity of nodes, a variable $ans$ to record the answer, and a variable $mx$ to record the maximum number of infections that can be reduced currently. Initially, $ans = n$, $mx = 0$. + +Then we traverse the array $initial$, use a hash table or an array of length $n$ named $cnt$ to count the number of infected nodes in each connected component. + +Next, we traverse the array $initial$ again. For each node $x$, we find the root node $root$ of its connected component. If there is only one infected node in this connected component, i.e., $cnt[root] = 1$, we update the answer. The update condition is that the number of nodes $sz$ in this connected component is greater than $mx$ or $sz$ equals $mx$ and the value of $x$ is less than $ans$. + +Finally, if $ans$ has not been updated, it means that there are multiple infected nodes in all connected components, so we return the minimum value in $initial$, otherwise, we return $ans$. + +The time complexity is $O(n^2 \times \alpha(n))$, and the space complexity is $O(n)$. Where $n$ is the number of nodes, and $\alpha(n)$ is the inverse of the Ackermann function. ```python +class UnionFind: + __slots__ = "p", "size" + + def __init__(self, n: int): + self.p = list(range(n)) + self.size = [1] * n + + def find(self, x: int) -> int: + if self.p[x] != x: + self.p[x] = self.find(self.p[x]) + return self.p[x] + + def union(self, a: int, b: int) -> bool: + pa, pb = self.find(a), self.find(b) + if pa == pb: + return False + if self.size[pa] > self.size[pb]: + self.p[pb] = pa + self.size[pa] += self.size[pb] + else: + self.p[pa] = pb + self.size[pb] += self.size[pa] + return True + + def get_size(self, root: int) -> int: + return self.size[root] + + class Solution: def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int: n = len(graph) - p = list(range(n)) - size = [1] * n - - def find(x): - if p[x] != x: - p[x] = find(p[x]) - return p[x] - + uf = UnionFind(n) for i in range(n): for j in range(i + 1, n): - if graph[i][j] == 1: - pa, pb = find(i), find(j) - if pa == pb: - continue - p[pa] = pb - size[pb] += size[pa] - - mi = inf - res = initial[0] - initial.sort() - for i in range(len(initial)): - t = 0 - s = set() - for j in range(len(initial)): - if i == j: - continue - if find(initial[j]) in s: - continue - s.add(find(initial[j])) - t += size[find(initial[j])] - if mi > t: - mi = t - res = initial[i] - return res + graph[i][j] and uf.union(i, j) + cnt = Counter(uf.find(x) for x in initial) + ans, mx = n, 0 + for x in initial: + root = uf.find(x) + if cnt[root] > 1: + continue + sz = uf.get_size(root) + if sz > mx or (sz == mx and x < ans): + ans = x + mx = sz + return min(initial) if ans == n else ans ``` ```java -class Solution { - private int[] p; +class UnionFind { + private final int[] p; + private final int[] size; - public int minMalwareSpread(int[][] graph, int[] initial) { - int n = graph.length; + public UnionFind(int n) { p = new int[n]; + size = new int[n]; for (int i = 0; i < n; ++i) { p[i] = i; + size[i] = 1; + } + } + + public int find(int x) { + if (p[x] != x) { + p[x] = find(p[x]); + } + return p[x]; + } + + public boolean union(int a, int b) { + int pa = find(a), pb = find(b); + if (pa == pb) { + return false; } - int[] size = new int[n]; - Arrays.fill(size, 1); + if (size[pa] > size[pb]) { + p[pb] = pa; + size[pa] += size[pb]; + } else { + p[pa] = pb; + size[pb] += size[pa]; + } + return true; + } + + public int size(int root) { + return size[root]; + } +} + +class Solution { + public int minMalwareSpread(int[][] graph, int[] initial) { + int n = graph.length; + UnionFind uf = new UnionFind(n); for (int i = 0; i < n; ++i) { for (int j = i + 1; j < n; ++j) { if (graph[i][j] == 1) { - int pa = find(i), pb = find(j); - if (pa == pb) { - continue; - } - p[pa] = pb; - size[pb] += size[pa]; + uf.union(i, j); } } } - int mi = Integer.MAX_VALUE; - int res = initial[0]; - Arrays.sort(initial); - for (int i = 0; i < initial.length; ++i) { - int t = 0; - Set s = new HashSet<>(); - for (int j = 0; j < initial.length; ++j) { - if (i == j) { - continue; - } - if (s.contains(find(initial[j]))) { - continue; + int ans = n; + int mi = n, mx = 0; + int[] cnt = new int[n]; + for (int x : initial) { + ++cnt[uf.find(x)]; + mi = Math.min(mi, x); + } + + for (int x : initial) { + int root = uf.find(x); + if (cnt[root] == 1) { + int sz = uf.size(root); + if (sz > mx || (sz == mx && x < ans)) { + ans = x; + mx = sz; } - s.add(find(initial[j])); - t += size[find(initial[j])]; - } - if (mi > t) { - mi = t; - res = initial[i]; } } - return res; + return ans == n ? mi : ans; + } +} +``` + +```cpp +class UnionFind { +public: + UnionFind(int n) { + p = vector(n); + size = vector(n, 1); + iota(p.begin(), p.end(), 0); } - private int find(int x) { + bool unite(int a, int b) { + int pa = find(a), pb = find(b); + if (pa == pb) { + return false; + } + if (size[pa] > size[pb]) { + p[pb] = pa; + size[pa] += size[pb]; + } else { + p[pa] = pb; + size[pb] += size[pa]; + } + return true; + } + + int find(int x) { if (p[x] != x) { p[x] = find(p[x]); } return p[x]; } -} -``` -```cpp + int getSize(int root) { + return size[root]; + } + +private: + vector p, size; +}; + class Solution { public: - vector p; - int minMalwareSpread(vector>& graph, vector& initial) { int n = graph.size(); - p.resize(n); - for (int i = 0; i < n; ++i) p[i] = i; - vector size(n, 1); + UnionFind uf(n); for (int i = 0; i < n; ++i) { for (int j = i + 1; j < n; ++j) { if (graph[i][j]) { - int pa = find(i), pb = find(j); - if (pa == pb) continue; - p[pa] = pb; - size[pb] += size[pa]; + uf.unite(i, j); } } } - int mi = 400; - int res = initial[0]; - sort(initial.begin(), initial.end()); - for (int i = 0; i < initial.size(); ++i) { - int t = 0; - unordered_set s; - for (int j = 0; j < initial.size(); ++j) { - if (i == j) continue; - if (s.count(find(initial[j]))) continue; - s.insert(find(initial[j])); - t += size[find(initial[j])]; - } - if (mi > t) { - mi = t; - res = initial[i]; + int ans = n, mx = 0; + vector cnt(n); + for (int x : initial) { + ++cnt[uf.find(x)]; + } + for (int x : initial) { + int root = uf.find(x); + if (cnt[root] == 1) { + int sz = uf.getSize(root); + if (sz > mx || (sz == mx && ans > x)) { + ans = x; + mx = sz; + } } } - return res; - } - - int find(int x) { - if (p[x] != x) p[x] = find(p[x]); - return p[x]; + return ans == n ? *min_element(initial.begin(), initial.end()) : ans; } }; ``` ```go -var p []int +type unionFind struct { + p, size []int +} -func minMalwareSpread(graph [][]int, initial []int) int { - n := len(graph) - p = make([]int, n) +func newUnionFind(n int) *unionFind { + p := make([]int, n) size := make([]int, n) - for i := 0; i < n; i++ { + for i := range p { p[i] = i size[i] = 1 } - for i := 0; i < n; i++ { + return &unionFind{p, size} +} + +func (uf *unionFind) find(x int) int { + if uf.p[x] != x { + uf.p[x] = uf.find(uf.p[x]) + } + return uf.p[x] +} + +func (uf *unionFind) union(a, b int) bool { + pa, pb := uf.find(a), uf.find(b) + if pa == pb { + return false + } + if uf.size[pa] > uf.size[pb] { + uf.p[pb] = pa + uf.size[pa] += uf.size[pb] + } else { + uf.p[pa] = pb + uf.size[pb] += uf.size[pa] + } + return true +} + +func (uf *unionFind) getSize(root int) int { + return uf.size[root] +} + +func minMalwareSpread(graph [][]int, initial []int) int { + n := len(graph) + uf := newUnionFind(n) + for i := range graph { for j := i + 1; j < n; j++ { if graph[i][j] == 1 { - pa, pb := find(i), find(j) - if pa == pb { - continue - } - p[pa] = pb - size[pb] += size[pa] + uf.union(i, j) } } } - mi := 400 - res := initial[0] - sort.Ints(initial) - for i := 0; i < len(initial); i++ { - t := 0 - s := make(map[int]bool) - for j := 0; j < len(initial); j++ { - if i == j { - continue - } - if s[find(initial[j])] { - continue + cnt := make([]int, n) + ans, mx := n, 0 + for _, x := range initial { + cnt[uf.find(x)]++ + } + for _, x := range initial { + root := uf.find(x) + if cnt[root] == 1 { + sz := uf.getSize(root) + if sz > mx || sz == mx && x < ans { + ans, mx = x, sz } - s[find(initial[j])] = true - t += size[find(initial[j])] - } - if mi > t { - mi = t - res = initial[i] } } - return res + if ans == n { + return slices.Min(initial) + } + return ans } +``` -func find(x int) int { - if p[x] != x { - p[x] = find(p[x]) - } - return p[x] +```ts +class UnionFind { + p: number[]; + size: number[]; + constructor(n: number) { + this.p = Array(n) + .fill(0) + .map((_, i) => i); + this.size = Array(n).fill(1); + } + + find(x: number): number { + if (this.p[x] !== x) { + this.p[x] = this.find(this.p[x]); + } + return this.p[x]; + } + + union(a: number, b: number): boolean { + const [pa, pb] = [this.find(a), this.find(b)]; + if (pa === pb) { + return false; + } + if (this.size[pa] > this.size[pb]) { + this.p[pb] = pa; + this.size[pa] += this.size[pb]; + } else { + this.p[pa] = pb; + this.size[pb] += this.size[pa]; + } + return true; + } + + getSize(root: number): number { + return this.size[root]; + } +} + +function minMalwareSpread(graph: number[][], initial: number[]): number { + const n = graph.length; + const uf = new UnionFind(n); + for (let i = 0; i < n; ++i) { + for (let j = i + 1; j < n; ++j) { + graph[i][j] && uf.union(i, j); + } + } + let [ans, mx] = [n, 0]; + const cnt: number[] = Array(n).fill(0); + for (const x of initial) { + ++cnt[uf.find(x)]; + } + for (const x of initial) { + const root = uf.find(x); + if (cnt[root] === 1) { + const sz = uf.getSize(root); + if (sz > mx || (sz === mx && x < ans)) { + [ans, mx] = [x, sz]; + } + } + } + return ans === n ? Math.min(...initial) : ans; } ``` diff --git a/solution/0900-0999/0924.Minimize Malware Spread/Solution.cpp b/solution/0900-0999/0924.Minimize Malware Spread/Solution.cpp index 6da9b93e061d7..4398509cd8877 100644 --- a/solution/0900-0999/0924.Minimize Malware Spread/Solution.cpp +++ b/solution/0900-0999/0924.Minimize Malware Spread/Solution.cpp @@ -1,44 +1,68 @@ -class Solution { +class UnionFind { public: - vector p; + UnionFind(int n) { + p = vector(n); + size = vector(n, 1); + iota(p.begin(), p.end(), 0); + } + + bool unite(int a, int b) { + int pa = find(a), pb = find(b); + if (pa == pb) { + return false; + } + if (size[pa] > size[pb]) { + p[pb] = pa; + size[pa] += size[pb]; + } else { + p[pa] = pb; + size[pb] += size[pa]; + } + return true; + } + int find(int x) { + if (p[x] != x) { + p[x] = find(p[x]); + } + return p[x]; + } + + int getSize(int root) { + return size[root]; + } + +private: + vector p, size; +}; + +class Solution { +public: int minMalwareSpread(vector>& graph, vector& initial) { int n = graph.size(); - p.resize(n); - for (int i = 0; i < n; ++i) p[i] = i; - vector size(n, 1); + UnionFind uf(n); for (int i = 0; i < n; ++i) { for (int j = i + 1; j < n; ++j) { if (graph[i][j]) { - int pa = find(i), pb = find(j); - if (pa == pb) continue; - p[pa] = pb; - size[pb] += size[pa]; + uf.unite(i, j); } } } - int mi = 400; - int res = initial[0]; - sort(initial.begin(), initial.end()); - for (int i = 0; i < initial.size(); ++i) { - int t = 0; - unordered_set s; - for (int j = 0; j < initial.size(); ++j) { - if (i == j) continue; - if (s.count(find(initial[j]))) continue; - s.insert(find(initial[j])); - t += size[find(initial[j])]; - } - if (mi > t) { - mi = t; - res = initial[i]; + int ans = n, mx = 0; + vector cnt(n); + for (int x : initial) { + ++cnt[uf.find(x)]; + } + for (int x : initial) { + int root = uf.find(x); + if (cnt[root] == 1) { + int sz = uf.getSize(root); + if (sz > mx || (sz == mx && ans > x)) { + ans = x; + mx = sz; + } } } - return res; - } - - int find(int x) { - if (p[x] != x) p[x] = find(p[x]); - return p[x]; + return ans == n ? *min_element(initial.begin(), initial.end()) : ans; } }; \ No newline at end of file diff --git a/solution/0900-0999/0924.Minimize Malware Spread/Solution.go b/solution/0900-0999/0924.Minimize Malware Spread/Solution.go index 433d7e5f0e26a..787d34a67870e 100644 --- a/solution/0900-0999/0924.Minimize Malware Spread/Solution.go +++ b/solution/0900-0999/0924.Minimize Malware Spread/Solution.go @@ -1,52 +1,69 @@ -var p []int +type unionFind struct { + p, size []int +} -func minMalwareSpread(graph [][]int, initial []int) int { - n := len(graph) - p = make([]int, n) +func newUnionFind(n int) *unionFind { + p := make([]int, n) size := make([]int, n) - for i := 0; i < n; i++ { + for i := range p { p[i] = i size[i] = 1 } - for i := 0; i < n; i++ { + return &unionFind{p, size} +} + +func (uf *unionFind) find(x int) int { + if uf.p[x] != x { + uf.p[x] = uf.find(uf.p[x]) + } + return uf.p[x] +} + +func (uf *unionFind) union(a, b int) bool { + pa, pb := uf.find(a), uf.find(b) + if pa == pb { + return false + } + if uf.size[pa] > uf.size[pb] { + uf.p[pb] = pa + uf.size[pa] += uf.size[pb] + } else { + uf.p[pa] = pb + uf.size[pb] += uf.size[pa] + } + return true +} + +func (uf *unionFind) getSize(root int) int { + return uf.size[root] +} + +func minMalwareSpread(graph [][]int, initial []int) int { + n := len(graph) + uf := newUnionFind(n) + for i := range graph { for j := i + 1; j < n; j++ { if graph[i][j] == 1 { - pa, pb := find(i), find(j) - if pa == pb { - continue - } - p[pa] = pb - size[pb] += size[pa] + uf.union(i, j) } } } - mi := 400 - res := initial[0] - sort.Ints(initial) - for i := 0; i < len(initial); i++ { - t := 0 - s := make(map[int]bool) - for j := 0; j < len(initial); j++ { - if i == j { - continue - } - if s[find(initial[j])] { - continue + cnt := make([]int, n) + ans, mx := n, 0 + for _, x := range initial { + cnt[uf.find(x)]++ + } + for _, x := range initial { + root := uf.find(x) + if cnt[root] == 1 { + sz := uf.getSize(root) + if sz > mx || sz == mx && x < ans { + ans, mx = x, sz } - s[find(initial[j])] = true - t += size[find(initial[j])] - } - if mi > t { - mi = t - res = initial[i] } } - return res -} - -func find(x int) int { - if p[x] != x { - p[x] = find(p[x]) + if ans == n { + return slices.Min(initial) } - return p[x] + return ans } \ No newline at end of file diff --git a/solution/0900-0999/0924.Minimize Malware Spread/Solution.java b/solution/0900-0999/0924.Minimize Malware Spread/Solution.java index 72b684fafcf04..c9879ea0a7c9e 100644 --- a/solution/0900-0999/0924.Minimize Malware Spread/Solution.java +++ b/solution/0900-0999/0924.Minimize Malware Spread/Solution.java @@ -1,54 +1,72 @@ -class Solution { - private int[] p; +class UnionFind { + private final int[] p; + private final int[] size; - public int minMalwareSpread(int[][] graph, int[] initial) { - int n = graph.length; + public UnionFind(int n) { p = new int[n]; + size = new int[n]; for (int i = 0; i < n; ++i) { p[i] = i; + size[i] = 1; + } + } + + public int find(int x) { + if (p[x] != x) { + p[x] = find(p[x]); } - int[] size = new int[n]; - Arrays.fill(size, 1); + return p[x]; + } + + public boolean union(int a, int b) { + int pa = find(a), pb = find(b); + if (pa == pb) { + return false; + } + if (size[pa] > size[pb]) { + p[pb] = pa; + size[pa] += size[pb]; + } else { + p[pa] = pb; + size[pb] += size[pa]; + } + return true; + } + + public int size(int root) { + return size[root]; + } +} + +class Solution { + public int minMalwareSpread(int[][] graph, int[] initial) { + int n = graph.length; + UnionFind uf = new UnionFind(n); for (int i = 0; i < n; ++i) { for (int j = i + 1; j < n; ++j) { if (graph[i][j] == 1) { - int pa = find(i), pb = find(j); - if (pa == pb) { - continue; - } - p[pa] = pb; - size[pb] += size[pa]; + uf.union(i, j); } } } - int mi = Integer.MAX_VALUE; - int res = initial[0]; - Arrays.sort(initial); - for (int i = 0; i < initial.length; ++i) { - int t = 0; - Set s = new HashSet<>(); - for (int j = 0; j < initial.length; ++j) { - if (i == j) { - continue; - } - if (s.contains(find(initial[j]))) { - continue; - } - s.add(find(initial[j])); - t += size[find(initial[j])]; - } - if (mi > t) { - mi = t; - res = initial[i]; - } + int ans = n; + int mi = n, mx = 0; + int[] cnt = new int[n]; + for (int x : initial) { + ++cnt[uf.find(x)]; + mi = Math.min(mi, x); } - return res; - } - private int find(int x) { - if (p[x] != x) { - p[x] = find(p[x]); + for (int x : initial) { + int root = uf.find(x); + if (cnt[root] == 1) { + int sz = uf.size(root); + if (sz > mx || (sz == mx && x < ans)) { + ans = x; + mx = sz; + } + } } - return p[x]; + return ans == n ? mi : ans; } } \ No newline at end of file diff --git a/solution/0900-0999/0924.Minimize Malware Spread/Solution.py b/solution/0900-0999/0924.Minimize Malware Spread/Solution.py index d901ac647c9d4..044d2ac138e27 100644 --- a/solution/0900-0999/0924.Minimize Malware Spread/Solution.py +++ b/solution/0900-0999/0924.Minimize Malware Spread/Solution.py @@ -1,37 +1,46 @@ +class UnionFind: + __slots__ = "p", "size" + + def __init__(self, n: int): + self.p = list(range(n)) + self.size = [1] * n + + def find(self, x: int) -> int: + if self.p[x] != x: + self.p[x] = self.find(self.p[x]) + return self.p[x] + + def union(self, a: int, b: int) -> bool: + pa, pb = self.find(a), self.find(b) + if pa == pb: + return False + if self.size[pa] > self.size[pb]: + self.p[pb] = pa + self.size[pa] += self.size[pb] + else: + self.p[pa] = pb + self.size[pb] += self.size[pa] + return True + + def get_size(self, root: int) -> int: + return self.size[root] + + class Solution: def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int: n = len(graph) - p = list(range(n)) - size = [1] * n - - def find(x): - if p[x] != x: - p[x] = find(p[x]) - return p[x] - + uf = UnionFind(n) for i in range(n): for j in range(i + 1, n): - if graph[i][j] == 1: - pa, pb = find(i), find(j) - if pa == pb: - continue - p[pa] = pb - size[pb] += size[pa] - - mi = inf - res = initial[0] - initial.sort() - for i in range(len(initial)): - t = 0 - s = set() - for j in range(len(initial)): - if i == j: - continue - if find(initial[j]) in s: - continue - s.add(find(initial[j])) - t += size[find(initial[j])] - if mi > t: - mi = t - res = initial[i] - return res + graph[i][j] and uf.union(i, j) + cnt = Counter(uf.find(x) for x in initial) + ans, mx = n, 0 + for x in initial: + root = uf.find(x) + if cnt[root] > 1: + continue + sz = uf.get_size(root) + if sz > mx or (sz == mx and x < ans): + ans = x + mx = sz + return min(initial) if ans == n else ans diff --git a/solution/0900-0999/0924.Minimize Malware Spread/Solution.ts b/solution/0900-0999/0924.Minimize Malware Spread/Solution.ts new file mode 100644 index 0000000000000..10cca6996fde8 --- /dev/null +++ b/solution/0900-0999/0924.Minimize Malware Spread/Solution.ts @@ -0,0 +1,61 @@ +class UnionFind { + p: number[]; + size: number[]; + constructor(n: number) { + this.p = Array(n) + .fill(0) + .map((_, i) => i); + this.size = Array(n).fill(1); + } + + find(x: number): number { + if (this.p[x] !== x) { + this.p[x] = this.find(this.p[x]); + } + return this.p[x]; + } + + union(a: number, b: number): boolean { + const [pa, pb] = [this.find(a), this.find(b)]; + if (pa === pb) { + return false; + } + if (this.size[pa] > this.size[pb]) { + this.p[pb] = pa; + this.size[pa] += this.size[pb]; + } else { + this.p[pa] = pb; + this.size[pb] += this.size[pa]; + } + return true; + } + + getSize(root: number): number { + return this.size[root]; + } +} + +function minMalwareSpread(graph: number[][], initial: number[]): number { + const n = graph.length; + const uf = new UnionFind(n); + for (let i = 0; i < n; ++i) { + for (let j = i + 1; j < n; ++j) { + graph[i][j] && uf.union(i, j); + } + } + let [ans, mx] = [n, 0]; + const cnt: number[] = Array(n).fill(0); + for (const x of initial) { + ++cnt[uf.find(x)]; + } + for (const x of initial) { + const root = uf.find(x); + if (cnt[root] === 1) { + const sz = uf.getSize(root); + if (sz > mx || (sz === mx && x < ans)) { + [ans, mx] = [x, sz]; + } + } + } + return ans === n ? Math.min(...initial) : ans; +}