How to get node information for a Spark decision tree model… here is a solution to the problem.
How to get node information for a Spark decision tree model
I would like to get more details about each node of the decision tree generation model for Spark MLlib. The closest I can get using the API is print(model.toDebugString()),
which returns something like this (taken from the PySpark documentation).
DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.0)
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
How can I modify the MLlib source code to get, for example, impurities and depth per node? (How do I call a new Scala function in PySpark if needed?) )
Solution
I’ll try to accomplish the @mostOfMajority’s answer by describing how I used PySpark 2.4.3.
Root node
Given a well-trained decision tree model, this is how to get its root node:
def _get_root_node(tree: DecisionTreeClassificationModel):
return tree._call_java('rootNode')
Impurities
We can get impurities by traversing the tree down from the root node. Its pre-order transversal can do this:
def get_impurities(tree: DecisionTreeClassificationModel) -> List[float]:
def recur(node):
if node.numDescendants() == 0:
return []
ni = node.impurity()
return (
recur(node.leftChild()) + [ni] + recur(node.rightChild())
)
return recur(_get_root_node(tree))
Example
In [1]: print(tree.toDebugString)
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f90ba6dbb0fe) of depth 3 with 7 nodes
If (feature 0 <= 6.5)
If (feature 0 <= 3.5)
Predict: 1.0
Else (feature 0 > 3.5)
If (feature 0 <= 5.0)
Predict: 0.0
Else (feature 0 > 5.0)
Predict: 1.0
Else (feature 0 > 6.5)
Predict: 0.0
In [2]: cat.get_impurities(tree)
Out[2]: [0.4444444444444444, 0.5, 0.5]