Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

299#make leaves to be placeholders #307

Merged
merged 3 commits into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,18 @@ print( stderr )

If you are having issues with run command you can try copying the following files from: https://github.com/xflr6/graphviz/tree/master/graphviz.

Place them in the AppData\Local\Continuum\anaconda3\Lib\site-packages\graphviz folder.
Place them in the AppData\Local\Continuum\anaconda3\Lib\site-packages\graphviz folder.

Clean out the __pycache__ directory too.

For graphviz windows install 8.0.5 and python interface v0.18+ :
```python
import graphviz.backend as be
cmd = ["dot", "-V"]
stdout = be.execute.run_check(cmd, capture_output=True, check=True, quiet=False)
print( stdout )
```

Jupyter Lab and Jupyter notebook both show the inline .svg images well.

### Verify graphviz installation
Expand Down
25 changes: 16 additions & 9 deletions dtreeviz/models/spark_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,23 @@ def get_node_feature(self, id) -> int:
return self.get_features()[id]

def get_node_nsamples_by_class(self, id):
def _get_value(spark_version):
if spark_version >= 3:
return np.array(self.tree_nodes[id].impurityStats().stats())
elif spark_version >= 2:
return np.array(list(self.tree_nodes[id].impurityStats().stats()))
else:
raise Exception("dtreeviz supports spark versions >= 2")

all_nodes = self.internal + self.leaves
if self.is_classifier():
return _get_value(ShadowSparkTree._get_pyspark_major_version())
node_value = [node.n_sample_classes() for node in all_nodes if node.id == id]
return node_value[0]

# This is the code to return the nsamples/class from tree metadata. It's faster, but the visualisations cannot
# be made on new datasets.
# def _get_value(spark_version):
# if spark_version >= 3:
# return np.array(self.tree_nodes[id].impurityStats().stats())
# elif spark_version >= 2:
# return np.array(list(self.tree_nodes[id].impurityStats().stats()))
# else:
# raise Exception("dtreeviz supports spark versions >= 2")
#
# if self.is_classifier():
# return _get_value(ShadowSparkTree._get_pyspark_major_version())

def get_prediction(self, id):
return self.tree_nodes[id].prediction()
Expand Down
113 changes: 71 additions & 42 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,12 @@ def node_name(node: ShadowDecTreeNode) -> str:

def split_node(name, node_name, split):
if fancy:
filepath = os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg")
labelgraph = node_label(node) if show_node_labels else ''
html = f"""<table border="0">
{labelgraph}
<tr>
<td><img src="{tmp}/node{node.id}_{os.getpid()}.svg"/></td>
<td><img src="{filepath}"/></td>
</tr>
</table>"""
else:
Expand All @@ -356,10 +357,11 @@ def split_node(name, node_name, split):
def regr_leaf_node(node, label_fontsize: int = 12):
# always generate fancy regr leaves for now but shrink a bit for nonfancy.
labelgraph = node_label(node) if show_node_labels else ''
filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg")
html = f"""<table border="0">
{labelgraph}
<tr>
<td><img src="{tmp}/leaf{node.id}_{os.getpid()}.svg"/></td>
<td><img src="{filepath}"/></td>
</tr>
</table>"""
if node.id in highlight_path:
Expand All @@ -369,10 +371,11 @@ def regr_leaf_node(node, label_fontsize: int = 12):

def class_leaf_node(node, label_fontsize: int = 12):
labelgraph = node_label(node) if show_node_labels else ''
filepath = os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg")
html = f"""<table border="0" CELLBORDER="0">
{labelgraph}
<tr>
<td><img src="{tmp}/leaf{node.id}_{os.getpid()}.svg"/></td>
<td><img src="{filepath}"/></td>
</tr>
</table>"""
if node.id in highlight_path:
Expand All @@ -384,10 +387,11 @@ def node_label(node):
return f'<tr><td CELLPADDING="0" CELLSPACING="0"><font face="{fontname}" color="{colors["node_label"]}" point-size="14"><i>Node {node.id}</i></font></td></tr>'

def class_legend_html():
filepath = os.path.join(tmp, f"legend_{os.getpid()}.svg")
return f"""
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td border="0" cellspacing="0" cellpadding="0"><img src="{tmp}/legend_{os.getpid()}.svg"/></td>
<td border="0" cellspacing="0" cellpadding="0"><img src="{filepath}"/></td>
</tr>
</table>
"""
Expand Down Expand Up @@ -527,7 +531,7 @@ def get_leaves():
if np.max(class_values) >= n_classes:
raise ValueError(f"Target label values (for now) must be 0..{n_classes-1} for n={n_classes} labels")
color_map = {v: color_values[i] for i, v in enumerate(class_values)}
_draw_legend(self.shadow_tree, self.shadow_tree.target_name, f"{tmp}/legend_{os.getpid()}.svg",
_draw_legend(self.shadow_tree, self.shadow_tree.target_name, os.path.join(tmp, f"legend_{os.getpid()}.svg"),
colors=colors,
fontname=fontname)

Expand Down Expand Up @@ -557,63 +561,66 @@ def get_leaves():
if depth_range_to_display is not None:
if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1):
continue
node_to_display = True
if fancy:
if self.shadow_tree.is_classifier():
_class_split_viz(node, X_train, y_train,
filename=f"{tmp}/node{node.id}_{os.getpid()}.svg",
precision=precision,
colors={**color_map, **colors},
histtype=histtype,
node_heights=node_heights,
X=x,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
highlight_node=node.id in highlight_path)
node_to_display = _class_split_viz(node, X_train, y_train,
filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"),
precision=precision,
colors={**color_map, **colors},
histtype=histtype,
node_heights=node_heights,
X=x,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
highlight_node=node.id in highlight_path)
else:
_regr_split_viz(node, X_train, y_train,
filename=f"{tmp}/node{node.id}_{os.getpid()}.svg",
target_name=self.shadow_tree.target_name,
y_range=y_range,
X=x,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
highlight_node=node.id in highlight_path,
colors=colors)
node_to_display = _regr_split_viz(node, X_train, y_train,
filename=os.path.join(tmp, f"node{node.id}_{os.getpid()}.svg"),
target_name=self.shadow_tree.target_name,
y_range=y_range,
X=x,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
highlight_node=node.id in highlight_path,
colors=colors)

nname = node_name(node)
if not node.is_categorical_split():
gr_node = split_node(node.feature_name(), nname, split=myround(node.split(), precision))
else:
gr_node = split_node(node.feature_name(), nname, split=node.split()[0])
internal.append(gr_node)

if node_to_display:
internal.append(gr_node)

leaves = []
for node in get_leaves():
if depth_range_to_display is not None:
if node.level not in range(depth_range_to_display[0], depth_range_to_display[1] + 1):
continue
if self.shadow_tree.is_classifier():
_class_leaf_viz(node, colors=color_values,
filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
if _class_leaf_viz(node, colors=color_values,
filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
graph_colors=colors,
fontname=fontname,
leaftype=leaftype)
leaves.append(class_leaf_node(node))
leaftype=leaftype):
leaves.append(class_leaf_node(node))
else:
# for now, always gen leaf
_regr_leaf_viz(node,
if _regr_leaf_viz(node,
y_train,
target_name=self.shadow_tree.target_name,
filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
y_range=y_range,
precision=precision,
ticks_fontsize=ticks_fontsize,
label_fontsize=label_fontsize,
fontname=fontname,
colors=colors)
leaves.append(regr_leaf_node(node))
colors=colors):
leaves.append(regr_leaf_node(node))

if show_just_path:
show_root_edge_labels = False
Expand Down Expand Up @@ -1133,6 +1140,14 @@ def _class_split_viz(node: ShadowDecTreeNode,

# Get X, y data for all samples associated with this node.
X_feature = X_train[:, node.feature()]


if len(node.samples()) == 0:
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return False

X_node_feature, y_train = X_feature[node.samples()], y_train[node.samples()]

n_classes = node.shadow_tree.nclasses()
Expand Down Expand Up @@ -1231,6 +1246,7 @@ def _class_split_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return True


def _class_leaf_viz(node: ShadowDecTreeNode,
Expand All @@ -1255,13 +1271,15 @@ def _class_leaf_viz(node: ShadowDecTreeNode,
# when using another dataset than the training dataset, some leaves could have 0 samples.
# Trying to make a pie chart will raise some deprecation
if sum(counts) == 0:
return
return False
if leaftype == 'pie':
_draw_piechart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}",
graph_colors=graph_colors, fontname=fontname)
return True
elif leaftype == 'barh':
_draw_barh_chart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}",
graph_colors=graph_colors, fontname=fontname)
return True
else:
raise ValueError(f'Undefined leaftype = {leaftype}')

Expand All @@ -1283,6 +1301,12 @@ def _regr_split_viz(node: ShadowDecTreeNode,
figsize = (2.5, 1.1)
fig, ax = plt.subplots(1, 1, figsize=figsize)

if len(node.samples()) == 0:
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return False

feature_name = node.feature_name()
_format_axes(ax, feature_name, target_name if node == node.shadow_tree.root else None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False, pad_for_wedge=True)
ax.set_ylim(y_range)
Expand All @@ -1308,10 +1332,12 @@ def _regr_split_viz(node: ShadowDecTreeNode,
right = y_train[right]
split = node.split()

ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'],
linewidth=1)
if len(left) > 0:
ax.plot([overall_feature_range[0], split], [np.mean(left), np.mean(left)], '--', color=colors['split_line'],
linewidth=1)
ax.plot([split, split], [*y_range], '--', color=colors['split_line'], linewidth=1)
ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
if len(right) > 0:
ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
linewidth=1)

wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_classifier=False)
Expand Down Expand Up @@ -1355,6 +1381,7 @@ def _regr_split_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return True


def _regr_leaf_viz(node: ShadowDecTreeNode,
Expand All @@ -1370,13 +1397,14 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
colors = adjust_colors(colors)

samples = node.samples()
if len(samples) == 0:
return False
y = y[samples]

figsize = (.75, .8)

fig, ax = plt.subplots(1, 1, figsize=figsize)

m = np.mean(y)
m = node.prediction()

_format_axes(ax, None, None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False)
ax.set_ylim(y_range)
Expand All @@ -1401,6 +1429,7 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
return True


def _draw_legend(shadow_tree, target_name, filename, colors, fontname):
Expand Down Expand Up @@ -1530,7 +1559,7 @@ def _get_leaf_target_input(shadow_tree: ShadowDecTree, precision: int):
for i, node in enumerate(shadow_tree.leaves):
leaf_index_sample = node.samples()
leaf_target = shadow_tree.y_train[leaf_index_sample]
leaf_target_mean = np.mean(leaf_target)
leaf_target_mean = node.prediction()
np.random.seed(0) # generate the same list of random values for each call
X = np.random.normal(i, sigma, size=len(leaf_target))

Expand Down