I've been reading up on Decision Trees and Cross Validation, and I understand both concepts. However, I'm having trouble understanding Cross Validation as it pertains to Decision Trees. Essentially Cross Validation allows you to alternate between training and testing when your dataset is relatively small to maximize your error estimation. A very simple algorithm goes something like this:
- Decide on the number of folds you want (k)
- Subdivide your dataset into k folds
- Use k-1 folds for a training set to build a tree.
- Use the testing set to estimate statistics about the error in your tree.
- Save your results for later
- Repeat steps 3-6 for k times leaving out a different fold for your test set.
- Average the errors across your iterations to predict the overall error
The problem I can't figure out is at the end you'll have k Decision trees that could all be slightly different because they might not split the same way, etc. Which tree do you pick? One idea I had was pick the one with minimal errors (although that doesn't make it optimal just that it performed best on the fold it was given - maybe using stratification will help but everything I've read say it only helps a little bit).
As I understand cross validation the point is to compute in node statistics that can later be used for pruning. So really each node in the tree will have statistics calculated for it based on the test set given to it. What's important are these in node stats, but if your averaging your error. How do you merge these stats within each node across k trees when each tree could vary in what they choose to split on, etc.
What's the point of calculating the overall error across each iteration? That's not something that could be used during pruning.
Any help with this little wrinkle would be much appreciated.