diff --git a/README.md b/README.md index f71e163..e97ec9c 100644 --- a/README.md +++ b/README.md @@ -256,10 +256,18 @@ print( stderr ) If you are having issues with run command you can try copying the following files from: https://github.com/xflr6/graphviz/tree/master/graphviz. -Place them in the AppData\Local\Continuum\anaconda3\Lib\site-packages\graphviz folder. +Place them in the AppData\Local\Continuum\anaconda3\Lib\site-packages\graphviz folder. Clean out the __pycache__ directory too. +For graphviz windows install 8.0.5 and python interface v0.18+ : +```python +import graphviz.backend as be +cmd = ["dot", "-V"] +stdout = be.execute.run_check(cmd, capture_output=True, check=True, quiet=False) +print( stdout ) +``` + Jupyter Lab and Jupyter notebook both show the inline .svg images well. ### Verify graphviz installation diff --git a/dtreeviz/models/spark_decision_tree.py b/dtreeviz/models/spark_decision_tree.py index 0c59f68..c9ceec1 100644 --- a/dtreeviz/models/spark_decision_tree.py +++ b/dtreeviz/models/spark_decision_tree.py @@ -172,16 +172,23 @@ def get_node_feature(self, id) -> int: return self.get_features()[id] def get_node_nsamples_by_class(self, id): - def _get_value(spark_version): - if spark_version >= 3: - return np.array(self.tree_nodes[id].impurityStats().stats()) - elif spark_version >= 2: - return np.array(list(self.tree_nodes[id].impurityStats().stats())) - else: - raise Exception("dtreeviz supports spark versions >= 2") - + all_nodes = self.internal + self.leaves if self.is_classifier(): - return _get_value(ShadowSparkTree._get_pyspark_major_version()) + node_value = [node.n_sample_classes() for node in all_nodes if node.id == id] + return node_value[0] + + # This is the code to return the nsamples/class from tree metadata. It's faster, but the visualisations cannot + # be made on new datasets. + # def _get_value(spark_version): + # if spark_version >= 3: + # return np.array(self.tree_nodes[id].impurityStats().stats()) + # elif spark_version >= 2: + # return np.array(list(self.tree_nodes[id].impurityStats().stats())) + # else: + # raise Exception("dtreeviz supports spark versions >= 2") + # + # if self.is_classifier(): + # return _get_value(ShadowSparkTree._get_pyspark_major_version()) def get_prediction(self, id): return self.tree_nodes[id].prediction() diff --git a/dtreeviz/trees.py b/dtreeviz/trees.py index 8c5cec9..051b146 100644 --- a/dtreeviz/trees.py +++ b/dtreeviz/trees.py @@ -338,11 +338,12 @@ def node_name(node: ShadowDecTreeNode) -> str: def split_node(name, node_name, split): if fancy: + filepath = os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg") labelgraph = node_label(node) if show_node_labels else '' html = f""" {labelgraph} - +
""" else: @@ -356,10 +357,11 @@ def split_node(name, node_name, split): def regr_leaf_node(node, label_fontsize: int = 12): # always generate fancy regr leaves for now but shrink a bit for nonfancy. labelgraph = node_label(node) if show_node_labels else '' + filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg") html = f""" {labelgraph} - +
""" if node.id in highlight_path: @@ -369,10 +371,11 @@ def regr_leaf_node(node, label_fontsize: int = 12): def class_leaf_node(node, label_fontsize: int = 12): labelgraph = node_label(node) if show_node_labels else '' + filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg") html = f""" {labelgraph} - +
""" if node.id in highlight_path: @@ -384,10 +387,11 @@ def node_label(node): return f'Node {node.id}' def class_legend_html(): + filepath = os.path.join(tmp, f"legend_{os.getpid()}.svg") return f""" - +
""" @@ -527,7 +531,7 @@ def get_leaves(): if np.max(class_values) >= n_classes: raise ValueError(f"Target label values (for now) must be 0..{n_classes-1} for n={n_classes} labels") color_map = {v: color_values[i] for i, v in enumerate(class_values)} - _draw_legend(self.shadow_tree, self.shadow_tree.target_name, f"{tmp}/legend_{os.getpid()}.svg", + _draw_legend(self.shadow_tree, self.shadow_tree.target_name, os.path.join(tmp, f"legend_{os.getpid()}.svg"), colors=colors, fontname=fontname) @@ -557,37 +561,40 @@ def get_leaves(): if depth_range_to_display is not None: if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1): continue + node_to_display = True if fancy: if self.shadow_tree.is_classifier(): - _class_split_viz(node, X_train, y_train, - filename=f"{tmp}/node{node.id}_{os.getpid()}.svg", - precision=precision, - colors={**color_map, **colors}, - histtype=histtype, - node_heights=node_heights, - X=x, - ticks_fontsize=ticks_fontsize, - label_fontsize=label_fontsize, - fontname=fontname, - highlight_node=node.id in highlight_path) + node_to_display = _class_split_viz(node, X_train, y_train, + filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"), + precision=precision, + colors={**color_map, **colors}, + histtype=histtype, + node_heights=node_heights, + X=x, + ticks_fontsize=ticks_fontsize, + label_fontsize=label_fontsize, + fontname=fontname, + highlight_node=node.id in highlight_path) else: - _regr_split_viz(node, X_train, y_train, - filename=f"{tmp}/node{node.id}_{os.getpid()}.svg", - target_name=self.shadow_tree.target_name, - y_range=y_range, - X=x, - ticks_fontsize=ticks_fontsize, - label_fontsize=label_fontsize, - fontname=fontname, - highlight_node=node.id in highlight_path, - colors=colors) + node_to_display = _regr_split_viz(node, X_train, y_train, + filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"), + target_name=self.shadow_tree.target_name, + y_range=y_range, + X=x, + ticks_fontsize=ticks_fontsize, + label_fontsize=label_fontsize, + fontname=fontname, + highlight_node=node.id in highlight_path, + colors=colors) nname = node_name(node) if not node.is_categorical_split(): gr_node = split_node(node.feature_name(), nname, split=myround(node.split(), precision)) else: gr_node = split_node(node.feature_name(), nname, split=node.split()[0]) - internal.append(gr_node) + + if node_to_display: + internal.append(gr_node) leaves = [] for node in get_leaves(): @@ -595,25 +602,25 @@ def get_leaves(): if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1): continue if self.shadow_tree.is_classifier(): - _class_leaf_viz(node, colors=color_values, - filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg", + if _class_leaf_viz(node, colors=color_values, + filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"), graph_colors=colors, fontname=fontname, - leaftype=leaftype) - leaves.append(class_leaf_node(node)) + leaftype=leaftype): + leaves.append(class_leaf_node(node)) else: # for now, always gen leaf - _regr_leaf_viz(node, + if _regr_leaf_viz(node, y_train, target_name=self.shadow_tree.target_name, - filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg", + filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"), y_range=y_range, precision=precision, ticks_fontsize=ticks_fontsize, label_fontsize=label_fontsize, fontname=fontname, - colors=colors) - leaves.append(regr_leaf_node(node)) + colors=colors): + leaves.append(regr_leaf_node(node)) if show_just_path: show_root_edge_labels = False @@ -1133,6 +1140,14 @@ def _class_split_viz(node: ShadowDecTreeNode, # Get X, y data for all samples associated with this node. X_feature = X_train[:, node.feature()] + + + if len(node.samples()) == 0: + if filename is not None: + plt.savefig(filename, bbox_inches='tight', pad_inches=0) + plt.close() + return False + X_node_feature, y_train = X_feature[node.samples()], y_train[node.samples()] n_classes = node.shadow_tree.nclasses() @@ -1231,6 +1246,7 @@ def _class_split_viz(node: ShadowDecTreeNode, if filename is not None: plt.savefig(filename, bbox_inches='tight', pad_inches=0) plt.close() + return True def _class_leaf_viz(node: ShadowDecTreeNode, @@ -1255,13 +1271,15 @@ def _class_leaf_viz(node: ShadowDecTreeNode, # when using another dataset than the training dataset, some leaves could have 0 samples. # Trying to make a pie chart will raise some deprecation if sum(counts) == 0: - return + return False if leaftype == 'pie': _draw_piechart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}", graph_colors=graph_colors, fontname=fontname) + return True elif leaftype == 'barh': _draw_barh_chart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}", graph_colors=graph_colors, fontname=fontname) + return True else: raise ValueError(f'Undefined leaftype = {leaftype}') @@ -1283,6 +1301,12 @@ def _regr_split_viz(node: ShadowDecTreeNode, figsize = (2.5, 1.1) fig, ax = plt.subplots(1, 1, figsize=figsize) + if len(node.samples()) == 0: + if filename is not None: + plt.savefig(filename, bbox_inches='tight', pad_inches=0) + plt.close() + return False + feature_name = node.feature_name() _format_axes(ax, feature_name, target_name if node == node.shadow_tree.root else None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False, pad_for_wedge=True) ax.set_ylim(y_range) @@ -1308,10 +1332,12 @@ def _regr_split_viz(node: ShadowDecTreeNode, right = y_train[right] split = node.split() - ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'], - linewidth=1) + if len(left) > 0: + ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'], + linewidth=1) ax.plot([split, split], [*y_range], '--', color=colors['split_line'], linewidth=1) - ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'], + if len(right) > 0: + ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'], linewidth=1) wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_classifier=False) @@ -1355,6 +1381,7 @@ def _regr_split_viz(node: ShadowDecTreeNode, if filename is not None: plt.savefig(filename, bbox_inches='tight', pad_inches=0) plt.close() + return True def _regr_leaf_viz(node: ShadowDecTreeNode, @@ -1370,13 +1397,14 @@ def _regr_leaf_viz(node: ShadowDecTreeNode, colors = adjust_colors(colors) samples = node.samples() + if len(samples) == 0: + return False y = y[samples] figsize = (.75, .8) fig, ax = plt.subplots(1, 1, figsize=figsize) - - m = np.mean(y) + m = node.prediction() _format_axes(ax, None, None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False) ax.set_ylim(y_range) @@ -1401,6 +1429,7 @@ def _regr_leaf_viz(node: ShadowDecTreeNode, if filename is not None: plt.savefig(filename, bbox_inches='tight', pad_inches=0) plt.close() + return True def _draw_legend(shadow_tree, target_name, filename, colors, fontname): @@ -1530,7 +1559,7 @@ def _get_leaf_target_input(shadow_tree: ShadowDecTree, precision: int): for i, node in enumerate(shadow_tree.leaves): leaf_index_sample = node.samples() leaf_target = shadow_tree.y_train[leaf_index_sample] - leaf_target_mean = np.mean(leaf_target) + leaf_target_mean = node.prediction() np.random.seed(0) # generate the same list of random values for each call X = np.random.normal(i, sigma, size=len(leaf_target))