Skip to content

Commit

Permalink
Fix the view() plot for GBT
Browse files Browse the repository at this point in the history
  • Loading branch information
tlapusan authored and parrt committed Apr 28, 2023
1 parent 683cd48 commit c5c695d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
9 changes: 7 additions & 2 deletions dtreeviz/models/shadow_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,17 @@ def get_shadow_tree(tree_model, X_train, y_train, feature_names, target_name, cl
from dtreeviz.models import lightgbm_decision_tree
return lightgbm_decision_tree.ShadowLightGBMTree(tree_model, tree_index, X_train, y_train,
feature_names, target_name, class_names)
elif "tensorflow_decision_forests.keras.RandomForestModel" in str(type(tree_model)):
elif any(tf_model in str(type(tree_model)) for tf_model in ["tensorflow_decision_forests.keras.RandomForestModel",
"tensorflow_decision_forests.keras.GradientBoostedTreesModel"]):
from dtreeviz.models import tensorflow_decision_tree
return tensorflow_decision_tree.ShadowTensorflowTree(tree_model, tree_index, X_train, y_train,
feature_names, target_name, class_names)
else:
raise ValueError(
f"Tree model must be in (DecisionTreeRegressor, DecisionTreeClassifier, "
"xgboost.core.Booster, lightgbm.basic.Booster, pyspark DecisionTreeClassificationModel, "
f"pyspark DecisionTreeClassificationModel, tensorflow_decision_forests.keras.RandomForestModel) "
f"pyspark DecisionTreeClassificationModel, tensorflow_decision_forests.keras.RandomForestModel, "
f"tensorflow_decision_forests.keras.GradientBoostedTreesModel) "
f"but you passed a {tree_model.__class__.__name__}!")


Expand Down Expand Up @@ -560,6 +562,9 @@ def prediction_name(self) -> (str, None):
Return prediction class or value otherwise.
"""
if self.isclassifier():
# In a GBT model, the trees are always regressive trees (even if the GBT is a classifier).
if "tensorflow_decision_forests.keras.GradientBoostedTreesModel" in str(type(self.shadow_tree.tree_model)):
return round(self.prediction(), 6)
if self.shadow_tree.class_names is not None:
return self.shadow_tree.class_names[self.prediction()]
return self.prediction()
Expand Down
4 changes: 4 additions & 0 deletions dtreeviz/models/tensorflow_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def get_node_nsamples_by_class(self, id):

def get_prediction(self, id):
if self.is_classifier():
# In a GBT model, the trees are always regressive trees (even if the GBT is a classifier). So we don't
# have the probability attribute
if "tensorflow_decision_forests.keras.GradientBoostedTreesModel" in str(type(self.model)):
return self.tree_nodes[id].value.value
return np.argmax(self.tree_nodes[id].value.probability)
else:
return self.tree_nodes[id].value.value
Expand Down

0 comments on commit c5c695d

Please sign in to comment.