Ruleex
Algorithms of rule extraction from neural networks
Install / Use
/learn @fantamat/RuleexREADME
ruleex
The ruleex is a python package which implements rule extraction algorithms. Now, there are available three algorithms (HypInv, ANN-DT, and DeepRED). All of those algorithms were slightly modified. Especially, extracted rules are stored in the tree structure with decision nodes which can have any form, e.g., axis parallel or linear split. Implementation of the general decision tree, i.e., the tree with any kind of decision node, is also included along with some handy operations.
Following description provides only short description focused on the code (classes, functions, and their arguments). For more details of the theory behind the implementation there will be available my master thesis.
tree
This subpackage contains implementation of the general decision tree, i.e., decision tree with any kind of decision nodes. It is also possible to operate with so-called decision graph which can have more than one incoming edges to each node.
class Rule
Rule is an abstract class describing the decision node inside the tree. Object of this class stores at least these arguments:
- true_branch and false_branch: pointers to the next node or None.
- class_set: set of classes that the node represents (in the tree this sets should follow the order in inclusion manner)
- class_hits: the list with the number of samples from each class.
- num_true and num_false: the number of samples that were redirected to the true and false branch, respectively.
The methods that are needed to be overridden are:
- eval_rule: evaluates the rule, i.e., returns True if the rule's condition is fulfilled.
- eval_all: evaluates a list of samples (uses numpy to optimize the decision process).
- to_string: prints the node to the string.
- copy: returns a new instance as a copy of current object.
It is also conveniente to extend __init__ method to store some other crucial value for a specific node.
Implemented subclasses are AxisRule and LinearRule representing axis parallel and linear split node. There is also implemented Leaf - the leaf node of the tree.
It is possible to implement other subclasses of the class Rule however in the tree structure leafs should be always of the class Leaf because the algorithms implemented in RuleTree class are dependent on that property.
class RuleTree
The main class of this subpackage. It represents general decision tree.
This class is not encapsulated. It is rather free to operate on. Its implementation serves two purpouses. The first is a storage of the general information about the tee such as
- root node,
- number of classes,
- number of input values.
And the second purpose is to implement some simple operation on the tree structure.
- static methods
- rt.half_in_ruletree(number_of_classes): creates tree with the same size as the number of classes. This tree makes classification based on the first value that is grater than 0.5.
- rt.load(filename): loads RuleTree object from the pickle file.
- evaluation
- rt.eval_one(x): evaluates one sample.
- rt.eval_all(x, all_nodes=None, prevs=None): evaluates all samples stored in the list x. When arguments all_nodes and prevs are passed then the method do not need to do recursive calls and its faster (it also holds for other methods with these arguments).
- copy, save, and print
- rt.copy(): makes deep copy of the tree.
- rt.save(filename): saves ruletree into the pickle file.
- rt.to_string(): prints ruletree. Be aware of the computation complexity of this function when the tree is big.
- rt.print_expanded_rules(): prints all rules, i.e., each path from the root to a leaf node as If-THEN rule with conjuction of the nodes conditions.
- rt.view_graph(): returns graphviz Digraph. It also can show it and store it as pdf.
- getters
- rt.get_all_nodes(): returns all nodes that decision graph contains starting with the rt.root.
- rt.get_predecessor_dict(): returns a map between a node and a list of their predecessors.
- rt.get_all_expanded_rules(): returns list of all paths in the graph that starts in the rt.root and ends in a leaf node.
- rt.get_thresholds(): for the trees with Axis rules returns all thresholds for each attribute
- rt.get_rule_index_dict(x): returns a map between a node and indexes of samples x that visited the node during the evaluation.
- rt.get_stats(): Return some chosen properties of the ruletree: number of rules, number of nodes, number of used indexes by the tree (supports only AxisRule and LinearRule), a sum of the lenght of all rules, a sum of indexes used in each rule.
- graph operations
- rt.replace_leaf_with_set(leaf_class_set, replacing_rule, all_nodes=None, prevs=None): replaces all leafs with the class set equal to leaf_class_set by replacing rule.
- rt.replace_rule(matching_rule, replacement, prev_entries=None): replaces matching_rule by the replacement.
If prev_entries are pressent then it do not recursively looking for matching_rule
in the graph. (Hint:
prev_entries = prevs\[matching_rule\]) - rt.delete_node(node, branch, all_nodes, prevs): deletes the node in graph structure and replace it by its branch specified with the argument.
- rt.fill_class_sets(): fills class sets accordingly to the leaf class sets. Node with two Leafs in branches with class sets {1} and {0,3} lead will have class set {0,1,3}.
- rt.fill_hits(x, labels, remove_redundat=False): fills class_hits in all nodes in the graph by evaluating samples x with their true labels. If remove_redundat is True then the nodes that are not filled are removed.
- checks
- rt.check_none_inside_tree(): returns True if only Leaf nodes has None in some branch.
- rt.check_one_evaluation(): checks if the graph evalueates each sample to only one class.
- rt.graph_max_indegree(all_nodes=None, return_all=False): returns maximal indegree occurring in the graph. And returns a map between a node and its indegree if return_all is True.
- condition pruning
- rt.fill_hits with remove_redundat=True option.
- rt.delete_interval_redundancy(): Deletes the nodes which result is forced by conditions of the previous nodes in the path from the rt.root to it.
- rt.delete_class_redundacy(): Deletes the parts of the RuleTree structure which lead to the same result (classification).
- rt.remove_unvisited(visited_rules): removes other nodes that specified in argument.
- rt.remove_unvisited_edges(all_nodes, prevs): removes the nodes that has num_true or num_false equal to zero.
- rt.prune_using_hits(min_split_fraction=0.0, min_rule_samples=1, all_nodes=None):
Removes all nodes that have sum of class_hits less than min_rule_samples and
max(num_true/num_false, num_false/num_true) < min_split_fraction.
Building RuleTree
Standard method to built DT is based on researching all possible splits and taking that minimizes weighted sum of impurities of its division. The main concept of this algorithm is implemented in build module by function build_from_rules.
Required arguments are
- rule_list: a list of rules, e.g., object with a class AxisRule
- data: a list of inputs of training samples
- labels: a list of labels of training samples
Warning: Do not use this function for more than 100 possible nodes. It is general and not optimized. If you want to train standard DT use sklearn and then convert the result by sklearndt_to_ruletree function in ulits module. (Even ANN-DT algorithm can be set to build DT in the standard way and it is optimized)
utils.py
Main reason for this module is a conversion between other models of DT and the RuleTree class.
sklearndt_to_ruletree converts sklearn model into the RuleTree object. It is also possible to add some additional pruning setting like min_split_fraction (see rt.prune_using_hits).
pybliquedt_to_ruletree and train_OC1 assumes a package pyblique. They produce a RuleTree with LinearRule splits.
rutils.py
This module required rpy2 and R-project installed. It provides functions to train C5.0 DT and convert it to the RuleTree object.
ANN-DT method (anndt subpackage)
This method is very similar to the standard DT building, but it at each node looks if there is enough samples and if not new samples are generated by sampling input space and taking the output of the NN. There are also other impurity functions because the NN provides continuous output for samples and not just label.
The main function is anndt(model_fun, x, params, MeasureClass=None, stat_test=None, sampler=None, init_restrictions=None):, where arguments are:
- model_fun is a function that returns output of the NN for a list of samples when is called.
- x a list of training samples.
- MeasureClass class of the impurity measure (defined in module measures).
- stat_test a function of used statistical test (some functions are defined in module stat_test).
- sampler a subclass of Sampler defined in module sampling, where some examples are implemented. Pay most of the attention to setting or implementing this class for your data.
- init_restriction a list of minimal and maximal value of the input.
- params a dictionary of additional parameters the main parameters are:
- min_samples: minimal number of samples for creating a new node. This values is used to generate that many new samples to fulfill the condition.
- min_train: minimal number of train samples to create a new node. If the number of training samples is lower building current subtree stops.
- max_depth: maximal depth of a generated DT.
- min_split_fraction: one of the stopping criteria see rt.prune_using_hits for more details.
- varbose:
