Some model parameters cannot be learned directly from a data set during model training; these kinds of parameters are called hyperparameters. Some examples of hyperparameters include the number of predictors that are sampled at splits in a tree-based model (we call this mtry in tidymodels) or the learning rate in a boosted tree model (we call this learn_rate). Instead of learning these kinds of hyperparameters during model training, we can estimate the best values for these values by training many models on resampled data sets and exploring how well all these models perform. This process is called tuning.

The General Social Survey (revisited)

In our previous Evaluate your model with resampling article, we introduced a data set of survey respondents who indicated whether or not they believed Muslim clergymen who express anti-American attitudes should be allowed to teach in a college or university. We trained a random forest model to predict respondents’ responses, and used resampling to estimate the performance of our model on this data.

data("gss", package = "rcis")

# select a smaller subset of variables for analysis
gss <- gss %>%
    id, wtss, colmslm, age, black, degree,
    hispanic_2, polviews, pray, sex, south, tolerance
  ) %>%
  # drop observations with missing values - could always use imputation instead

Table 1: Data summary
Number of rows940
Number of columns12
Column type frequency:
Group variablesNone

Variable type: factor

colmslm01FALSE2Not: 582, Yes: 358
black01FALSE2No: 779, Yes: 161
degree01FALSE5HS: 477, Bac: 190, Gra: 105, <HS: 91
hispanic_201FALSE2No: 856, Yes: 84
polviews01FALSE7Mod: 335, Con: 160, Slg: 135, Lib: 123
pray01FALSE6ONC: 295, SEV: 256, NEV: 125, LT : 107
sex01FALSE2Fem: 509, Mal: 431
south01FALSE2Non: 561, Sou: 379

Variable type: numeric


Predicting attitudes, but better

Random forest models are a tree-based ensemble method, and typically perform well with default hyperparameters. However, the accuracy of some other tree-based models, such as boosted tree models or decision tree models, can be sensitive to the values of hyperparameters. In this article, we will train a decision tree model. There are several hyperparameters for decision tree models that can be tuned for better performance. Let’s explore:

  • the complexity parameter (which we call cost_complexity in tidymodels) for the tree, and
  • the maximum tree_depth.

Tuning these hyperparameters can improve model performance because decision tree models are prone to overfitting. This happens because single tree models tend to fit the training data too well — so well, in fact, that they over-learn patterns present in the training data that end up being detrimental when predicting new data.

We will tune the model hyperparameters to avoid overfitting. Tuning the value of cost_complexity helps by pruning back our tree. It adds a cost, or penalty, to error rates of more complex trees; a cost closer to zero decreases the number tree nodes pruned and is more likely to result in an overfit tree. However, a high cost increases the number of tree nodes pruned and can result in the opposite problem — an underfit tree. Tuning tree_depth, on the other hand, helps by stopping our tree from growing after it reaches a certain depth. We want to tune these hyperparameters to find what those two values should be for our model to do the best job predicting image segmentation.

Before we start the tuning process, we split our data into training and testing sets, just like when we trained the model with one default set of hyperparameters. As before, we can use strata = class if we want our training and testing sets to be created using stratified sampling so that both have the same proportion of both kinds of segmentation.

gss_split <- initial_split(gss, strata = colmslm)

gss_train <- training(gss_split)
gss_test <- testing(gss_split)

We use the training data for tuning the model.

Tuning hyperparameters

Let’s start with the parsnip package, using a decision_tree() model with the rpart engine. To tune the decision tree hyperparameters cost_complexity and tree_depth, we create a model specification that identifies which hyperparameters we plan to tune.

tune_spec <-
    cost_complexity = tune(),
    tree_depth = tune()
  ) %>%
  set_engine("rpart") %>%

## Decision Tree Model Specification (classification)
## Main Arguments:
##   cost_complexity = tune()
##   tree_depth = tune()
## Computational engine: rpart

Think of tune() here as a placeholder. After the tuning process, we will select a single numeric value for each of these hyperparameters. For now, we specify our parsnip model object and identify the hyperparameters we will tune().

We can’t train this specification on a single data set (such as the entire training set) and learn what the hyperparameter values should be, but we can train many models using resampled data and see which models turn out best. We can create a regular grid of values to try using some convenience functions for each hyperparameter:

tree_grid <- grid_regular(cost_complexity(),
  levels = 5

The function grid_regular() is from the dials package. It chooses sensible values to try for each hyperparameter; here, we asked for 5 of each. Since we have two to tune, grid_regular() returns 5 $\times$ 5 = 25 different possible tuning combinations to try in a tidy tibble format.

## # A tibble: 25 x 2
##    cost_complexity tree_depth
##              <dbl>      <int>
##  1    0.0000000001          1
##  2    0.0000000178          1
##  3    0.00000316            1
##  4    0.000562              1
##  5    0.1                   1
##  6    0.0000000001          4
##  7    0.0000000178          4
##  8    0.00000316            4
##  9    0.000562              4
## 10    0.1                   4
## # … with 15 more rows

Here, you can see all 5 values of cost_complexity ranging up to 0.1. These values get repeated for each of the 5 values of tree_depth:

tree_grid %>%
## # A tibble: 5 x 2
##   tree_depth     n
##        <int> <int>
## 1          1     5
## 2          4     5
## 3          8     5
## 4         11     5
## 5         15     5

Armed with our grid filled with 25 candidate decision tree models, let’s create cross-validation folds for tuning:

gss_folds <- vfold_cv(gss_train)

Tuning in tidymodels requires a resampled object created with the rsample package.

Model tuning with a grid

We are ready to tune! Let’s use tune_grid() to fit models at all the different values we chose for each tuned hyperparameter. There are several options for building the object for tuning:

  • Tune a model specification along with a recipe or model, or

  • Tune a workflow() that bundles together a model specification and a recipe or model preprocessor.

Here we use a workflow() with a straightforward formula; if this model required more involved data preprocessing, we could use add_recipe() instead of add_formula().


tree_wf <- workflow() %>%
  add_model(tune_spec) %>%
  add_formula(colmslm ~ .)

tree_res <-
  tree_wf %>%
    resamples = gss_folds,
    grid = tree_grid

## # Tuning results
## # 10-fold cross-validation 
## # A tibble: 10 x 4
##    splits           id     .metrics          .notes          
##    <list>           <chr>  <list>            <list>          
##  1 <split [635/71]> Fold01 <tibble [50 × 6]> <tibble [0 × 1]>
##  2 <split [635/71]> Fold02 <tibble [50 × 6]> <tibble [0 × 1]>
##  3 <split [635/71]> Fold03 <tibble [50 × 6]> <tibble [0 × 1]>
##  4 <split [635/71]> Fold04 <tibble [50 × 6]> <tibble [0 × 1]>
##  5 <split [635/71]> Fold05 <tibble [50 × 6]> <tibble [0 × 1]>
##  6 <split [635/71]> Fold06 <tibble [50 × 6]> <tibble [0 × 1]>
##  7 <split [636/70]> Fold07 <tibble [50 × 6]> <tibble [0 × 1]>
##  8 <split [636/70]> Fold08 <tibble [50 × 6]> <tibble [0 × 1]>
##  9 <split [636/70]> Fold09 <tibble [50 × 6]> <tibble [0 × 1]>
## 10 <split [636/70]> Fold10 <tibble [50 × 6]> <tibble [0 × 1]>

Once we have our tuning results, we can both explore them through visualization and then select the best result. The function collect_metrics() gives us a tidy tibble with all the results. We had 25 candidate models and two metrics, accuracy and roc_auc, and we get a row for each .metric and model.

tree_res %>%
## # A tibble: 50 x 8
##    cost_complexity tree_depth .metric  .estimator  mean     n std_err .config   
##              <dbl>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>     
##  1    0.0000000001          1 accuracy binary     0.812    10  0.0111 Preproces…
##  2    0.0000000001          1 roc_auc  binary     0.809    10  0.0108 Preproces…
##  3    0.0000000178          1 accuracy binary     0.812    10  0.0111 Preproces…
##  4    0.0000000178          1 roc_auc  binary     0.809    10  0.0108 Preproces…
##  5    0.00000316            1 accuracy binary     0.812    10  0.0111 Preproces…
##  6    0.00000316            1 roc_auc  binary     0.809    10  0.0108 Preproces…
##  7    0.000562              1 accuracy binary     0.812    10  0.0111 Preproces…
##  8    0.000562              1 roc_auc  binary     0.809    10  0.0108 Preproces…
##  9    0.1                   1 accuracy binary     0.812    10  0.0111 Preproces…
## 10    0.1                   1 roc_auc  binary     0.809    10  0.0108 Preproces…
## # … with 40 more rows

We might get more out of plotting these results:

tree_res %>%
  collect_metrics() %>%
  mutate(tree_depth = factor(tree_depth)) %>%
  ggplot(aes(cost_complexity, mean, color = tree_depth)) +
  geom_line(size = 1.5, alpha = 0.6) +
  geom_point(size = 2) +
  facet_wrap(facets = vars(.metric), scales = "free", nrow = 2) +
  scale_x_log10(labels = scales::label_number()) +
  scale_color_viridis_d(option = "plasma", begin = .9, end = 0)

We can see that our “stubbiest” tree, with a depth of 1, is the worst model according to roc_auc (though surprisingly the most accurate) and across all candidate values of cost_complexity. Deeper trees tend to do better for this problem. However, the best tree seems to be between these values with a tree depth of 8. The show_best() function shows us the top 5 candidate models by default:

tree_res %>%
## # A tibble: 5 x 8
##   cost_complexity tree_depth .metric .estimator  mean     n std_err .config     
##             <dbl>      <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>       
## 1    0.0000000001          8 roc_auc binary     0.839    10  0.0157 Preprocesso…
## 2    0.0000000178          8 roc_auc binary     0.839    10  0.0157 Preprocesso…
## 3    0.00000316            8 roc_auc binary     0.839    10  0.0157 Preprocesso…
## 4    0.000562              8 roc_auc binary     0.839    10  0.0157 Preprocesso…
## 5    0.0000000001         15 roc_auc binary     0.838    10  0.0179 Preprocesso…

We can also use the select_best() function to pull out the single set of hyperparameter values for our best decision tree model:

best_tree <- tree_res %>%

## # A tibble: 1 x 3
##   cost_complexity tree_depth .config              
##             <dbl>      <int> <chr>                
## 1    0.0000000001          8 Preprocessor1_Model11

These are the values for tree_depth and cost_complexity that maximize AUC in this data set of respondents.

Finalizing our model

We can update (or “finalize”) our workflow object tree_wf with the values from select_best().

final_wf <-
  tree_wf %>%

## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: decision_tree()
## ── Preprocessor ────────────────────────────────────────────────────────────────
## colmslm ~ .
## ── Model ───────────────────────────────────────────────────────────────────────
## Decision Tree Model Specification (classification)
## Main Arguments:
##   cost_complexity = 1e-10
##   tree_depth = 8
## Computational engine: rpart

Our tuning is done!

Exploring results

Let’s fit this final model to the training data. What does the decision tree look like?

final_tree <-
  final_wf %>%
  fit(data = gss_train)

## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: decision_tree()
## ── Preprocessor ────────────────────────────────────────────────────────────────
## colmslm ~ .
## ── Model ───────────────────────────────────────────────────────────────────────
## n= 706 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
##   1) root 706 269 Not allowed (0.38101983 0.61898017)  
##     2) tolerance>=12.5 305  83 Yes, allowed (0.72786885 0.27213115)  
##       4) tolerance>=13.5 238  48 Yes, allowed (0.79831933 0.20168067)  
##         8) id< 1840 223  38 Yes, allowed (0.82959641 0.17040359)  
##          16) id>=1632.5 27   0 Yes, allowed (1.00000000 0.00000000) *
##          17) id< 1632.5 196  38 Yes, allowed (0.80612245 0.19387755)  
##            34) wtss>=1.697784 35   2 Yes, allowed (0.94285714 0.05714286) *
##            35) wtss< 1.697784 161  36 Yes, allowed (0.77639752 0.22360248)  
##              70) id< 463.5 44   4 Yes, allowed (0.90909091 0.09090909) *
##              71) id>=463.5 117  32 Yes, allowed (0.72649573 0.27350427)  
##               142) id>=595.5 103  24 Yes, allowed (0.76699029 0.23300971) *
##               143) id< 595.5 14   6 Not allowed (0.42857143 0.57142857) *
##         9) id>=1840 15   5 Not allowed (0.33333333 0.66666667) *
##       5) tolerance< 13.5 67  32 Not allowed (0.47761194 0.52238806)  
##        10) pray=SEVERAL TIMES A DAY,ONCE A DAY,SEVERAL TIMES A WEEK,LT ONCE A WEEK 55  24 Yes, allowed (0.56363636 0.43636364)  
##          20) polviews=ExtrmLib,Liberal,ExtrmCons 12   2 Yes, allowed (0.83333333 0.16666667) *
##          21) polviews=SlghtLib,Moderate,SlghtCons,Conserv 43  21 Not allowed (0.48837209 0.51162791)  
##            42) degree=HS,Junior Coll,Bachelor deg 35  15 Yes, allowed (0.57142857 0.42857143)  
##              84) sex=Male 14   3 Yes, allowed (0.78571429 0.21428571) *
##              85) sex=Female 21   9 Not allowed (0.42857143 0.57142857)  
##               170) degree=Junior Coll,Bachelor deg 10   4 Yes, allowed (0.60000000 0.40000000) *
##               171) degree=HS 11   3 Not allowed (0.27272727 0.72727273) *
##            43) degree=<HS,Graduate deg 8   1 Not allowed (0.12500000 0.87500000) *
##        11) pray=ONCE A WEEK,NEVER 12   1 Not allowed (0.08333333 0.91666667) *
##     3) tolerance< 12.5 401  47 Not allowed (0.11720698 0.88279302)  
##       6) id< 646.5 109  24 Not allowed (0.22018349 0.77981651)  
##        12) tolerance>=7.5 83  24 Not allowed (0.28915663 0.71084337)  
##          24) id>=570.5 7   1 Yes, allowed (0.85714286 0.14285714) *
##          25) id< 570.5 76  18 Not allowed (0.23684211 0.76315789)  
##            50) pray=SEVERAL TIMES A DAY,ONCE A DAY,NEVER 55  16 Not allowed (0.29090909 0.70909091)  
##             100) polviews=ExtrmLib,Liberal,ExtrmCons 10   4 Yes, allowed (0.60000000 0.40000000) *
##             101) polviews=SlghtLib,Moderate,SlghtCons,Conserv 45  10 Not allowed (0.22222222 0.77777778)  
##               202) age< 59.5 33  10 Not allowed (0.30303030 0.69696970)  
##                 404) age>=53.5 9   4 Yes, allowed (0.55555556 0.44444444) *
##                 405) age< 53.5 24   5 Not allowed (0.20833333 0.79166667) *
##               203) age>=59.5 12   0 Not allowed (0.00000000 1.00000000) *
##            51) pray=SEVERAL TIMES A WEEK,ONCE A WEEK,LT ONCE A WEEK 21   2 Not allowed (0.09523810 0.90476190) *
##        13) tolerance< 7.5 26   0 Not allowed (0.00000000 1.00000000) *
##       7) id>=646.5 292  23 Not allowed (0.07876712 0.92123288)  
##        14) age< 26.5 36   9 Not allowed (0.25000000 0.75000000)  
##          28) pray=ONCE A DAY,SEVERAL TIMES A WEEK,ONCE A WEEK,NEVER 25   9 Not allowed (0.36000000 0.64000000)  
##            56) polviews=Liberal,SlghtLib 7   3 Yes, allowed (0.57142857 0.42857143) *
##            57) polviews=Moderate,SlghtCons 18   5 Not allowed (0.27777778 0.72222222) *
##          29) pray=SEVERAL TIMES A DAY,LT ONCE A WEEK 11   0 Not allowed (0.00000000 1.00000000) *
##        15) age>=26.5 256  14 Not allowed (0.05468750 0.94531250) *
## ...
## and 0 more lines.

This final_tree object has the finalized, fitted model object inside. You may want to extract the model object from the workflow. To do this, you can use the helper function pull_workflow_fit().

For example, perhaps we would also like to understand what variables are important in this final model. We can use the vip package to estimate variable importance.


final_tree %>%
  pull_workflow_fit() %>%

These are the survey variables that are the most important in driving predictions on the Muslim clergymen question.

The last fit

Finally, let’s return to our test data and estimate the model performance we expect to see with new data. We can use the function last_fit() with our finalized model; this function fits the finalized model on the full training data set and evaluates the finalized model on the testing data.

final_fit <-
  final_wf %>%

final_fit %>%
## # A tibble: 2 x 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.765 Preprocessor1_Model1
## 2 roc_auc  binary         0.801 Preprocessor1_Model1
final_fit %>%
  collect_predictions() %>%
  roc_curve(colmslm, `.pred_Yes, allowed`) %>%

The performance metrics from the test set indicate that we did not overfit during our tuning procedure.

We leave it to the reader to explore whether you can tune a different decision tree hyperparameter. You can explore the reference docs, or use the args() function to see which parsnip object arguments are available:

## function (mode = "unknown", cost_complexity = NULL, tree_depth = NULL, 
##     min_n = NULL) 

You could tune the other hyperparameter we didn’t use here, min_n, which sets the minimum n to split at any node. This is another early stopping method for decision trees that can help prevent overfitting. Use this searchable table to find the original argument for min_n in the rpart package (hint). See whether you can tune a different combination of hyperparameters and/or values to improve a tree’s ability to predict respondents’ answers.


Example drawn from Get Started - Tune model parameters and licensed under CC BY-SA 4.0.

