Train many models using group by via purrr package

Objective

  • Use purrr to run multiple regressions on group_by column
  • Use broom to collect stats
  • Use ggplot2 visualise model performance

Data

data("mtcars")
summary(mtcars)
##       mpg             cyl             disp             hp       
##  Min.   :10.40   Min.   :4.000   Min.   : 71.1   Min.   : 52.0  
##  1st Qu.:15.43   1st Qu.:4.000   1st Qu.:120.8   1st Qu.: 96.5  
##  Median :19.20   Median :6.000   Median :196.3   Median :123.0  
##  Mean   :20.09   Mean   :6.188   Mean   :230.7   Mean   :146.7  
##  3rd Qu.:22.80   3rd Qu.:8.000   3rd Qu.:326.0   3rd Qu.:180.0  
##  Max.   :33.90   Max.   :8.000   Max.   :472.0   Max.   :335.0  
##       drat             wt             qsec             vs        
##  Min.   :2.760   Min.   :1.513   Min.   :14.50   Min.   :0.0000  
##  1st Qu.:3.080   1st Qu.:2.581   1st Qu.:16.89   1st Qu.:0.0000  
##  Median :3.695   Median :3.325   Median :17.71   Median :0.0000  
##  Mean   :3.597   Mean   :3.217   Mean   :17.85   Mean   :0.4375  
##  3rd Qu.:3.920   3rd Qu.:3.610   3rd Qu.:18.90   3rd Qu.:1.0000  
##  Max.   :4.930   Max.   :5.424   Max.   :22.90   Max.   :1.0000  
##        am              gear            carb      
##  Min.   :0.0000   Min.   :3.000   Min.   :1.000  
##  1st Qu.:0.0000   1st Qu.:3.000   1st Qu.:2.000  
##  Median :0.0000   Median :4.000   Median :2.000  
##  Mean   :0.4062   Mean   :3.688   Mean   :2.812  
##  3rd Qu.:1.0000   3rd Qu.:4.000   3rd Qu.:4.000  
##  Max.   :1.0000   Max.   :5.000   Max.   :8.000
mtcars %>% count(am,cyl)
##   am cyl  n
## 1  0   4  3
## 2  0   6  4
## 3  0   8 12
## 4  1   4  8
## 5  1   6  3
## 6  1   8  2

Train models

train_model <- function(df){
  lm(mpg ~ wt, data=df)
}


model_df <- mtcars %>% 
  group_by(am,cyl) %>% 
  nest() %>%
  mutate(model=map(data,train_model))  %>%
  mutate(tidy=model %>% map(tidy),
         glance=model %>% map(glance),
         augment = model %>% map(augment),
         rsq = glance %>% map_dbl("r.squared")
         )
unnest(model_df,augment) %>% head()
## # A tibble: 6 x 15
## # Groups:   cyl, am [2]
##     cyl    am data  model tidy  glance   mpg    wt .fitted .resid  .hat .sigma
##   <dbl> <dbl> <lis> <lis> <lis> <list> <dbl> <dbl>   <dbl>  <dbl> <dbl>  <dbl>
## 1     6     1 <tib… <lm>  <tib… <tibb…  21    2.62    20.6  0.353 0.888 NaN   
## 2     6     1 <tib… <lm>  <tib… <tibb…  21    2.88    20.5  0.505 0.772 Inf   
## 3     6     1 <tib… <lm>  <tib… <tibb…  19.7  2.77    20.6 -0.858 0.340 NaN   
## 4     4     1 <tib… <lm>  <tib… <tibb…  22.8  2.32    25.9 -3.08  0.191   3.34
## 5     4     1 <tib… <lm>  <tib… <tibb…  32.4  2.2     26.8  5.57  0.146   2.50
## 6     4     1 <tib… <lm>  <tib… <tibb…  30.4  1.62    31.4 -1.05  0.281   3.64
## # … with 3 more variables: .cooksd <dbl>, .std.resid <dbl>, rsq <dbl>
unnest(model_df,tidy) %>% head()
## # A tibble: 6 x 12
## # Groups:   cyl, am [3]
##     cyl    am data  model term  estimate std.error statistic p.value glance
##   <dbl> <dbl> <lis> <lis> <chr>    <dbl>     <dbl>     <dbl>   <dbl> <list>
## 1     6     1 <tib… <lm>  (Int…   22.2       16.1      1.38  3.99e-1 <tibb…
## 2     6     1 <tib… <lm>  wt      -0.594      5.83    -0.102 9.35e-1 <tibb…
## 3     4     1 <tib… <lm>  (Int…   44.2        6.44     6.86  4.73e-4 <tibb…
## 4     4     1 <tib… <lm>  wt      -7.89       3.10    -2.55  4.37e-2 <tibb…
## 5     6     0 <tib… <lm>  (Int…   63.6       11.9      5.36  3.30e-2 <tibb…
## 6     6     0 <tib… <lm>  wt     -13.1        3.50    -3.75  6.42e-2 <tibb…
## # … with 2 more variables: augment <list>, rsq <dbl>
model_df %>% ggplot(aes(rsq,reorder(cyl,rsq))) +
  geom_point(aes(color=as.factor(am))) +
  labs(y="cyl",colour="am") 

Avatar
Ray Sun
Data Analytics Professional

My interests include AI/ML and data analytics.

Related