### Load standardpackages
library(tidyverse) # Collection of all the good stuff like dplyr, ggplot2 ect.
library(magrittr) # For extra-piping operators (eg. %<>%)
# Load specific packages
# install.packages("tidymodels") " Install if necessary
library(tidymodels)
Welcome all to this introduction to machine learning (ML). In this session we cover the following topics 1. Generalizating and valididating from ML models. 2. The Bias-Variance Trade-Off 3. Out-of-sample testing and cross-validation workflows 4. Implementing Ml workflows with the tidymodels
ecosystem.
Remeber, the steps in the ML workflow are:
Obtaining data
Cleaning and inspecting
Visualizing and exploring data
Preprocessing data
Fiting and tuning models
Validating models
Communicating insights
While step 1-3 is mainly covered by the general tidyverse
packages such as dplyr
and ggplot2
, step 7 can be done using for instance rmarkdown
(like me here) or developing an interactive shiny
application. We will touch upon that, but the main focus here lies in the steps 5-6, the core of ML work.
These steps are mainly covered by the packages to be found in the tidymodels
ecosystem, which take care of sampling, fitting, tuning, and evaluating models and data.
tidymodels
is an ecosystem of packages to implement efficient and consisting SML modelling workflows consistent with the tidy principles and neathly fitting into tidy workflows. It contains the following packages
rsample
provides infrastructure for efficient data splitting and resampling.parsnip
is a tidy, unified interface to models independent of the particular package syntax.recipes
is a tidy interface to data pre-processing tools for feature engineering.workflows
bundle your pre-processing, modeling, and post-processing together.tune
optimizes the hyperparameters.yardstick
provides model performance metrics.broom
converts the information in common statistical R objects into user-friendly tidy formats.dials
creates and manages tuning parameters and parameter grids.I will tap into most of them during this and later sessions, therefore it makes sense to upfront load th complete tidymodels
ecosystem.
Lets get started.
Let’ do a brief example for a simple linear model. We generate some data, where \(y\) is a linear function of \(x\) plus some random error.
set.seed(1337)
beta0 = 15
beta1 = 0.3
data_reg <- tibble(x = runif(500, min = 0, max = 100),
y = beta0+ (beta1*x) + rnorm(500, sd = 5))
data_reg %>% ggplot(aes(x = x, y = y)) +
geom_point() +
geom_rug(size = 0.1, alpha = 0.75)
We can now fit a linear regression model that aims at discovering the underlying relationship.
fit_lm <- data_reg %>% lm(formula = y ~ x)
fit_lm %>% summary()
Call:
lm(formula = y ~ x, data = .)
Residuals:
Min 1Q Median 3Q Max
-15.423 -3.317 -0.170 3.337 17.157
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 14.791865 0.452922 32.66 <2e-16 ***
x 0.303863 0.007865 38.63 <2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 5.123 on 498 degrees of freedom
Multiple R-squared: 0.7498, Adjusted R-squared: 0.7493
F-statistic: 1493 on 1 and 498 DF, p-value: < 2.2e-16
We see it got the underlying relationship somewhat correct. Keep in mind, its ability to discover it is also limited by the small sample, where small random errors dan bias the result.
Note: This is exactly what geom_smooth()
in ggplot
does when giving it the method="lm"
parameter. Lets take a look at it visually.
data_reg %>% ggplot(aes(x = x, y = y)) +
geom_point() +
geom_smooth(method = "lm", formula = y ~ x, se = TRUE)
We can now use predict()
to predict y values due to the fitted model.
data_reg %<>%
mutate(predicted = fit_lm %>% predict())
data_reg %>% ggplot(aes(x = x, y = y)) +
geom_segment(aes(xend = x, yend = predicted), alpha = .2) +
geom_point(alpha = 0.5) +
geom_point(aes(y = predicted), col = 'red', shape = 21)
It obviously predicts along th straight function line. Due to the random noise introduced, it is most of the time off a bit. Lets calculate the error term
error_reg <- pull(data_reg, y) - pull(data_reg, predicted)
error_reg %>% mean()
[1] 3.036836e-14
On average the error is very low. However, keep in mind positive and negative errors cancel each others out. Lets look at the RSME better.
sqrt(mean(error_reg ^ 2)) # Calculate RMSE
[1] 5.112672
Btw: Could also be piped…
error_reg^2 %>% mean() %>% sqrt()
[1] 5.112672
However, we predicted on the data the model was fitted on. How would it fair on new data?
set.seed(1338)
data_reg_new <- tibble(x = runif(500, min = 0, max = 100),
y = beta0+ (beta1*x) + rnorm(500, sd = 5))
pred_reg_new <- fit_lm %>% predict(new_data = data_reg_new)
error_reg_new <- error <- pull(data_reg_new, y) - pred_reg_new
error_reg_new^2 %>% mean() %>% sqrt()
[1] 13.27436
Ok, lets try the same with a binary class prediction. Lets create a random x and an associated binary y.
set.seed(1337)
beta1 <- 5
data_clas <- tibble(
x = rnorm(500),
y = rbinom(500, size = 1, prob = 1/(1+exp(-(beta1*x))) ) %>% as.logical() %>% factor()
)
data_clas %>% head()
data_clas %>%
ggplot(aes(x = x, y = y)) +
geom_point(alpha = 0.5)
lets fit a logistic regression on that
fit_log <- data_clas %>%
glm(formula = y ~ x, family = 'binomial')
fit_log %>% summary()
Call:
glm(formula = y ~ x, family = "binomial", data = .)
Deviance Residuals:
Min 1Q Median 3Q Max
-2.90735 -0.23058 -0.00276 0.22925 2.68607
Coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -0.01544 0.17213 -0.090 0.929
x 5.27883 0.52813 9.995 <2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 692.86 on 499 degrees of freedom
Residual deviance: 219.50 on 498 degrees of freedom
AIC: 223.5
Number of Fisher Scoring iterations: 7
We can again visualize it:
data_clas %>%
mutate(y = y %>% as.logical() %>% as.numeric()) %>%
ggplot(aes(x = x, y = y)) +
geom_point(alpha = 0.5) +
geom_smooth(method = "glm", method.args = list(family = "binomial"), se = FALSE)
We again can use this fitted model to predict the datapoints y-class. Here, we have the choice to either report the predicted class or the predicted probability. We here do both.
data_clas %<>%
mutate(predicted = fit_log %>% predict(type = 'response'),
predicted_class = predicted %>% round(0) %>% as.logical() %>% factor())
data_clas %>% head()
cm_log <- data_clas %>% conf_mat(y, predicted_class)
cm_log %>% autoplot(type = "heatmap")
cm_log %>% summary() %>% mutate(.estimate = .estimate %>% round(3)) %>% select(-.estimator)
roc_log <- data_clas %>%
roc_curve(y, predicted, event_level = 'second')
roc_log %>% head()
data_clas %>% roc_auc(y, predicted, event_level = 'second')
roc_log %>% autoplot()
Again, lets create some new data to test
set.seed(1338)
beta1 <- 5
data_clas_new <- tibble(
x = rnorm(500),
y = rbinom(500, size = 1, prob = 1/(1+exp(-(beta1*x))) ) %>% as.logical() %>% factor()
)
data_clas_new %<>%
mutate(predicted = fit_log %>% predict(type = 'response', newdata = data_clas_new),
predicted_class = predicted %>% round(0) %>% as.logical() %>% factor())
cm_log_new <- data_clas_new %>% conf_mat(y, predicted_class)
cm_log_new %>% summary() %>% mutate(.estimate = .estimate %>% round(3)) %>% select(-.estimator)
data_clas %>% roc_auc(y, predicted, event_level = 'second')
Ok, that all now looked a bit cumbersome. Lets do it a bit more advanced and flexible introducing the tidymodel
ML workflow. Here, we would apply the following standard workflow:
rsample
function initial_split()
recipe
parsnip
rsample
packagetune
package to tune hyperparameters and the dials
package to manage the hyperparameter searchWe will load a standard dataset from mlbench
, the BostonHousing dataset. It comes as a dataframe with 506 observations on 14 features, the last one medv
being the outcome:
crim
per capita crime rate by townzn
proportion of residential land zoned for lots over 25,000 sq.ftindus
proportion of non-retail business acres per townchas
Charles River dummy variable (= 1 if tract bounds river; 0 otherwise) (deselected in this case)nox
nitric oxides concentration (parts per 110 million)rm
average number of rooms per dwellingage
proportion of owner-occupied units built prior to 1940dis
weighted distances to five Boston employment centresrad
index of accessibility to radial highwaystax
full-value property-tax rate per USD 10,000ptratio
pupil-teacher ratio by townb
1000(B - 0.63)^2 where B is the proportion of blacks by townlstat
lower status of the populationmedv
median value of owner-occupied homes in USD 1000’s (our outcome to predict)Source: Harrison, D. and Rubinfeld, D.L. “Hedonic prices and the demand for clean air”, J. Environ. Economics & Management, vol.5, 81-102, 1978.
These data have been taken from the UCI Repository Of Machine Learning Databases
# install.packages('mlbench')# Install if necessary
library(mlbench) # Library including many ML benchmark datasets
data(BostonHousing)
data <- BostonHousing %>% as_tibble() %>% select(-chas)
rm(BostonHousing)
data %>% head()
data %>% glimpse()
Rows: 506
Columns: 13
$ crim <dbl> 0.00632, 0.02731, 0.02729, 0.03237, 0.06905, 0.02985, 0.08829, 0.14455, 0.21124, 0.17004, 0.22489, 0.11747, 0.09378, 0.62976, 0.63796, 0.62739, 1.05393, 0.78420, 0.80271, 0…
$ zn <dbl> 18.0, 0.0, 0.0, 0.0, 0.0, 0.0, 12.5, 12.5, 12.5, 12.5, 12.5, 12.5, 12.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,…
$ indus <dbl> 2.31, 7.07, 7.07, 2.18, 2.18, 2.18, 7.87, 7.87, 7.87, 7.87, 7.87, 7.87, 7.87, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14, 8.14…
$ nox <dbl> 0.538, 0.469, 0.469, 0.458, 0.458, 0.458, 0.524, 0.524, 0.524, 0.524, 0.524, 0.524, 0.524, 0.538, 0.538, 0.538, 0.538, 0.538, 0.538, 0.538, 0.538, 0.538, 0.538, 0.538, 0.53…
$ rm <dbl> 6.575, 6.421, 7.185, 6.998, 7.147, 6.430, 6.012, 6.172, 5.631, 6.004, 6.377, 6.009, 5.889, 5.949, 6.096, 5.834, 5.935, 5.990, 5.456, 5.727, 5.570, 5.965, 6.142, 5.813, 5.92…
$ age <dbl> 65.2, 78.9, 61.1, 45.8, 54.2, 58.7, 66.6, 96.1, 100.0, 85.9, 94.3, 82.9, 39.0, 61.8, 84.5, 56.5, 29.3, 81.7, 36.6, 69.5, 98.1, 89.2, 91.7, 100.0, 94.1, 85.7, 90.3, 88.8, 94…
$ dis <dbl> 4.0900, 4.9671, 4.9671, 6.0622, 6.0622, 6.0622, 5.5605, 5.9505, 6.0821, 6.5921, 6.3467, 6.2267, 5.4509, 4.7075, 4.4619, 4.4986, 4.4986, 4.2579, 3.7965, 3.7965, 3.7979, 4.01…
$ rad <dbl> 1, 2, 2, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 3, 5, 2, 5…
$ tax <dbl> 296, 242, 242, 222, 222, 222, 311, 311, 311, 311, 311, 311, 311, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 307, 30…
$ ptratio <dbl> 15.3, 17.8, 17.8, 18.7, 18.7, 18.7, 15.2, 15.2, 15.2, 15.2, 15.2, 15.2, 15.2, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0, 21.0…
$ b <dbl> 396.90, 396.90, 392.83, 394.63, 396.90, 394.12, 395.60, 396.90, 386.63, 386.71, 392.52, 396.90, 390.50, 396.90, 380.02, 395.62, 386.85, 386.75, 288.99, 390.95, 376.57, 392.…
$ lstat <dbl> 4.98, 9.14, 4.03, 2.94, 5.33, 5.21, 12.43, 19.15, 29.93, 17.10, 20.45, 13.27, 15.71, 8.26, 10.26, 8.47, 6.58, 14.67, 11.69, 11.28, 21.02, 13.83, 18.72, 19.88, 16.30, 16.51,…
$ medv <dbl> 24.0, 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15.0, 18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6, 15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4…
In this exercise, we will predict medv
(median value of owner-occupied homes in USD). Such a model would in the real world be used to predict developments in housing prices, eg. to inform policy makers or potential investors. In case I have only one target outcome, I prefer to name it as y
. This simple naming convention helps to re-use code across datasets.
data %<>%
rename(y = medv) %>%
relocate(y)
Lets take a look at some descriptives.
data %>%
summarise(across(everything(), list(min = min, mean = mean,max = max, sd = sd), .names = "{.col}_{.fn}")) %>%
mutate(across(everything(), round, 2)) %>%
pivot_longer(everything(),
names_sep = "_",
names_to = c("variable", ".value"))
Ok, time for some visual exploration. Here I will introduce the GGally
package, a wrapper for ggplot2
which has some functions for very nice visual summaries in matrix form.
First, lets look at a classical correlation matrix.
# install.packages('GGally') # Install if necessary
data %>%
GGally::ggcorr(label = TRUE,
label_size = 3,
label_round = 2,
label_alpha = TRUE)
Even cooler, the ggpairs
function creates you a scatterplot matrix plus all variable distributions and correlations.
data %>%
GGally::ggpairs(aes(alpha = 0.3),
ggtheme = theme_gray())
First, we split our data in training and test sample. We use the initial_split
function of the rsample
pckage.
data_split <- initial_split(data, prop = 0.75, strata = y)
data_train <- data_split %>% training()
data_test <- data_split %>% testing()
We use the recipe
package to automatize and standardize all necessary pre-processing workflows.
Here, we do only some simple transformations. * We normalize all numeric data by centering (subtracting the mean) and scaling (divide by standard deviation). * We remove features with near-zero-variance, which would not help the model a lot. * We here also add a simple way to already in the preprocessing deal with missing data. recipes
has inbuild missing value inputation algorithms, such as ‘k-nearest-neighbors’.
data_recipe <- data_train %>%
recipe(y ~.) %>%
step_center(all_numeric(), -all_outcomes()) %>% # Centers all numeric variables to mean = 0
step_scale(all_numeric(), -all_outcomes()) %>% # scales all numeric variables to sd = 1
step_nzv(all_predictors()) %>% # Removed predictors with zero variance
step_knnimpute(all_predictors()) %>% # knn inputation of missing values
prep()
data_recipe
Data Recipe
Inputs:
Training data contained 381 data points and no missing data.
Operations:
Centering for crim, zn, indus, nox, rm, age, dis, rad, tax, ptratio, b, lstat [trained]
Scaling for crim, zn, indus, nox, rm, age, dis, rad, tax, ptratio, b, lstat [trained]
Sparse, unbalanced variable filter removed no terms [trained]
K-nearest neighbor imputation for zn, indus, nox, rm, age, dis, rad, tax, ptratio, b, lstat, crim [trained]
First of all, we will define the models we will run here. In detail, we will run a:
There is no particular reason other than to demonstrate different models with increasing complexity and hyperparameter tuning options.
To set up a model with parsnip
, the following syntax applies:
model_XX <- model_family(mode = 'regression/classification',
parameter_1 = 123,
parameter_2 = tune()) %>%
set_engine('packagename')
model_lm <- linear_reg(mode = 'regression') %>%
set_engine('lm')
model_el <-linear_reg(mode = 'regression',
penalty = tune(),
mixture = tune()) %>%
set_engine("glmnet")
model_rf <- rand_forest(mode = 'regression',
trees = 25,
mtry = tune(),
min_n = tune()
) %>%
set_engine('ranger', importance = 'impurity')
We now define workflows
by putting the preprocessing recipe together with the corresponding models. Not a necessary step, but I find it neath.
workflow_general <- workflow() %>%
add_recipe(data_recipe)
workflow_lm <- workflow_general %>%
add_model(model_lm)
workflow_el <- workflow_general %>%
add_model(model_el)
workflow_rf <- workflow_general %>%
add_model(model_rf)
data_resample <- bootstraps(data_train,
strata = y,
times = 5)
data_resample %>% glimpse()
Rows: 5
Columns: 2
$ splits <list> [<boot_split[381 x 156 x 381 x 13]>], [<boot_split[381 x 140 x 381 x 13]>], [<boot_split[381 x 134 x 381 x 13]>], [<boot_split[381 x 140 x 381 x 13]>], [<boot_split[381 x 1…
$ id <chr> "Bootstrap1", "Bootstrap2", "Bootstrap3", "Bootstrap4", "Bootstrap5"
tune_el <-
tune_grid(
workflow_el,
resamples = data_resample,
grid = 10
)
tune_el %>% autoplot()
best_param_el <- tune_el %>% select_best(metric = 'rmse')
best_param_el
tune_el %>% show_best(metric = 'rmse', n = 1)
tune_rf <-
tune_grid(
workflow_rf,
resamples = data_resample,
grid = 10
)
tune_rf %>% autoplot()
best_param_rf <- tune_rf %>% select_best(metric = 'rmse')
best_param_rf
tune_rf %>% show_best(metric = 'rmse', n = 1)
Alright, now we can fit the final models. Therefore, we have to first upate the formerly created workflows, where we fill the tune()
placeholders with the by now determined best performing hyperparameter setup.
workflow_final_el <- workflow_el %>%
finalize_workflow(parameters = best_param_el)
workflow_final_rf <- workflow_rf %>%
finalize_workflow(parameters = best_param_rf)
fit_lm <- workflow_lm %>%
fit(data_train)
fit_el <- workflow_final_el %>%
fit(data_train)
fit_rf <- workflow_final_rf %>%
fit(data_train)
pred_collected <- tibble(
truth = data_train %>% pull(y),
base = mean(truth),
lm = fit_lm %>% predict(new_data = data_train) %>% pull(.pred),
el = fit_el %>% predict(new_data = data_train) %>% pull(.pred),
rf = fit_rf %>% predict(new_data = data_train) %>% pull(.pred),
) %>%
pivot_longer(cols = -truth,
names_to = 'model',
values_to = '.pred')
pred_collected %>% head()
pred_collected %>%
group_by(model) %>%
rmse(truth = truth, estimate = .pred) %>%
select(model, .estimate) %>%
arrange(.estimate)
pred_collected %>%
ggplot(aes(x = truth, y = .pred, color = model)) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
geom_point(alpha = 0.5) +
labs(
x = "Truth",
y = "Predicted price",
color = "Type of model"
)
So, now we are almost there. Since we know we will use the random forest, we only have to predict on our test sample and see how we fair…
fit_last_rf <- workflow_final_rf %>% last_fit(split = data_split)
fit_last_rf %>% collect_metrics()
fit_last_rf %>%
pluck(".workflow", 1) %>%
pull_workflow_fit() %>%
vip::vip(num_features = 10)
fit_el %>%
pull_workflow_fit() %>%
vip::vip(num_features = 10)
Customer churn refers to the situation when a customer ends their relationship with a company, and it’s a costly problem. Customers are the fuel that powers a business. Loss of customers impacts sales. Further, it’s much more difficult and costly to gain new customers than it is to retain existing customers. As a result, organizations need to focus on reducing customer churn.
The good news is that machine learning can help. For many businesses that offer subscription based services, it’s critical to both predict customer churn and explain what features relate to customer churn.
We now dive into the IBM Watson Telco Dataset. According to IBM, the business challenge is.
A telecommunications company [Telco] is concerned about the number of customers leaving their landline business for cable competitors. They need to understand who is leaving. Imagine that you’re an analyst at this company and you have to find out who is leaving and why.
The dataset includes information about:
Churn
data <- readRDS(url("https://github.com/SDS-AAU/SDS-master/raw/master/00_data/telco_churn.rds")) # notice that for readRDS i have to wrap the adress in url()
data %>% head()
data %>% glimpse()
Rows: 7,043
Columns: 21
$ customerID <chr> "7590-VHVEG", "5575-GNVDE", "3668-QPYBK", "7795-CFOCW", "9237-HQITU", "9305-CDSKC", "1452-KIOVK", "6713-OKOMC", "7892-POOKP", "6388-TABGU", "9763-GRSKD", "7469-LKB…
$ gender <chr> "Female", "Male", "Male", "Male", "Female", "Female", "Male", "Female", "Female", "Male", "Male", "Male", "Male", "Male", "Male", "Female", "Female", "Male", "Fema…
$ SeniorCitizen <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1…
$ Partner <chr> "Yes", "No", "No", "No", "No", "No", "No", "No", "Yes", "No", "Yes", "No", "Yes", "No", "No", "Yes", "No", "No", "Yes", "No", "No", "Yes", "No", "Yes", "Yes", "No"…
$ Dependents <chr> "No", "No", "No", "No", "No", "No", "Yes", "No", "No", "Yes", "Yes", "No", "No", "No", "No", "Yes", "No", "Yes", "Yes", "No", "No", "No", "No", "No", "Yes", "No", …
$ tenure <int> 1, 34, 2, 45, 2, 8, 22, 10, 28, 62, 13, 16, 58, 49, 25, 69, 52, 71, 10, 21, 1, 12, 1, 58, 49, 30, 47, 1, 72, 17, 71, 2, 27, 1, 1, 72, 5, 46, 34, 11, 10, 70, 17, 63…
$ PhoneService <chr> "No", "Yes", "Yes", "No", "Yes", "Yes", "Yes", "No", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "Yes", "Yes", "Yes",…
$ MultipleLines <chr> "No phone service", "No", "No", "No phone service", "No", "Yes", "Yes", "No phone service", "Yes", "No", "No", "No", "Yes", "Yes", "No", "Yes", "No", "Yes", "No", …
$ InternetService <chr> "DSL", "DSL", "DSL", "DSL", "Fiber optic", "Fiber optic", "Fiber optic", "DSL", "Fiber optic", "DSL", "DSL", "No", "Fiber optic", "Fiber optic", "Fiber optic", "Fi…
$ OnlineSecurity <chr> "No", "Yes", "Yes", "Yes", "No", "No", "No", "Yes", "No", "Yes", "Yes", "No internet service", "No", "No", "Yes", "Yes", "No internet service", "Yes", "No", "No", …
$ OnlineBackup <chr> "Yes", "No", "Yes", "No", "No", "No", "Yes", "No", "No", "Yes", "No", "No internet service", "No", "Yes", "No", "Yes", "No internet service", "No", "No", "Yes", "N…
$ DeviceProtection <chr> "No", "Yes", "No", "Yes", "No", "Yes", "No", "No", "Yes", "No", "No", "No internet service", "Yes", "Yes", "Yes", "Yes", "No internet service", "Yes", "Yes", "Yes"…
$ TechSupport <chr> "No", "No", "No", "Yes", "No", "No", "No", "No", "Yes", "No", "No", "No internet service", "No", "No", "Yes", "Yes", "No internet service", "No", "Yes", "No", "No"…
$ StreamingTV <chr> "No", "No", "No", "No", "No", "Yes", "Yes", "No", "Yes", "No", "No", "No internet service", "Yes", "Yes", "Yes", "Yes", "No internet service", "Yes", "No", "No", "…
$ StreamingMovies <chr> "No", "No", "No", "No", "No", "Yes", "No", "No", "Yes", "No", "No", "No internet service", "Yes", "Yes", "Yes", "Yes", "No internet service", "Yes", "No", "Yes", "…
$ Contract <chr> "Month-to-month", "One year", "Month-to-month", "One year", "Month-to-month", "Month-to-month", "Month-to-month", "Month-to-month", "Month-to-month", "One year", "…
$ PaperlessBilling <chr> "Yes", "No", "Yes", "No", "Yes", "Yes", "Yes", "No", "Yes", "No", "Yes", "No", "No", "Yes", "Yes", "No", "No", "No", "No", "Yes", "Yes", "No", "No", "Yes", "No", "…
$ PaymentMethod <chr> "Electronic check", "Mailed check", "Mailed check", "Bank transfer (automatic)", "Electronic check", "Electronic check", "Credit card (automatic)", "Mailed check",…
$ MonthlyCharges <dbl> 29.85, 56.95, 53.85, 42.30, 70.70, 99.65, 89.10, 29.75, 104.80, 56.15, 49.95, 18.95, 100.35, 103.70, 105.50, 113.25, 20.65, 106.70, 55.20, 90.05, 39.65, 19.80, 20.…
$ TotalCharges <dbl> 29.85, 1889.50, 108.15, 1840.75, 151.65, 820.50, 1949.40, 301.90, 3046.05, 3487.95, 587.45, 326.80, 5681.10, 5036.30, 2686.05, 7895.15, 1022.95, 7382.25, 528.35, 1…
$ Churn <chr> "No", "No", "Yes", "No", "Yes", "Yes", "No", "No", "Yes", "No", "No", "No", "No", "Yes", "No", "No", "No", "No", "Yes", "No", "Yes", "No", "Yes", "No", "No", "No",…
data %<>%
rename(y = Churn) %>%
select(y, everything(), -customerID)
data %>% summary()
y gender SeniorCitizen Partner Dependents tenure PhoneService MultipleLines InternetService OnlineSecurity
Length:7043 Length:7043 Min. :0.0000 Length:7043 Length:7043 Min. : 0.00 Length:7043 Length:7043 Length:7043 Length:7043
Class :character Class :character 1st Qu.:0.0000 Class :character Class :character 1st Qu.: 9.00 Class :character Class :character Class :character Class :character
Mode :character Mode :character Median :0.0000 Mode :character Mode :character Median :29.00 Mode :character Mode :character Mode :character Mode :character
Mean :0.1621 Mean :32.37
3rd Qu.:0.0000 3rd Qu.:55.00
Max. :1.0000 Max. :72.00
OnlineBackup DeviceProtection TechSupport StreamingTV StreamingMovies Contract PaperlessBilling PaymentMethod MonthlyCharges TotalCharges
Length:7043 Length:7043 Length:7043 Length:7043 Length:7043 Length:7043 Length:7043 Length:7043 Min. : 18.25 Min. : 18.8
Class :character Class :character Class :character Class :character Class :character Class :character Class :character Class :character 1st Qu.: 35.50 1st Qu.: 401.4
Mode :character Mode :character Mode :character Mode :character Mode :character Mode :character Mode :character Mode :character Median : 70.35 Median :1397.5
Mean : 64.76 Mean :2283.3
3rd Qu.: 89.85 3rd Qu.:3794.7
Max. :118.75 Max. :8684.8
NA's :11
Next, lets have a first visual inspections. Many models in our prediction exercise to follow require the conditional distribution of the features to be different for the outcomes states to be predicted. So, lets take a look. Here, ggplot2
plus the ggridges
package is my favorite. It is particularly helpfull when dealing with many variables, where you want to see differences in their conditional distribution with respect to an outcome of interest.
# install.packages('ggridges') # install if necessary
data %>%
gather(variable, value, -y) %>% # Note: At one point do pivot_longer instead
ggplot(aes(y = as.factor(variable),
fill = as.factor(y),
x = percent_rank(value)) ) +
ggridges::geom_density_ridges(alpha = 0.75)
data_split <- initial_split(data, prop = 0.75, strata = y)
data_train <- data_split %>% training()
data_test <- data_split %>% testing()
Here, I do the following preprocessing:
TotalCharges
data_recipe <- data_train %>%
recipe(y ~.) %>%
step_log(TotalCharges) %>%
step_center(all_numeric(), -all_outcomes()) %>%
step_scale(all_numeric(), -all_outcomes()) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_knnimpute(all_predictors()) %>% # knn inputation of missing values
prep()
model_lg <- logistic_reg(mode = 'classification') %>%
set_engine('glm', family = binomial)
model_dt <- decision_tree(mode = 'classification',
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
) %>%
set_engine('rpart')
model_xg <- boost_tree(mode = 'classification',
trees = 100,
mtry = tune(),
min_n = tune(),
tree_depth = tune(),
learn_rate = tune()
) %>%
set_engine("xgboost")
workflow_general <- workflow() %>%
add_recipe(data_recipe)
workflow_lg <- workflow_general %>%
add_model(model_lg)
workflow_dt <- workflow_general %>%
add_model(model_dt)
workflow_xg <- workflow_general %>%
add_model(model_xg)
data_resample <- data_train %>%
vfold_cv(strata = y,
v = 3,
repeats = 3)
tune_dt <-
tune_grid(
workflow_dt,
resamples = data_resample,
grid = 10
)
tune_dt %>% autoplot()
best_param_dt <- tune_dt %>% select_best(metric = 'roc_auc')
best_param_dt
tune_dt %>% show_best(metric = 'roc_auc', n = 1)
tune_xg <-
tune_grid(
workflow_xg,
resamples = data_resample,
grid = 10
)
tune_xg %>% autoplot()
best_param_xg <- tune_xg %>% select_best(metric = 'roc_auc')
best_param_xg
tune_xg %>% show_best(metric = 'roc_auc', n = 1)
workflow_final_dt <- workflow_dt %>%
finalize_workflow(parameters = best_param_dt)
workflow_final_xg <- workflow_xg %>%
finalize_workflow(parameters = best_param_xg)
fit_lg <- workflow_lg %>%
fit(data_train)
fit_dt <- workflow_final_dt %>%
fit(data_train)
fit_xg <- workflow_final_xg %>%
fit(data_train)
[10:08:57] WARNING: amalgamation/../src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.
pred_collected <- tibble(
truth = data_train %>% pull(y) %>% as.factor(),
#base = mean(truth),
lg = fit_lg %>% predict(new_data = data_train) %>% pull(.pred_class),
dt = fit_dt %>% predict(new_data = data_train) %>% pull(.pred_class),
xg = fit_xg %>% predict(new_data = data_train) %>% pull(.pred_class),
) %>%
pivot_longer(cols = -truth,
names_to = 'model',
values_to = '.pred')
pred_collected %>% head()
pred_collected %>%
group_by(model) %>%
accuracy(truth = truth, estimate = .pred) %>%
select(model, .estimate) %>%
arrange(desc(.estimate))
pred_collected %>%
group_by(model) %>%
bal_accuracy(truth = truth, estimate = .pred) %>%
select(model, .estimate) %>%
arrange(desc(.estimate))
Surprisingly, here the less complex model seems to hve the edge!
So, now we are almost there. Since we know we will use the random forest, we only have to predict on our test sample and see how we fair…
fit_last_dt <- workflow_final_dt %>% last_fit(split = data_split)
fit_last_dt %>% collect_metrics()
fit_last_dt %>%
pluck(".workflow", 1) %>%
pull_workflow_fit() %>%
vip::vip(num_features = 10)
fit_xg %>%
pull_workflow_fit() %>%
vip::vip(num_features = 10)
tidymodels
: Tidy statistical and predictive modeling ecosystem. Full of introductions, examples, and further materialcaret
and other slowly declining ML package ecosystems.
tidymodels
and caret
sessionInfo()
R version 4.0.3 (2020-10-10)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Catalina 10.15.7
Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
Random number generation:
RNG: L'Ecuyer-CMRG
Normal: Inversion
Sample: Rejection
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] xgboost_1.3.2.1 rpart_4.1-15 ranger_0.12.1 glmnet_4.1-1 Matrix_1.3-2 vctrs_0.3.6 rlang_0.4.10 mlbench_2.1-3 yardstick_0.0.7 workflows_0.2.1
[11] tune_0.1.3 rsample_0.0.9 recipes_0.1.15 parsnip_0.1.5 modeldata_0.1.0 infer_0.5.4 dials_0.0.9 scales_1.1.1 broom_0.7.5 tidymodels_0.1.2
[21] knitr_1.31 magrittr_2.0.1 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.5 purrr_0.3.4 readr_1.4.0 tidyr_1.1.3 tibble_3.1.0 ggplot2_3.3.3
[31] tidyverse_1.3.0
loaded via a namespace (and not attached):
[1] colorspace_2.0-0 ellipsis_0.3.1 class_7.3-18 ggridges_0.5.3 rsconnect_0.8.16 base64enc_0.1-3 fs_1.5.0 rstudioapi_0.13 farver_2.1.0
[10] listenv_0.8.0 furrr_0.2.2 prodlim_2019.11.13 fansi_0.4.2 lubridate_1.7.10 xml2_1.3.2 codetools_0.2-18 splines_4.0.3 jsonlite_1.7.2
[19] pROC_1.17.0.1 dbplyr_2.1.0 compiler_4.0.3 httr_1.4.2 backports_1.2.1 assertthat_0.2.1 cli_2.3.1 prettyunits_1.1.1 htmltools_0.5.1.1
[28] tools_4.0.3 gtable_0.3.0 glue_1.4.2 Rcpp_1.0.6 cellranger_1.1.0 jquerylib_0.1.3 DiceDesign_1.9 nlme_3.1-152 debugme_1.1.0
[37] iterators_1.0.13 timeDate_3043.102 gower_0.2.2 xfun_0.21 globals_0.14.0 rvest_0.3.6 lifecycle_1.0.0 pacman_0.5.1 future_1.21.0
[46] MASS_7.3-53.1 ipred_0.9-10 hms_1.0.0 parallel_4.0.3 RColorBrewer_1.1-2 yaml_2.2.1 gridExtra_2.3 sass_0.3.1 reshape_0.8.8
[55] stringi_1.5.3 foreach_1.5.1 lhs_1.1.1 hardhat_0.1.5 shape_1.4.5 lava_1.6.8.1 repr_1.1.3 pkgconfig_2.0.3 evaluate_0.14
[64] lattice_0.20-41 labeling_0.4.2 tidyselect_1.1.0 parallelly_1.23.0 GGally_2.1.1 plyr_1.8.6 R6_2.5.0 generics_0.1.0 DBI_1.1.1
[73] mgcv_1.8-34 pillar_1.5.1 haven_2.3.1 withr_2.4.1 survival_3.2-7 nnet_7.3-15 modelr_0.1.8 crayon_1.4.1 vip_0.3.2
[82] utf8_1.1.4 rmarkdown_2.7 progress_1.2.2 grid_4.0.3 readxl_1.3.1 data.table_1.14.0 reprex_1.0.0 digest_0.6.27 munsell_0.5.0
[91] GPfit_1.0-8 skimr_2.1.3 bslib_0.2.4