Objective
- Use purrr to run multiple regressions on group_by column
- Use broom to collect stats
- Use ggplot2 visualise model performance
Data
data("mtcars")
summary(mtcars)
## mpg cyl disp hp
## Min. :10.40 Min. :4.000 Min. : 71.1 Min. : 52.0
## 1st Qu.:15.43 1st Qu.:4.000 1st Qu.:120.8 1st Qu.: 96.5
## Median :19.20 Median :6.000 Median :196.3 Median :123.0
## Mean :20.09 Mean :6.188 Mean :230.7 Mean :146.7
## 3rd Qu.:22.80 3rd Qu.:8.000 3rd Qu.:326.0 3rd Qu.:180.0
## Max. :33.90 Max. :8.000 Max. :472.0 Max. :335.0
## drat wt qsec vs
## Min. :2.760 Min. :1.513 Min. :14.50 Min. :0.0000
## 1st Qu.:3.080 1st Qu.:2.581 1st Qu.:16.89 1st Qu.:0.0000
## Median :3.695 Median :3.325 Median :17.71 Median :0.0000
## Mean :3.597 Mean :3.217 Mean :17.85 Mean :0.4375
## 3rd Qu.:3.920 3rd Qu.:3.610 3rd Qu.:18.90 3rd Qu.:1.0000
## Max. :4.930 Max. :5.424 Max. :22.90 Max. :1.0000
## am gear carb
## Min. :0.0000 Min. :3.000 Min. :1.000
## 1st Qu.:0.0000 1st Qu.:3.000 1st Qu.:2.000
## Median :0.0000 Median :4.000 Median :2.000
## Mean :0.4062 Mean :3.688 Mean :2.812
## 3rd Qu.:1.0000 3rd Qu.:4.000 3rd Qu.:4.000
## Max. :1.0000 Max. :5.000 Max. :8.000
mtcars %>% count(am,cyl)
## am cyl n
## 1 0 4 3
## 2 0 6 4
## 3 0 8 12
## 4 1 4 8
## 5 1 6 3
## 6 1 8 2
Train models
train_model <- function(df){
lm(mpg ~ wt, data=df)
}
model_df <- mtcars %>%
group_by(am,cyl) %>%
nest() %>%
mutate(model=map(data,train_model)) %>%
mutate(tidy=model %>% map(tidy),
glance=model %>% map(glance),
augment = model %>% map(augment),
rsq = glance %>% map_dbl("r.squared")
)
unnest(model_df,augment) %>% head()
## # A tibble: 6 x 15
## # Groups: cyl, am [2]
## cyl am data model tidy glance mpg wt .fitted .resid .hat .sigma
## <dbl> <dbl> <lis> <lis> <lis> <list> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 6 1 <tib… <lm> <tib… <tibb… 21 2.62 20.6 0.353 0.888 NaN
## 2 6 1 <tib… <lm> <tib… <tibb… 21 2.88 20.5 0.505 0.772 Inf
## 3 6 1 <tib… <lm> <tib… <tibb… 19.7 2.77 20.6 -0.858 0.340 NaN
## 4 4 1 <tib… <lm> <tib… <tibb… 22.8 2.32 25.9 -3.08 0.191 3.34
## 5 4 1 <tib… <lm> <tib… <tibb… 32.4 2.2 26.8 5.57 0.146 2.50
## 6 4 1 <tib… <lm> <tib… <tibb… 30.4 1.62 31.4 -1.05 0.281 3.64
## # … with 3 more variables: .cooksd <dbl>, .std.resid <dbl>, rsq <dbl>
unnest(model_df,tidy) %>% head()
## # A tibble: 6 x 12
## # Groups: cyl, am [3]
## cyl am data model term estimate std.error statistic p.value glance
## <dbl> <dbl> <lis> <lis> <chr> <dbl> <dbl> <dbl> <dbl> <list>
## 1 6 1 <tib… <lm> (Int… 22.2 16.1 1.38 3.99e-1 <tibb…
## 2 6 1 <tib… <lm> wt -0.594 5.83 -0.102 9.35e-1 <tibb…
## 3 4 1 <tib… <lm> (Int… 44.2 6.44 6.86 4.73e-4 <tibb…
## 4 4 1 <tib… <lm> wt -7.89 3.10 -2.55 4.37e-2 <tibb…
## 5 6 0 <tib… <lm> (Int… 63.6 11.9 5.36 3.30e-2 <tibb…
## 6 6 0 <tib… <lm> wt -13.1 3.50 -3.75 6.42e-2 <tibb…
## # … with 2 more variables: augment <list>, rsq <dbl>
model_df %>% ggplot(aes(rsq,reorder(cyl,rsq))) +
geom_point(aes(color=as.factor(am))) +
labs(y="cyl",colour="am")