Skip to content

Commit

Permalink
fix: support LightGBM ensemble containing single leaf trees (stump) (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
thatlittleboy authored Jul 16, 2023
1 parent 123ceaa commit b9f29cc
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

- Fixed segmentation fault errors on our MacOS test suite involving `lightgbm`
([#3093](https://github.com/slundberg/shap/pull/3093) by @thatlittleboy).
- Add support for LightGBM ensembles containing single leaf trees in `TreeExplainer`
([#3094](https://github.com/slundberg/shap/pull/3094) by @thatlittleboy).

### Changed

Expand Down
11 changes: 9 additions & 2 deletions shap/explainers/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,9 @@ def buildTree(index, node):
queue.append(left_child)
queue.append(right_child)
else:
vleaf_idx: int = vertex["leaf_index"] + num_parents
# NOTE: If "leaf_index" is not present as a key, it means we have a
# stump tree. I.e., num_nodes=1.
vleaf_idx: int = vertex.get("leaf_index", 0) + num_parents
self.children_left[vleaf_idx] = -1
self.children_right[vleaf_idx] = -1
self.children_default[vleaf_idx] = -1
Expand All @@ -1476,7 +1478,12 @@ def buildTree(index, node):
self.features[vleaf_idx] = -1
self.thresholds[vleaf_idx] = -1
self.values[vleaf_idx] = [vertex["leaf_value"]]
self.node_sample_weight[vleaf_idx] = vertex["leaf_count"]
# FIXME: "leaf_count" currently doesn't exist if we have a stump tree.
# We should be technically be assigning the number of samples used to
# train the model as the weight here, but unfortunately this info is
# currently unavailable in `tree`, so we set to 0 first.
# cf. https://github.com/microsoft/LightGBM/issues/5962
self.node_sample_weight[vleaf_idx] = vertex.get("leaf_count", 0)
self.values = np.asarray(self.values)
self.values = np.multiply(self.values, scaling)

Expand Down
34 changes: 18 additions & 16 deletions tests/explainers/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,25 +776,26 @@ class TestSingleTree:
"""Tests for the SingleTree class."""

def test_singletree_lightgbm_basic(self):
"""A basic test for checking that a LightGBM dump_model() dictionary
is parsed properly into a SingleTree object.
"""A basic test for checking that a LightGBM `dump_model()["tree_info"]`
dictionary is parsed properly into a `SingleTree` object.
"""

# Stump (only root node) tree
# FIXME: this test should NOT throw a KeyError, see #2044
with pytest.raises(KeyError, match="leaf_index"):
sample_tree = {
"tree_index": 256,
"num_leaves": 1,
"num_cat": 0,
"shrinkage": 1,
"tree_structure": {
"leaf_value": 0,
},
}
stree = SingleTree(sample_tree)
# just ensure that this does not error out
assert stree.children_left[0] == -1
sample_tree = {
"tree_index": 256,
"num_leaves": 1,
"num_cat": 0,
"shrinkage": 1,
"tree_structure": {
"leaf_value": 0,
# "leaf_count": 123, # FIXME(upstream): microsoft/LightGBM#5962
},
}
stree = SingleTree(sample_tree)
# just ensure that this does not error out
assert stree.children_left[0] == -1
# assert stree.node_sample_weight[0] == 123
assert hasattr(stree, "values")

# Depth=1 tree
sample_tree = {
Expand All @@ -821,6 +822,7 @@ def test_singletree_lightgbm_basic(self):
stree = SingleTree(sample_tree)
# just ensure that the tree is parsed correctly
assert stree.node_sample_weight[0] == 100
assert hasattr(stree, "values")


class TestExplainerSklearn:
Expand Down

0 comments on commit b9f29cc

Please sign in to comment.