From 6b2d0abc434c306dc336ecc876a5aa0491aaf827 Mon Sep 17 00:00:00 2001 From: "tudor.lapusan" Date: Sat, 23 Sep 2023 16:42:23 +0300 Subject: [PATCH] Fix the issues when the node/leaf don't contain samples --- dtreeviz/models/spark_decision_tree.py | 25 ++++--- dtreeviz/trees.py | 98 +++++++++++++++++--------- 2 files changed, 79 insertions(+), 44 deletions(-) diff --git a/dtreeviz/models/spark_decision_tree.py b/dtreeviz/models/spark_decision_tree.py index 0c59f687..c9ceec13 100644 --- a/dtreeviz/models/spark_decision_tree.py +++ b/dtreeviz/models/spark_decision_tree.py @@ -172,16 +172,23 @@ def get_node_feature(self, id) -> int: return self.get_features()[id] def get_node_nsamples_by_class(self, id): - def _get_value(spark_version): - if spark_version >= 3: - return np.array(self.tree_nodes[id].impurityStats().stats()) - elif spark_version >= 2: - return np.array(list(self.tree_nodes[id].impurityStats().stats())) - else: - raise Exception("dtreeviz supports spark versions >= 2") - + all_nodes = self.internal + self.leaves if self.is_classifier(): - return _get_value(ShadowSparkTree._get_pyspark_major_version()) + node_value = [node.n_sample_classes() for node in all_nodes if node.id == id] + return node_value[0] + + # This is the code to return the nsamples/class from tree metadata. It's faster, but the visualisations cannot + # be made on new datasets. + # def _get_value(spark_version): + # if spark_version >= 3: + # return np.array(self.tree_nodes[id].impurityStats().stats()) + # elif spark_version >= 2: + # return np.array(list(self.tree_nodes[id].impurityStats().stats())) + # else: + # raise Exception("dtreeviz supports spark versions >= 2") + # + # if self.is_classifier(): + # return _get_value(ShadowSparkTree._get_pyspark_major_version()) def get_prediction(self, id): return self.tree_nodes[id].prediction() diff --git a/dtreeviz/trees.py b/dtreeviz/trees.py index a94b1db5..051b146c 100644 --- a/dtreeviz/trees.py +++ b/dtreeviz/trees.py @@ -338,11 +338,12 @@ def node_name(node: ShadowDecTreeNode) -> str: def split_node(name, node_name, split): if fancy: + filepath = os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg") labelgraph = node_label(node) if show_node_labels else '' html = f""" {labelgraph} - +
""" else: @@ -356,10 +357,11 @@ def split_node(name, node_name, split): def regr_leaf_node(node, label_fontsize: int = 12): # always generate fancy regr leaves for now but shrink a bit for nonfancy. labelgraph = node_label(node) if show_node_labels else '' + filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg") html = f""" {labelgraph} - +
""" if node.id in highlight_path: @@ -369,10 +371,11 @@ def regr_leaf_node(node, label_fontsize: int = 12): def class_leaf_node(node, label_fontsize: int = 12): labelgraph = node_label(node) if show_node_labels else '' + filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg") html = f""" {labelgraph} - +
""" if node.id in highlight_path: @@ -384,10 +387,11 @@ def node_label(node): return f'Node {node.id}' def class_legend_html(): + filepath = os.path.join(tmp, f"legend_{os.getpid()}.svg") return f""" - +
""" @@ -527,7 +531,7 @@ def get_leaves(): if np.max(class_values) >= n_classes: raise ValueError(f"Target label values (for now) must be 0..{n_classes-1} for n={n_classes} labels") color_map = {v: color_values[i] for i, v in enumerate(class_values)} - _draw_legend(self.shadow_tree, self.shadow_tree.target_name, f"{tmp}/legend_{os.getpid()}.svg", + _draw_legend(self.shadow_tree, self.shadow_tree.target_name, os.path.join(tmp, f"legend_{os.getpid()}.svg"), colors=colors, fontname=fontname) @@ -557,37 +561,40 @@ def get_leaves(): if depth_range_to_display is not None: if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1): continue + node_to_display = True if fancy: if self.shadow_tree.is_classifier(): - _class_split_viz(node, X_train, y_train, - filename=f"{tmp}/node{node.id}_{os.getpid()}.svg", - precision=precision, - colors={**color_map, **colors}, - histtype=histtype, - node_heights=node_heights, - X=x, - ticks_fontsize=ticks_fontsize, - label_fontsize=label_fontsize, - fontname=fontname, - highlight_node=node.id in highlight_path) + node_to_display = _class_split_viz(node, X_train, y_train, + filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"), + precision=precision, + colors={**color_map, **colors}, + histtype=histtype, + node_heights=node_heights, + X=x, + ticks_fontsize=ticks_fontsize, + label_fontsize=label_fontsize, + fontname=fontname, + highlight_node=node.id in highlight_path) else: - _regr_split_viz(node, X_train, y_train, - filename=f"{tmp}/node{node.id}_{os.getpid()}.svg", - target_name=self.shadow_tree.target_name, - y_range=y_range, - X=x, - ticks_fontsize=ticks_fontsize, - label_fontsize=label_fontsize, - fontname=fontname, - highlight_node=node.id in highlight_path, - colors=colors) + node_to_display = _regr_split_viz(node, X_train, y_train, + filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"), + target_name=self.shadow_tree.target_name, + y_range=y_range, + X=x, + ticks_fontsize=ticks_fontsize, + label_fontsize=label_fontsize, + fontname=fontname, + highlight_node=node.id in highlight_path, + colors=colors) nname = node_name(node) if not node.is_categorical_split(): gr_node = split_node(node.feature_name(), nname, split=myround(node.split(), precision)) else: gr_node = split_node(node.feature_name(), nname, split=node.split()[0]) - internal.append(gr_node) + + if node_to_display: + internal.append(gr_node) leaves = [] for node in get_leaves(): @@ -596,24 +603,24 @@ def get_leaves(): continue if self.shadow_tree.is_classifier(): if _class_leaf_viz(node, colors=color_values, - filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg", + filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"), graph_colors=colors, fontname=fontname, leaftype=leaftype): leaves.append(class_leaf_node(node)) else: # for now, always gen leaf - _regr_leaf_viz(node, + if _regr_leaf_viz(node, y_train, target_name=self.shadow_tree.target_name, - filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg", + filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"), y_range=y_range, precision=precision, ticks_fontsize=ticks_fontsize, label_fontsize=label_fontsize, fontname=fontname, - colors=colors) - leaves.append(regr_leaf_node(node)) + colors=colors): + leaves.append(regr_leaf_node(node)) if show_just_path: show_root_edge_labels = False @@ -1133,6 +1140,14 @@ def _class_split_viz(node: ShadowDecTreeNode, # Get X, y data for all samples associated with this node. X_feature = X_train[:, node.feature()] + + + if len(node.samples()) == 0: + if filename is not None: + plt.savefig(filename, bbox_inches='tight', pad_inches=0) + plt.close() + return False + X_node_feature, y_train = X_feature[node.samples()], y_train[node.samples()] n_classes = node.shadow_tree.nclasses() @@ -1231,6 +1246,7 @@ def _class_split_viz(node: ShadowDecTreeNode, if filename is not None: plt.savefig(filename, bbox_inches='tight', pad_inches=0) plt.close() + return True def _class_leaf_viz(node: ShadowDecTreeNode, @@ -1285,6 +1301,12 @@ def _regr_split_viz(node: ShadowDecTreeNode, figsize = (2.5, 1.1) fig, ax = plt.subplots(1, 1, figsize=figsize) + if len(node.samples()) == 0: + if filename is not None: + plt.savefig(filename, bbox_inches='tight', pad_inches=0) + plt.close() + return False + feature_name = node.feature_name() _format_axes(ax, feature_name, target_name if node == node.shadow_tree.root else None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False, pad_for_wedge=True) ax.set_ylim(y_range) @@ -1310,10 +1332,12 @@ def _regr_split_viz(node: ShadowDecTreeNode, right = y_train[right] split = node.split() - ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'], - linewidth=1) + if len(left) > 0: + ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'], + linewidth=1) ax.plot([split, split], [*y_range], '--', color=colors['split_line'], linewidth=1) - ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'], + if len(right) > 0: + ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'], linewidth=1) wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_classifier=False) @@ -1357,6 +1381,7 @@ def _regr_split_viz(node: ShadowDecTreeNode, if filename is not None: plt.savefig(filename, bbox_inches='tight', pad_inches=0) plt.close() + return True def _regr_leaf_viz(node: ShadowDecTreeNode, @@ -1372,6 +1397,8 @@ def _regr_leaf_viz(node: ShadowDecTreeNode, colors = adjust_colors(colors) samples = node.samples() + if len(samples) == 0: + return False y = y[samples] figsize = (.75, .8) @@ -1402,6 +1429,7 @@ def _regr_leaf_viz(node: ShadowDecTreeNode, if filename is not None: plt.savefig(filename, bbox_inches='tight', pad_inches=0) plt.close() + return True def _draw_legend(shadow_tree, target_name, filename, colors, fontname):