diff --git a/README.md b/README.md
index f71e163..e97ec9c 100644
--- a/README.md
+++ b/README.md
@@ -256,10 +256,18 @@ print( stderr )
If you are having issues with run command you can try copying the following files from: https://github.com/xflr6/graphviz/tree/master/graphviz.
-Place them in the AppData\Local\Continuum\anaconda3\Lib\site-packages\graphviz folder.
+Place them in the AppData\Local\Continuum\anaconda3\Lib\site-packages\graphviz folder.
Clean out the __pycache__ directory too.
+For graphviz windows install 8.0.5 and python interface v0.18+ :
+```python
+import graphviz.backend as be
+cmd = ["dot", "-V"]
+stdout = be.execute.run_check(cmd, capture_output=True, check=True, quiet=False)
+print( stdout )
+```
+
Jupyter Lab and Jupyter notebook both show the inline .svg images well.
### Verify graphviz installation
diff --git a/dtreeviz/models/spark_decision_tree.py b/dtreeviz/models/spark_decision_tree.py
index 0c59f68..c9ceec1 100644
--- a/dtreeviz/models/spark_decision_tree.py
+++ b/dtreeviz/models/spark_decision_tree.py
@@ -172,16 +172,23 @@ def get_node_feature(self, id) -> int:
return self.get_features()[id]
def get_node_nsamples_by_class(self, id):
- def _get_value(spark_version):
- if spark_version >= 3:
- return np.array(self.tree_nodes[id].impurityStats().stats())
- elif spark_version >= 2:
- return np.array(list(self.tree_nodes[id].impurityStats().stats()))
- else:
- raise Exception("dtreeviz supports spark versions >= 2")
-
+ all_nodes = self.internal + self.leaves
if self.is_classifier():
- return _get_value(ShadowSparkTree._get_pyspark_major_version())
+ node_value = [node.n_sample_classes() for node in all_nodes if node.id == id]
+ return node_value[0]
+
+ # This is the code to return the nsamples/class from tree metadata. It's faster, but the visualisations cannot
+ # be made on new datasets.
+ # def _get_value(spark_version):
+ # if spark_version >= 3:
+ # return np.array(self.tree_nodes[id].impurityStats().stats())
+ # elif spark_version >= 2:
+ # return np.array(list(self.tree_nodes[id].impurityStats().stats()))
+ # else:
+ # raise Exception("dtreeviz supports spark versions >= 2")
+ #
+ # if self.is_classifier():
+ # return _get_value(ShadowSparkTree._get_pyspark_major_version())
def get_prediction(self, id):
return self.tree_nodes[id].prediction()
diff --git a/dtreeviz/trees.py b/dtreeviz/trees.py
index 8c5cec9..051b146 100644
--- a/dtreeviz/trees.py
+++ b/dtreeviz/trees.py
@@ -338,11 +338,12 @@ def node_name(node: ShadowDecTreeNode) -> str:
def split_node(name, node_name, split):
if fancy:
+ filepath = os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg")
labelgraph = node_label(node) if show_node_labels else ''
html = f"""
{labelgraph}
- }.svg) |
+  |
"""
else:
@@ -356,10 +357,11 @@ def split_node(name, node_name, split):
def regr_leaf_node(node, label_fontsize: int = 12):
# always generate fancy regr leaves for now but shrink a bit for nonfancy.
labelgraph = node_label(node) if show_node_labels else ''
+ filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg")
html = f"""
{labelgraph}
- }.svg) |
+  |
"""
if node.id in highlight_path:
@@ -369,10 +371,11 @@ def regr_leaf_node(node, label_fontsize: int = 12):
def class_leaf_node(node, label_fontsize: int = 12):
labelgraph = node_label(node) if show_node_labels else ''
+ filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg")
html = f"""
{labelgraph}
- }.svg) |
+  |
"""
if node.id in highlight_path:
@@ -384,10 +387,11 @@ def node_label(node):
return f'Node {node.id} |
'
def class_legend_html():
+ filepath = os.path.join(tmp, f"legend_{os.getpid()}.svg")
return f"""
- }.svg) |
+  |
"""
@@ -527,7 +531,7 @@ def get_leaves():
if np.max(class_values) >= n_classes:
raise ValueError(f"Target label values (for now) must be 0..{n_classes-1} for n={n_classes} labels")
color_map = {v: color_values[i] for i, v in enumerate(class_values)}
- _draw_legend(self.shadow_tree, self.shadow_tree.target_name, f"{tmp}/legend_{os.getpid()}.svg",
+ _draw_legend(self.shadow_tree, self.shadow_tree.target_name, os.path.join(tmp, f"legend_{os.getpid()}.svg"),
colors=colors,
fontname=fontname)
@@ -557,37 +561,40 @@ def get_leaves():
if depth_range_to_display is not None:
if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1):
continue
+ node_to_display = True
if fancy:
if self.shadow_tree.is_classifier():
- _class_split_viz(node, X_train, y_train,
- filename=f"{tmp}/node{node.id}_{os.getpid()}.svg",
- precision=precision,
- colors={**color_map, **colors},
- histtype=histtype,
- node_heights=node_heights,
- X=x,
- ticks_fontsize=ticks_fontsize,
- label_fontsize=label_fontsize,
- fontname=fontname,
- highlight_node=node.id in highlight_path)
+ node_to_display = _class_split_viz(node, X_train, y_train,
+ filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"),
+ precision=precision,
+ colors={**color_map, **colors},
+ histtype=histtype,
+ node_heights=node_heights,
+ X=x,
+ ticks_fontsize=ticks_fontsize,
+ label_fontsize=label_fontsize,
+ fontname=fontname,
+ highlight_node=node.id in highlight_path)
else:
- _regr_split_viz(node, X_train, y_train,
- filename=f"{tmp}/node{node.id}_{os.getpid()}.svg",
- target_name=self.shadow_tree.target_name,
- y_range=y_range,
- X=x,
- ticks_fontsize=ticks_fontsize,
- label_fontsize=label_fontsize,
- fontname=fontname,
- highlight_node=node.id in highlight_path,
- colors=colors)
+ node_to_display = _regr_split_viz(node, X_train, y_train,
+ filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"),
+ target_name=self.shadow_tree.target_name,
+ y_range=y_range,
+ X=x,
+ ticks_fontsize=ticks_fontsize,
+ label_fontsize=label_fontsize,
+ fontname=fontname,
+ highlight_node=node.id in highlight_path,
+ colors=colors)
nname = node_name(node)
if not node.is_categorical_split():
gr_node = split_node(node.feature_name(), nname, split=myround(node.split(), precision))
else:
gr_node = split_node(node.feature_name(), nname, split=node.split()[0])
- internal.append(gr_node)
+
+ if node_to_display:
+ internal.append(gr_node)
leaves = []
for node in get_leaves():
@@ -595,25 +602,25 @@ def get_leaves():
if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1):
continue
if self.shadow_tree.is_classifier():
- _class_leaf_viz(node, colors=color_values,
- filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
+ if _class_leaf_viz(node, colors=color_values,
+ filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
graph_colors=colors,
fontname=fontname,
- leaftype=leaftype)
- leaves.append(class_leaf_node(node))
+ leaftype=leaftype):
+ leaves.append(class_leaf_node(node))
else:
# for now, always gen leaf
- _regr_leaf_viz(node,
+ if _regr_leaf_viz(node,
y_train,
target_name=self.shadow_tree.target_name,
- filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
+ filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
y_range=y_range,
precision=precision,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
- colors=colors)
- leaves.append(regr_leaf_node(node))
+ colors=colors):
+ leaves.append(regr_leaf_node(node))
if show_just_path:
show_root_edge_labels = False
@@ -1133,6 +1140,14 @@ def _class_split_viz(node: ShadowDecTreeNode,
# Get X, y data for all samples associated with this node.
X_feature = X_train[:, node.feature()]
+
+
+ if len(node.samples()) == 0:
+ if filename is not None:
+ plt.savefig(filename, bbox_inches='tight', pad_inches=0)
+ plt.close()
+ return False
+
X_node_feature, y_train = X_feature[node.samples()], y_train[node.samples()]
n_classes = node.shadow_tree.nclasses()
@@ -1231,6 +1246,7 @@ def _class_split_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
+ return True
def _class_leaf_viz(node: ShadowDecTreeNode,
@@ -1255,13 +1271,15 @@ def _class_leaf_viz(node: ShadowDecTreeNode,
# when using another dataset than the training dataset, some leaves could have 0 samples.
# Trying to make a pie chart will raise some deprecation
if sum(counts) == 0:
- return
+ return False
if leaftype == 'pie':
_draw_piechart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}",
graph_colors=graph_colors, fontname=fontname)
+ return True
elif leaftype == 'barh':
_draw_barh_chart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}",
graph_colors=graph_colors, fontname=fontname)
+ return True
else:
raise ValueError(f'Undefined leaftype = {leaftype}')
@@ -1283,6 +1301,12 @@ def _regr_split_viz(node: ShadowDecTreeNode,
figsize = (2.5, 1.1)
fig, ax = plt.subplots(1, 1, figsize=figsize)
+ if len(node.samples()) == 0:
+ if filename is not None:
+ plt.savefig(filename, bbox_inches='tight', pad_inches=0)
+ plt.close()
+ return False
+
feature_name = node.feature_name()
_format_axes(ax, feature_name, target_name if node == node.shadow_tree.root else None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False, pad_for_wedge=True)
ax.set_ylim(y_range)
@@ -1308,10 +1332,12 @@ def _regr_split_viz(node: ShadowDecTreeNode,
right = y_train[right]
split = node.split()
- ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'],
- linewidth=1)
+ if len(left) > 0:
+ ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'],
+ linewidth=1)
ax.plot([split, split], [*y_range], '--', color=colors['split_line'], linewidth=1)
- ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
+ if len(right) > 0:
+ ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
linewidth=1)
wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_classifier=False)
@@ -1355,6 +1381,7 @@ def _regr_split_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
+ return True
def _regr_leaf_viz(node: ShadowDecTreeNode,
@@ -1370,13 +1397,14 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
colors = adjust_colors(colors)
samples = node.samples()
+ if len(samples) == 0:
+ return False
y = y[samples]
figsize = (.75, .8)
fig, ax = plt.subplots(1, 1, figsize=figsize)
-
- m = np.mean(y)
+ m = node.prediction()
_format_axes(ax, None, None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False)
ax.set_ylim(y_range)
@@ -1401,6 +1429,7 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
+ return True
def _draw_legend(shadow_tree, target_name, filename, colors, fontname):
@@ -1530,7 +1559,7 @@ def _get_leaf_target_input(shadow_tree: ShadowDecTree, precision: int):
for i, node in enumerate(shadow_tree.leaves):
leaf_index_sample = node.samples()
leaf_target = shadow_tree.y_train[leaf_index_sample]
- leaf_target_mean = np.mean(leaf_target)
+ leaf_target_mean = node.prediction()
np.random.seed(0) # generate the same list of random values for each call
X = np.random.normal(i, sigma, size=len(leaf_target))