From 0d3ce5e4b6cf674b42fb71440ccc2a0ec408e990 Mon Sep 17 00:00:00 2001 From: houpc Date: Thu, 31 Oct 2024 02:14:22 +0800 Subject: [PATCH] bugfix for remove_zero_valued_subgraphs! --- src/computational_graph/transform.jl | 7 +++++++ src/computational_graph/tree_properties.jl | 2 ++ 2 files changed, 9 insertions(+) diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index a763167b..9d2fb30e 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -399,6 +399,13 @@ function remove_zero_valued_subgraphs!(g::AbstractGraph) mask_zeros = findall(x -> x != zero(x), subg_fac) if isempty(mask_zeros) mask_zeros = [1] # retain eldest(g) if all subfactors are zero + else + if g.operator == Prod + idx = findfirst(x -> x==zero(x), subg_fac) + if !isnothing(idx) + append!(mask_zeros, idx) + end + end end set_subgraphs!(g, subg[mask_zeros]) set_subgraph_factors!(g, subg_fac[mask_zeros]) diff --git a/src/computational_graph/tree_properties.jl b/src/computational_graph/tree_properties.jl index df91a36a..aac26841 100644 --- a/src/computational_graph/tree_properties.jl +++ b/src/computational_graph/tree_properties.jl @@ -94,6 +94,8 @@ end function has_zero_subfactors(g::AbstractGraph) if isleaf(g) return false # convention: subgraph_factors = [] ⟹ subfactorless = false + elseif g.operator == Prod && 0 in subgraph_factors(g) + return true else return iszero(subgraph_factors(g)) end