diff --git a/dtreeviz/models/spark_decision_tree.py b/dtreeviz/models/spark_decision_tree.py
index 0c59f68..c9ceec1 100644
--- a/dtreeviz/models/spark_decision_tree.py
+++ b/dtreeviz/models/spark_decision_tree.py
@@ -172,16 +172,23 @@ def get_node_feature(self, id) -> int:
return self.get_features()[id]
def get_node_nsamples_by_class(self, id):
- def _get_value(spark_version):
- if spark_version >= 3:
- return np.array(self.tree_nodes[id].impurityStats().stats())
- elif spark_version >= 2:
- return np.array(list(self.tree_nodes[id].impurityStats().stats()))
- else:
- raise Exception("dtreeviz supports spark versions >= 2")
-
+ all_nodes = self.internal + self.leaves
if self.is_classifier():
- return _get_value(ShadowSparkTree._get_pyspark_major_version())
+ node_value = [node.n_sample_classes() for node in all_nodes if node.id == id]
+ return node_value[0]
+
+ # This is the code to return the nsamples/class from tree metadata. It's faster, but the visualisations cannot
+ # be made on new datasets.
+ # def _get_value(spark_version):
+ # if spark_version >= 3:
+ # return np.array(self.tree_nodes[id].impurityStats().stats())
+ # elif spark_version >= 2:
+ # return np.array(list(self.tree_nodes[id].impurityStats().stats()))
+ # else:
+ # raise Exception("dtreeviz supports spark versions >= 2")
+ #
+ # if self.is_classifier():
+ # return _get_value(ShadowSparkTree._get_pyspark_major_version())
def get_prediction(self, id):
return self.tree_nodes[id].prediction()
diff --git a/dtreeviz/trees.py b/dtreeviz/trees.py
index a94b1db..051b146 100644
--- a/dtreeviz/trees.py
+++ b/dtreeviz/trees.py
@@ -338,11 +338,12 @@ def node_name(node: ShadowDecTreeNode) -> str:
def split_node(name, node_name, split):
if fancy:
+ filepath = os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg")
labelgraph = node_label(node) if show_node_labels else ''
html = f"""
{labelgraph}
- data:image/s3,"s3://crabby-images/b8400/b84000cf61cb9fbfd1319c521ea3a0b3c6239fa6" alt=""}.svg) |
+ data:image/s3,"s3://crabby-images/f192a/f192a9b85c12cf6c73154c38f0df95a58858bb7d" alt="" |
"""
else:
@@ -356,10 +357,11 @@ def split_node(name, node_name, split):
def regr_leaf_node(node, label_fontsize: int = 12):
# always generate fancy regr leaves for now but shrink a bit for nonfancy.
labelgraph = node_label(node) if show_node_labels else ''
+ filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg")
html = f"""
{labelgraph}
- data:image/s3,"s3://crabby-images/a3672/a3672898b52fa4174f3f38a6719bfaaa40e35641" alt=""}.svg) |
+ data:image/s3,"s3://crabby-images/f192a/f192a9b85c12cf6c73154c38f0df95a58858bb7d" alt="" |
"""
if node.id in highlight_path:
@@ -369,10 +371,11 @@ def regr_leaf_node(node, label_fontsize: int = 12):
def class_leaf_node(node, label_fontsize: int = 12):
labelgraph = node_label(node) if show_node_labels else ''
+ filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg")
html = f"""
{labelgraph}
- data:image/s3,"s3://crabby-images/a3672/a3672898b52fa4174f3f38a6719bfaaa40e35641" alt=""}.svg) |
+ data:image/s3,"s3://crabby-images/f192a/f192a9b85c12cf6c73154c38f0df95a58858bb7d" alt="" |
"""
if node.id in highlight_path:
@@ -384,10 +387,11 @@ def node_label(node):
return f'Node {node.id} |
'
def class_legend_html():
+ filepath = os.path.join(tmp, f"legend_{os.getpid()}.svg")
return f"""
- data:image/s3,"s3://crabby-images/b55fb/b55fb98d72dc3149cdf9fa7f96230f2b1328ddf2" alt=""}.svg) |
+ data:image/s3,"s3://crabby-images/f192a/f192a9b85c12cf6c73154c38f0df95a58858bb7d" alt="" |
"""
@@ -527,7 +531,7 @@ def get_leaves():
if np.max(class_values) >= n_classes:
raise ValueError(f"Target label values (for now) must be 0..{n_classes-1} for n={n_classes} labels")
color_map = {v: color_values[i] for i, v in enumerate(class_values)}
- _draw_legend(self.shadow_tree, self.shadow_tree.target_name, f"{tmp}/legend_{os.getpid()}.svg",
+ _draw_legend(self.shadow_tree, self.shadow_tree.target_name, os.path.join(tmp, f"legend_{os.getpid()}.svg"),
colors=colors,
fontname=fontname)
@@ -557,37 +561,40 @@ def get_leaves():
if depth_range_to_display is not None:
if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1):
continue
+ node_to_display = True
if fancy:
if self.shadow_tree.is_classifier():
- _class_split_viz(node, X_train, y_train,
- filename=f"{tmp}/node{node.id}_{os.getpid()}.svg",
- precision=precision,
- colors={**color_map, **colors},
- histtype=histtype,
- node_heights=node_heights,
- X=x,
- ticks_fontsize=ticks_fontsize,
- label_fontsize=label_fontsize,
- fontname=fontname,
- highlight_node=node.id in highlight_path)
+ node_to_display = _class_split_viz(node, X_train, y_train,
+ filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"),
+ precision=precision,
+ colors={**color_map, **colors},
+ histtype=histtype,
+ node_heights=node_heights,
+ X=x,
+ ticks_fontsize=ticks_fontsize,
+ label_fontsize=label_fontsize,
+ fontname=fontname,
+ highlight_node=node.id in highlight_path)
else:
- _regr_split_viz(node, X_train, y_train,
- filename=f"{tmp}/node{node.id}_{os.getpid()}.svg",
- target_name=self.shadow_tree.target_name,
- y_range=y_range,
- X=x,
- ticks_fontsize=ticks_fontsize,
- label_fontsize=label_fontsize,
- fontname=fontname,
- highlight_node=node.id in highlight_path,
- colors=colors)
+ node_to_display = _regr_split_viz(node, X_train, y_train,
+ filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"),
+ target_name=self.shadow_tree.target_name,
+ y_range=y_range,
+ X=x,
+ ticks_fontsize=ticks_fontsize,
+ label_fontsize=label_fontsize,
+ fontname=fontname,
+ highlight_node=node.id in highlight_path,
+ colors=colors)
nname = node_name(node)
if not node.is_categorical_split():
gr_node = split_node(node.feature_name(), nname, split=myround(node.split(), precision))
else:
gr_node = split_node(node.feature_name(), nname, split=node.split()[0])
- internal.append(gr_node)
+
+ if node_to_display:
+ internal.append(gr_node)
leaves = []
for node in get_leaves():
@@ -596,24 +603,24 @@ def get_leaves():
continue
if self.shadow_tree.is_classifier():
if _class_leaf_viz(node, colors=color_values,
- filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
+ filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
graph_colors=colors,
fontname=fontname,
leaftype=leaftype):
leaves.append(class_leaf_node(node))
else:
# for now, always gen leaf
- _regr_leaf_viz(node,
+ if _regr_leaf_viz(node,
y_train,
target_name=self.shadow_tree.target_name,
- filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
+ filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
y_range=y_range,
precision=precision,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
- colors=colors)
- leaves.append(regr_leaf_node(node))
+ colors=colors):
+ leaves.append(regr_leaf_node(node))
if show_just_path:
show_root_edge_labels = False
@@ -1133,6 +1140,14 @@ def _class_split_viz(node: ShadowDecTreeNode,
# Get X, y data for all samples associated with this node.
X_feature = X_train[:, node.feature()]
+
+
+ if len(node.samples()) == 0:
+ if filename is not None:
+ plt.savefig(filename, bbox_inches='tight', pad_inches=0)
+ plt.close()
+ return False
+
X_node_feature, y_train = X_feature[node.samples()], y_train[node.samples()]
n_classes = node.shadow_tree.nclasses()
@@ -1231,6 +1246,7 @@ def _class_split_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
+ return True
def _class_leaf_viz(node: ShadowDecTreeNode,
@@ -1285,6 +1301,12 @@ def _regr_split_viz(node: ShadowDecTreeNode,
figsize = (2.5, 1.1)
fig, ax = plt.subplots(1, 1, figsize=figsize)
+ if len(node.samples()) == 0:
+ if filename is not None:
+ plt.savefig(filename, bbox_inches='tight', pad_inches=0)
+ plt.close()
+ return False
+
feature_name = node.feature_name()
_format_axes(ax, feature_name, target_name if node == node.shadow_tree.root else None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False, pad_for_wedge=True)
ax.set_ylim(y_range)
@@ -1310,10 +1332,12 @@ def _regr_split_viz(node: ShadowDecTreeNode,
right = y_train[right]
split = node.split()
- ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'],
- linewidth=1)
+ if len(left) > 0:
+ ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'],
+ linewidth=1)
ax.plot([split, split], [*y_range], '--', color=colors['split_line'], linewidth=1)
- ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
+ if len(right) > 0:
+ ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
linewidth=1)
wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_classifier=False)
@@ -1357,6 +1381,7 @@ def _regr_split_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
+ return True
def _regr_leaf_viz(node: ShadowDecTreeNode,
@@ -1372,6 +1397,8 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
colors = adjust_colors(colors)
samples = node.samples()
+ if len(samples) == 0:
+ return False
y = y[samples]
figsize = (.75, .8)
@@ -1402,6 +1429,7 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
+ return True
def _draw_legend(shadow_tree, target_name, filename, colors, fontname):