views:

162

answers:

4

I've been reading up on Decision Trees and Cross Validation, and I understand both concepts. However, I'm having trouble understanding Cross Validation as it pertains to Decision Trees. Essentially Cross Validation allows you to alternate between training and testing when your dataset is relatively small to maximize your error estimation. A very simple algorithm goes something like this:

  1. Decide on the number of folds you want (k)
  2. Subdivide your dataset into k folds
  3. Use k-1 folds for a training set to build a tree.
  4. Use the testing set to estimate statistics about the error in your tree.
  5. Save your results for later
  6. Repeat steps 3-6 for k times leaving out a different fold for your test set.
  7. Average the errors across your iterations to predict the overall error

The problem I can't figure out is at the end you'll have k Decision trees that could all be slightly different because they might not split the same way, etc. Which tree do you pick? One idea I had was pick the one with minimal errors (although that doesn't make it optimal just that it performed best on the fold it was given - maybe using stratification will help but everything I've read say it only helps a little bit).

As I understand cross validation the point is to compute in node statistics that can later be used for pruning. So really each node in the tree will have statistics calculated for it based on the test set given to it. What's important are these in node stats, but if your averaging your error. How do you merge these stats within each node across k trees when each tree could vary in what they choose to split on, etc.

What's the point of calculating the overall error across each iteration? That's not something that could be used during pruning.

Any help with this little wrinkle would be much appreciated.

+1  A: 

Cross validation isn't used for buliding/pruning the decision tree. It's used to estimate how good the tree (built on all of the data) will perform by simulating arrival of new data (by building the tree without some elements just as you wrote). I doesn't really make sense to pick one of the trees generated by it because the model is constrained by the data you have (and not using it all might actually be worse when you use the tree for new data).
The tree is built over the data that you choose (usualy all of it). Pruning is usually done by using some heuristic (i.e. 90% of the elements in the node belongs to class A so we don't go any further or the information gain is too small).

pablochan
+6  A: 

The problem I can't figure out is at the end you'll have k Decision trees that could all be slightly different because they might not split the same way, etc. Which tree do you pick?

The purpose of cross validation is not to help select a particular instance of the classifier (or decision tree, or whatever automatic learning application) but rather to qualify the model, i.e. to provide metrics such as the average error ratio, the deviation relative to this average etc. which can be useful in asserting the level of precision one can expect from the application. One of the things cross validation can help assert is whether the training data is big enough.

With regards to selecting a particular tree, you should instead run yet another training on 100% of the training data available, as this typically will produce a better tree. (The downside of the Cross Validation approach is that we need to divide the [typically little] amount of training data into "folds" and as you hint in the question this can lead to trees which are either overfit or underfit for particular data instances).

In the case of decision tree, I'm not sure what your reference to statistics gathered in the node and used to prune the tree pertains to. Maybe a particular use of cross-validation related techniques?...

mjv
Ok if I think of what this might mean is that I could compare different types of classifiers (Decision Tree vs. SVM) given a data set and say which one would be more likely to predict correctly. I think cross validation really has nothing to do with pruning or the stats used to prune. Is it sufficient to have a training set used to build the tree and another set to perform pruning on (i.e. testing set)? Most of UCI data comes with separate training and testing sets so that's why I ask.
chubbard
+1  A: 

The main point of using cross-validation is that it gives you better estimate of the performance of your trained model when used on different data.

Which tree do you pick? One option would be that you bulid a new tree using all your data for training set.

Mr. Brownstone
+1  A: 

For the first part, and like the others have pointed out, we usually use the entire dataset for building the final model, but we use cross-validation (CV) to get a better estimate of the generalization error on new unseen data.

For the second part, I think you are confusing CV with the validation set, used to avoid overfitting the tree by pruning a node when some function value computed on the validation set does not increase before/after the split.

Amro