R
tidymodels
机器学习
Author

Kili

Published

2024-08-02

载入R包:

Code
```{r}
library(tidymodels) # for the parsnip package, along with the rest of tidymodels

# Helper packages
library(readr) # for importing data
library(broom.mixed) # for converting bayesian models to tidy tibbles
library(dotwhisker) # for visualizing regression results
```
Code
```{r}
urchins <-
  read_csv("urchins.csv") %>%
  setNames(c("food_regime", "initial_volume", "width")) %>%
  mutate(food_regime = factor(food_regime, levels = c("Initial", "Low", "High")))

glimpse(urchins)
```
Rows: 72
Columns: 3
$ food_regime    <fct> Initial, Initial, Initial, Initial, Initial, Initial, I…
$ initial_volume <dbl> 3.5, 5.0, 8.0, 10.0, 13.0, 13.0, 15.0, 15.0, 16.0, 17.0…
$ width          <dbl> 0.010, 0.020, 0.061, 0.051, 0.041, 0.061, 0.041, 0.071,…
Code
```{r}
#| fig-cap: 按喂养食物进行分组线性回归
ggplot(
  urchins,
  aes(
    x = initial_volume,
    y = width,
    group = food_regime,
    col = food_regime
  )
) +
  geom_point() +
  geom_smooth(method = lm, se = FALSE) +
  scale_color_viridis_d(option = "plasma", end = .7)
```

按喂养食物进行分组线性回归

进行线性回归可用的引擎:

Code
```{r}
show_engines("linear_reg")
```
# A tibble: 7 × 2
  engine mode      
  <chr>  <chr>     
1 lm     regression
2 glm    regression
3 glmnet regression
4 stan   regression
5 spark  regression
6 keras  regression
7 brulee regression

预处理数据–recipes

Code
```{r}
library(tidymodels) # for the recipes package, along with the rest of tidymodels

# Helper packages
library(nycflights13) # for flight data
library(skimr) # for variable summaries
```

载入航班数据预测是否晚点

Code
```{r}
set.seed(123)

flight_data <-
  flights %>%
  mutate(
    # Convert the arrival delay to a factor
    arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
    arr_delay = factor(arr_delay),
    # We will use the date (not date-time) in the recipe below
    date = lubridate::as_date(time_hour)
  ) %>%
  # Include the weather data
  inner_join(weather, by = c("origin", "time_hour")) %>%
  # Only retain the specific columns we will use
  select(
    dep_time, flight, origin, dest, air_time, distance,
    carrier, date, arr_delay, time_hour
  ) %>%
  # Exclude missing data
  na.omit() %>%
  # For creating models, it is better to have qualitative columns
  # encoded as factors (instead of character strings)
  mutate_if(is.character, as.factor)
```

看一眼:

Code
```{r}
glimpse(flight_data)
```
Rows: 325,819
Columns: 10
$ dep_time  <int> 517, 533, 542, 544, 554, 554, 555, 557, 557, 558, 558, 558, …
$ flight    <int> 1545, 1714, 1141, 725, 461, 1696, 507, 5708, 79, 301, 49, 71…
$ origin    <fct> EWR, LGA, JFK, JFK, LGA, EWR, EWR, LGA, JFK, LGA, JFK, JFK, …
$ dest      <fct> IAH, IAH, MIA, BQN, ATL, ORD, FLL, IAD, MCO, ORD, PBI, TPA, …
$ air_time  <dbl> 227, 227, 160, 183, 116, 150, 158, 53, 140, 138, 149, 158, 3…
$ distance  <dbl> 1400, 1416, 1089, 1576, 762, 719, 1065, 229, 944, 733, 1028,…
$ carrier   <fct> UA, UA, AA, B6, DL, UA, B6, EV, B6, AA, B6, B6, UA, UA, AA, …
$ date      <date> 2013-01-01, 2013-01-01, 2013-01-01, 2013-01-01, 2013-01-01,…
$ arr_delay <fct> on_time, on_time, late, on_time, on_time, on_time, on_time, …
$ time_hour <dttm> 2013-01-01 05:00:00, 2013-01-01 05:00:00, 2013-01-01 05:00:…

其中flight与time_hour我们不希望将其作为预测数据,但保留为识别

划分训练与测试集

Code
```{r}
# Fix the random numbers by setting the seed
# This enables the analysis to be reproducible when random numbers are used
set.seed(222)
# Put 3/4 of the data into the training set
data_split <- initial_split(flight_data, prop = 3 / 4)

# Create data frames for the two sets:
train_data <- training(data_split)
test_data <- testing(data_split)
```

开始创建食谱

Code
```{r}
flights_rec <-
  recipe(arr_delay ~ ., data = train_data)
```

现在,我们可以向此配方添加角色。我们可以使用 update_role() 函数让食谱知道 flight 和 time_hour 是具有我们称为“ID”的自定义角色的变量(角色可以具有任何字符值)。虽然我们的公式将训练集中除 arr_delay 以外的所有变量都包括为预测因子,但这告诉配方保留这两个变量,但不要将它们用作结果或预测因子。

Code
```{r}
flights_rec <-
  recipe(arr_delay ~ ., data = train_data) %>%
  update_role(flight, time_hour, new_role = "ID")
summary(flights_rec)
```
# A tibble: 10 × 4
   variable  type      role      source  
   <chr>     <list>    <chr>     <chr>   
 1 dep_time  <chr [2]> predictor original
 2 flight    <chr [2]> ID        original
 3 origin    <chr [3]> predictor original
 4 dest      <chr [3]> predictor original
 5 air_time  <chr [2]> predictor original
 6 distance  <chr [2]> predictor original
 7 carrier   <chr [3]> predictor original
 8 date      <chr [1]> predictor original
 9 time_hour <chr [1]> ID        original
10 arr_delay <chr [3]> outcome   original

对日期进行特征工程

  • 星期几,

  • 月份,以及

  • 该日期是否与假日相对应。

Code
```{r}
flights_rec <-
  recipe(arr_delay ~ ., data = train_data) %>%
  update_role(flight, time_hour, new_role = "ID") %>%
  step_date(date, features = c("dow", "month")) %>%
  step_holiday(date,
    holidays = timeDate::listHolidays("US"), # us的假期
    keep_original_cols = FALSE
  ) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors())
```

简单讲解一下: step_date与step_holiday用于对原日期进行转化,而keep_original_cols去除原日期,step_dummy(all_nominal_predictors())将所有nominal即名义变量转化为哑变量,step_zv去除数量过少的因子,例如

Code
```{r}
test_data %>%
  distinct(dest) %>%
  anti_join(train_data)
```
# A tibble: 1 × 1
  dest 
  <fct>
1 LEX  

dest中lex仅在test set有一个记录

Code
```{r}
flights_rec[["var_info"]]$type
```
[[1]]
[1] "integer" "numeric"

[[2]]
[1] "integer" "numeric"

[[3]]
[1] "factor"    "unordered" "nominal"  

[[4]]
[1] "factor"    "unordered" "nominal"  

[[5]]
[1] "double"  "numeric"

[[6]]
[1] "double"  "numeric"

[[7]]
[1] "factor"    "unordered" "nominal"  

[[8]]
[1] "date"

[[9]]
[1] "datetime"

[[10]]
[1] "factor"    "unordered" "nominal"  

可见有四个变量被转化为了dummy变量

使用recipes

Code
```{r}
lr_mod <- 
  logistic_reg() %>% 
  set_engine("glm")  # 设定引擎

flights_wflow <- 
  workflow() %>%           # 创建工作流
  add_model(lr_mod) %>%    # 添加模型
  add_recipe(flights_rec)   # 添加食谱

flights_wflow  # 查看工作流
```
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps

• step_date()
• step_holiday()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)

Computational engine: glm 
Code
```{r}
flights_fit <- 
  flights_wflow %>% 
  fit(data = train_data)  # 拟合模型


flights_fit  # 查看模型


flights_fit %>% 
  extract_fit_parsnip() %>% 
  tidy()  # 查看模型参数

flights_fit %>% 
  extract_recipe() %>% 
  tidy()   # 查看食谱
```
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps

• step_date()
• step_holiday()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────

Call:  stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)

Coefficients:
                 (Intercept)                      dep_time  
                    7.276446                     -0.001664  
                    air_time                      distance  
                   -0.044014                      0.005071  
         date_USChristmasDay            date_USColumbusDay  
                    1.329336                      0.723927  
    date_USCPulaskisBirthday  date_USDecorationMemorialDay  
                    0.807165                      0.584694  
          date_USElectionDay             date_USGoodFriday  
                    0.947652                      1.246811  
      date_USInaugurationDay        date_USIndependenceDay  
                    0.228947                      2.119747  
             date_USLaborDay       date_USLincolnsBirthday  
                   -1.933737                      0.583750  
          date_USMemorialDay        date_USMLKingsBirthday  
                    1.519217                      0.428585  
          date_USNewYearsDay          date_USPresidentsDay  
                    0.204203                      0.483797  
      date_USThanksgivingDay            date_USVeteransDay  
                    0.152978                      0.717895  
  date_USWashingtonsBirthday                    origin_JFK  
                    0.043305                      0.107289  
                  origin_LGA                      dest_ACK  
                    0.010961                     -1.737954  
                    dest_ALB                      dest_ANC  
                   -1.679551                     -1.203995  
                    dest_ATL                      dest_AUS  
                   -1.607715                     -0.797304  
                    dest_AVL                      dest_BDL  
                   -1.335753                     -1.374705  
                    dest_BGR                      dest_BHM  
                   -1.268924                     -0.893688  
                    dest_BNA                      dest_BOS  
                   -1.276719                     -1.525764  
                    dest_BQN                      dest_BTV  
                   -1.524842                     -1.573955  
                    dest_BUF                      dest_BUR  
                   -1.474470                      0.128706  
                    dest_BWI                      dest_BZN  
                   -1.678767                     -2.082842  
                    dest_CAE                      dest_CAK  
                   -2.001807                     -1.620195  
                    dest_CHO                      dest_CHS  
                   -0.634494                     -1.480732  
                    dest_CLE                      dest_CLT  
                   -1.387610                     -1.586758  

...
and 116 more lines.
# A tibble: 157 × 5
   term                         estimate std.error statistic  p.value
   <chr>                           <dbl>     <dbl>     <dbl>    <dbl>
 1 (Intercept)                   7.28    2.73           2.67 7.64e- 3
 2 dep_time                     -0.00166 0.0000141   -118.   0       
 3 air_time                     -0.0440  0.000563     -78.2  0       
 4 distance                      0.00507 0.00150        3.38 7.32e- 4
 5 date_USChristmasDay           1.33    0.177          7.49 6.93e-14
 6 date_USColumbusDay            0.724   0.170          4.25 2.13e- 5
 7 date_USCPulaskisBirthday      0.807   0.139          5.80 6.57e- 9
 8 date_USDecorationMemorialDay  0.585   0.117          4.98 6.32e- 7
 9 date_USElectionDay            0.948   0.190          4.98 6.25e- 7
10 date_USGoodFriday             1.25    0.167          7.45 9.40e-14
# ℹ 147 more rows
# A tibble: 4 × 6
  number operation type    trained skip  id           
   <int> <chr>     <chr>   <lgl>   <lgl> <chr>        
1      1 step      date    TRUE    FALSE date_7pdqb   
2      2 step      holiday TRUE    FALSE holiday_kl2Li
3      3 step      dummy   TRUE    FALSE dummy_jB5fX  
4      4 step      zv      TRUE    FALSE zv_VpqIE     

建立线性回归model

fit()函数

Code
```{r}
linear_reg() %>%
  set_engine("keras") # 设定引擎
```
Linear Regression Model Specification (regression)

Computational engine: keras 
Code
```{r}
lm_mod <- linear_reg() # 保存
lm_fit <-
  lm_mod %>%
  fit(width ~ initial_volume * food_regime, data = urchins)
lm_fit
```
parsnip model object


Call:
stats::lm(formula = width ~ initial_volume * food_regime, data = data)

Coefficients:
                   (Intercept)                  initial_volume  
                     0.0331216                       0.0015546  
                food_regimeLow                 food_regimeHigh  
                     0.0197824                       0.0214111  
 initial_volume:food_regimeLow  initial_volume:food_regimeHigh  
                    -0.0012594                       0.0005254  

查看一下lm_fit的属性,与传统lm()函数进行对比

Code
```{r}
attributes(lm_fit)
attributes(lm_fit$fit)
lm(width ~ initial_volume * food_regime, data = urchins) |> attributes()
```
$names
[1] "lvl"          "spec"         "fit"          "preproc"      "elapsed"     
[6] "censor_probs"

$class
[1] "_lm"       "model_fit"

$names
 [1] "coefficients"  "residuals"     "effects"       "rank"         
 [5] "fitted.values" "assign"        "qr"            "df.residual"  
 [9] "contrasts"     "xlevels"       "call"          "terms"        
[13] "model"        

$class
[1] "lm"

$names
 [1] "coefficients"  "residuals"     "effects"       "rank"         
 [5] "fitted.values" "assign"        "qr"            "df.residual"  
 [9] "contrasts"     "xlevels"       "call"          "terms"        
[13] "model"        

$class
[1] "lm"

tidy一下

Code
```{r}
tidy(lm_fit)
```
# A tibble: 6 × 5
  term                            estimate std.error statistic  p.value
  <chr>                              <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)                     0.0331    0.00962      3.44  0.00100 
2 initial_volume                  0.00155   0.000398     3.91  0.000222
3 food_regimeLow                  0.0198    0.0130       1.52  0.133   
4 food_regimeHigh                 0.0214    0.0145       1.47  0.145   
5 initial_volume:food_regimeLow  -0.00126   0.000510    -2.47  0.0162  
6 initial_volume:food_regimeHigh  0.000525  0.000702     0.748 0.457   

预测

Code
```{r}
new_points <- expand.grid(
  initial_volume = 20,
  food_regime = c("Initial", "Low", "High")
)
new_points
```
  initial_volume food_regime
1             20     Initial
2             20         Low
3             20        High

预测一下初始体积20下的最终大小

Code
```{r}
mean_pred <- predict(lm_fit, new_data = new_points)
mean_pred
conf_int_pred <- predict(lm_fit,
  new_data = new_points,
  type = "conf_int"
)
conf_int_pred
```
# A tibble: 3 × 1
   .pred
   <dbl>
1 0.0642
2 0.0588
3 0.0961
# A tibble: 3 × 2
  .pred_lower .pred_upper
        <dbl>       <dbl>
1      0.0555      0.0729
2      0.0499      0.0678
3      0.0870      0.105 

非线性

先导入数据

Code
```{r}
library(tidymodels)
library(ISLR)

Wage <- as_tibble(Wage)
```

多项式回归与step functions

step_poly(age, degree = 4)将age进行4次多项式转化

Code
```{r}
rec_poly <- recipe(wage ~ age, data = Wage) %>%
  step_poly(age, degree = 4)

lm_spec <- linear_reg() %>%
  set_mode("regression") %>%
  set_engine("lm")

poly_wf <- workflow() %>%
  add_model(lm_spec) %>%
  add_recipe(rec_poly)

poly_fit <- fit(poly_wf, data = Wage)
poly_fit
```
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
1 Recipe Step

• step_poly()

── Model ───────────────────────────────────────────────────────────────────────

Call:
stats::lm(formula = ..y ~ ., data = data)

Coefficients:
(Intercept)   age_poly_1   age_poly_2   age_poly_3   age_poly_4  
     111.70       447.07      -478.32       125.52       -77.91  
Code
```{r}
tidy(poly_fit)
```
# A tibble: 5 × 5
  term        estimate std.error statistic  p.value
  <chr>          <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)    112.      0.729    153.   0       
2 age_poly_1     447.     39.9       11.2  1.48e-28
3 age_poly_2    -478.     39.9      -12.0  2.36e-32
4 age_poly_3     126.     39.9        3.14 1.68e- 3
5 age_poly_4     -77.9    39.9       -1.95 5.10e- 2

事实上step_poly()并没有返回age、age2、age3和age^4,它返回的变量是正交多项式的基,这意味着每一列都是变量age、age ^2、age ^3和age ^4的线性组合。

Code
```{r}
q <- poly(1:6, degree = 4, raw = FALSE)

q

poly(1:6, degree = 4, raw = TRUE)

round(sum(q[,1]*q[,2]))
```
              1          2          3          4
[1,] -0.5976143  0.5455447 -0.3726780  0.1889822
[2,] -0.3585686 -0.1091089  0.5217492 -0.5669467
[3,] -0.1195229 -0.4364358  0.2981424  0.3779645
[4,]  0.1195229 -0.4364358 -0.2981424  0.3779645
[5,]  0.3585686 -0.1091089 -0.5217492 -0.5669467
[6,]  0.5976143  0.5455447  0.3726780  0.1889822
attr(,"coefs")
attr(,"coefs")$alpha
[1] 3.5 3.5 3.5 3.5

attr(,"coefs")$norm2
[1]  1.00000  6.00000 17.50000 37.33333 64.80000 82.28571

attr(,"degree")
[1] 1 2 3 4
attr(,"class")
[1] "poly"   "matrix"
     1  2   3    4
[1,] 1  1   1    1
[2,] 2  4   8   16
[3,] 3  9  27   81
[4,] 4 16  64  256
[5,] 5 25 125  625
[6,] 6 36 216 1296
attr(,"degree")
[1] 1 2 3 4
attr(,"class")
[1] "poly"   "matrix"
[1] 0

已经施密特正交化以减少共线性

如果想用原数据:

rec_raw_poly <- recipe(wage ~ age, data = Wage) %>%
  step_poly(age, degree = 4, options = list(raw = TRUE))

raw_poly_wf <- workflow() %>%
  add_model(lm_spec) %>%
  add_recipe(rec_raw_poly)

raw_poly_fit <- fit(raw_poly_wf, data = Wage)

tidy(raw_poly_fit)

now,用poly_fit来拟合一些数据

Code
```{r}
age_range <- tibble(age = seq(min(Wage$age), max(Wage$age)))

regression_lines <- bind_cols(
  augment(poly_fit, new_data = age_range),
  predict(poly_fit, new_data = age_range, type = "conf_int")
)
regression_lines
```
# A tibble: 63 × 4
   .pred   age .pred_lower .pred_upper
   <dbl> <int>       <dbl>       <dbl>
 1  51.9    18        41.5        62.3
 2  58.5    19        49.9        67.1
 3  64.6    20        57.5        71.6
 4  70.2    21        64.4        76.0
 5  75.4    22        70.5        80.2
 6  80.1    23        76.0        84.2
 7  84.5    24        80.9        88.1
 8  88.5    25        85.2        91.7
 9  92.1    26        89.1        95.2
10  95.4    27        92.5        98.4
# ℹ 53 more rows
Code
```{r}
#| fig-cap: 绿色为回归,蓝色为置信区间

Wage %>%
  ggplot(aes(age, wage)) +
  geom_point(alpha = 0.2) +
  geom_line(aes(y = .pred), color = "darkgreen",
            data = regression_lines) +
  geom_line(aes(y = .pred_lower), data = regression_lines, 
            linetype = "dashed", color = "blue") +
  geom_line(aes(y = .pred_upper), data = regression_lines, 
            linetype = "dashed", color = "blue")
```

绿色为回归,蓝色为置信区间

现在预测更大的年龄范围(18~100)

Code
```{r}
wide_age_range <- tibble(age = seq(18, 100))

regression_lines <- bind_cols(
  augment(poly_fit, new_data = wide_age_range),
  predict(poly_fit, new_data = wide_age_range, type = "conf_int")
)

Wage %>%
  ggplot(aes(age, wage)) +
  geom_point(alpha = 0.2) +
  geom_line(aes(y = .pred), color = "darkgreen",
            data = regression_lines) +
  geom_line(aes(y = .pred_lower), data = regression_lines, 
            linetype = "dashed", color = "blue") +
  geom_line(aes(y = .pred_upper), data = regression_lines, 
            linetype = "dashed", color = "blue")
```

边缘处的置信区间变得更大,方差过大,model预测效果不好.

参考