Skip to content

Commit

Permalink
Fix the issues when the node/leaf don't contain samples
Browse files Browse the repository at this point in the history
  • Loading branch information
tlapusan authored and parrt committed Sep 23, 2023
1 parent e5d4a2b commit 6b2d0ab
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 44 deletions.
25 changes: 16 additions & 9 deletions dtreeviz/models/spark_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,23 @@ def get_node_feature(self, id) -> int:
return self.get_features()[id]

def get_node_nsamples_by_class(self, id):
def _get_value(spark_version):
if spark_version >= 3:
return np.array(self.tree_nodes[id].impurityStats().stats())
elif spark_version >= 2:
return np.array(list(self.tree_nodes[id].impurityStats().stats()))
else:
raise Exception("dtreeviz supports spark versions >= 2")

all_nodes = self.internal + self.leaves
if self.is_classifier():
return _get_value(ShadowSparkTree._get_pyspark_major_version())
node_value = [node.n_sample_classes() for node in all_nodes if node.id == id]
return node_value[0]

# This is the code to return the nsamples/class from tree metadata. It's faster, but the visualisations cannot
# be made on new datasets.
# def _get_value(spark_version):
# if spark_version >= 3:
# return np.array(self.tree_nodes[id].impurityStats().stats())
# elif spark_version >= 2:
# return np.array(list(self.tree_nodes[id].impurityStats().stats()))
# else:
# raise Exception("dtreeviz supports spark versions >= 2")
#
# if self.is_classifier():
# return _get_value(ShadowSparkTree._get_pyspark_major_version())

def get_prediction(self, id):
return self.tree_nodes[id].prediction()
Expand Down
98 changes: 63 additions & 35 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,12 @@ def node_name(node: ShadowDecTreeNode) -> str:

def split_node(name, node_name, split):
if fancy:
filepath = os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg")
labelgraph = node_label(node) if show_node_labels else ''
html = f"""<table border="0">
{labelgraph}
<tr>
<td><img src="{tmp}/node{node.id}_{os.getpid()}.svg"/></td>
<td><img src="{filepath}"/></td>
</tr>
</table>"""
else:
Expand All @@ -356,10 +357,11 @@ def split_node(name, node_name, split):
def regr_leaf_node(node, label_fontsize: int = 12):
# always generate fancy regr leaves for now but shrink a bit for nonfancy.
labelgraph = node_label(node) if show_node_labels else ''
filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg")
html = f"""<table border="0">
{labelgraph}
<tr>
<td><img src="{tmp}/leaf{node.id}_{os.getpid()}.svg"/></td>
<td><img src="{filepath}"/></td>
</tr>
</table>"""
if node.id in highlight_path:
Expand All @@ -369,10 +371,11 @@ def regr_leaf_node(node, label_fontsize: int = 12):

def class_leaf_node(node, label_fontsize: int = 12):
labelgraph = node_label(node) if show_node_labels else ''
filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg")
html = f"""<table border="0" CELLBORDER="0">
{labelgraph}
<tr>
<td><img src="{tmp}/leaf{node.id}_{os.getpid()}.svg"/></td>
<td><img src="{filepath}"/></td>
</tr>
</table>"""
if node.id in highlight_path:
Expand All @@ -384,10 +387,11 @@ def node_label(node):
return f'<tr><td CELLPADDING="0" CELLSPACING="0"><font face="{fontname}" color="{colors["node_label"]}" point-size="14"><i>Node {node.id}</i></font></td></tr>'

def class_legend_html():
filepath = os.path.join(tmp, f"legend_{os.getpid()}.svg")
return f"""
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td border="0" cellspacing="0" cellpadding="0"><img src="{tmp}/legend_{os.getpid()}.svg"/></td>
<td border="0" cellspacing="0" cellpadding="0"><img src="{filepath}"/></td>
</tr>
</table>
"""
Expand Down Expand Up @@ -527,7 +531,7 @@ def get_leaves():
if np.max(class_values) >= n_classes:
raise ValueError(f"Target label values (for now) must be 0..{n_classes-1} for n={n_classes} labels")
color_map = {v: color_values[i] for i, v in enumerate(class_values)}
_draw_legend(self.shadow_tree, self.shadow_tree.target_name, f"{tmp}/legend_{os.getpid()}.svg",
_draw_legend(self.shadow_tree, self.shadow_tree.target_name, os.path.join(tmp, f"legend_{os.getpid()}.svg"),
colors=colors,
fontname=fontname)

Expand Down Expand Up @@ -557,37 +561,40 @@ def get_leaves():
if depth_range_to_display is not None:
if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1):
continue
node_to_display = True
if fancy:
if self.shadow_tree.is_classifier():
_class_split_viz(node, X_train, y_train,
filename=f"{tmp}/node{node.id}_{os.getpid()}.svg",
precision=precision,
colors={**color_map, **colors},
histtype=histtype,
node_heights=node_heights,
X=x,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
highlight_node=node.id in highlight_path)
node_to_display = _class_split_viz(node, X_train, y_train,
filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"),
precision=precision,
colors={**color_map, **colors},
histtype=histtype,
node_heights=node_heights,
X=x,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
highlight_node=node.id in highlight_path)
else:
_regr_split_viz(node, X_train, y_train,
filename=f"{tmp}/node{node.id}_{os.getpid()}.svg",
target_name=self.shadow_tree.target_name,
y_range=y_range,
X=x,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
highlight_node=node.id in highlight_path,
colors=colors)
node_to_display = _regr_split_viz(node, X_train, y_train,
filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"),
target_name=self.shadow_tree.target_name,
y_range=y_range,
X=x,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
highlight_node=node.id in highlight_path,
colors=colors)

nname = node_name(node)
if not node.is_categorical_split():
gr_node = split_node(node.feature_name(), nname, split=myround(node.split(), precision))
else:
gr_node = split_node(node.feature_name(), nname, split=node.split()[0])
internal.append(gr_node)

if node_to_display:
internal.append(gr_node)

leaves = []
for node in get_leaves():
Expand All @@ -596,24 +603,24 @@ def get_leaves():
continue
if self.shadow_tree.is_classifier():
if _class_leaf_viz(node, colors=color_values,
filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
graph_colors=colors,
fontname=fontname,
leaftype=leaftype):
leaves.append(class_leaf_node(node))
else:
# for now, always gen leaf
_regr_leaf_viz(node,
if _regr_leaf_viz(node,
y_train,
target_name=self.shadow_tree.target_name,
filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
y_range=y_range,
precision=precision,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
colors=colors)
leaves.append(regr_leaf_node(node))
colors=colors):
leaves.append(regr_leaf_node(node))

if show_just_path:
show_root_edge_labels = False
Expand Down Expand Up @@ -1133,6 +1140,14 @@ def _class_split_viz(node: ShadowDecTreeNode,

# Get X, y data for all samples associated with this node.
X_feature = X_train[:, node.feature()]


if len(node.samples()) == 0:
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return False

X_node_feature, y_train = X_feature[node.samples()], y_train[node.samples()]

n_classes = node.shadow_tree.nclasses()
Expand Down Expand Up @@ -1231,6 +1246,7 @@ def _class_split_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return True


def _class_leaf_viz(node: ShadowDecTreeNode,
Expand Down Expand Up @@ -1285,6 +1301,12 @@ def _regr_split_viz(node: ShadowDecTreeNode,
figsize = (2.5, 1.1)
fig, ax = plt.subplots(1, 1, figsize=figsize)

if len(node.samples()) == 0:
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return False

feature_name = node.feature_name()
_format_axes(ax, feature_name, target_name if node == node.shadow_tree.root else None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False, pad_for_wedge=True)
ax.set_ylim(y_range)
Expand All @@ -1310,10 +1332,12 @@ def _regr_split_viz(node: ShadowDecTreeNode,
right = y_train[right]
split = node.split()

ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'],
linewidth=1)
if len(left) > 0:
ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'],
linewidth=1)
ax.plot([split, split], [*y_range], '--', color=colors['split_line'], linewidth=1)
ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
if len(right) > 0:
ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
linewidth=1)

wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_classifier=False)
Expand Down Expand Up @@ -1357,6 +1381,7 @@ def _regr_split_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return True


def _regr_leaf_viz(node: ShadowDecTreeNode,
Expand All @@ -1372,6 +1397,8 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
colors = adjust_colors(colors)

samples = node.samples()
if len(samples) == 0:
return False
y = y[samples]

figsize = (.75, .8)
Expand Down Expand Up @@ -1402,6 +1429,7 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return True


def _draw_legend(shadow_tree, target_name, filename, colors, fontname):
Expand Down

0 comments on commit 6b2d0ab

Please sign in to comment.