Last updated: 2021-09-29
Checks: 7 0
Knit directory: myTidyTuesday/
This reproducible R Markdown analysis was created with workflowr (version 1.6.2). The Checks tab describes the reproducibility checks that were applied when the results were created. The Past versions tab lists the development history.
Great! Since the R Markdown file has been committed to the Git repository, you know the exact version of the code that produced these results.
Great job! The global environment was empty. Objects defined in the global environment can affect the analysis in your R Markdown file in unknown ways. For reproduciblity it’s best to always run the code in an empty environment.
The command set.seed(20210907)
was run prior to running the code in the R Markdown file. Setting a seed ensures that any results that rely on randomness, e.g. subsampling or permutations, are reproducible.
Great job! Recording the operating system, R version, and package versions is critical for reproducibility.
Nice! There were no cached chunks for this analysis, so you can be confident that you successfully produced the results during this run.
Great job! Using relative paths to the files within your workflowr project makes it easier to run your code on other machines.
Great! You are using Git for version control. Tracking code development and connecting the code version to the results is critical for reproducibility.
The results in this page were generated with repository version 78afe8f. See the Past versions tab to see a history of the changes made to the R Markdown and HTML files.
Note that you need to be careful to ensure that all relevant files for the analysis have been committed to Git prior to generating the results (you can use wflow_publish
or wflow_git_commit
). workflowr only checks the R Markdown file, but you know if there are other scripts or data files that it depends on. Below is the status of the Git repository when the results were generated:
Ignored files:
Ignored: .Rhistory
Ignored: .Rproj.user/
Ignored: catboost_info/
Ignored: data/2021-09-08/
Ignored: data/CNHI_Excel_Chart.xlsx
Ignored: data/CommunityTreemap.jpeg
Ignored: data/Community_Roles.jpeg
Ignored: data/YammerDigitalDataScienceMembership.xlsx
Ignored: data/acs_poverty.rds
Ignored: data/fmhpi.rds
Ignored: data/grainstocks.rds
Ignored: data/hike_data.rds
Ignored: data/us_states.rds
Ignored: data/us_states_hexgrid.geojson
Ignored: data/weatherstats_toronto_daily.csv
Untracked files:
Untracked: code/YammerReach.R
Untracked: code/work list batch targets.R
Unstaged changes:
Modified: code/_common.R
Note that any generated files, e.g. HTML, png, CSS, etc., are not included in this status report because it is ok for generated content to have uncommitted changes.
These are the previous versions of the repository in which changes were made to the R Markdown (analysis/ChicagoTrafficInjuries.Rmd
) and HTML (docs/ChicagoTrafficInjuries.html
) files. If you’ve configured a remote Git repository (see ?wflow_git_remote
), click on the hyperlinks in the table below to view the files as they were in that past version.
File | Version | Author | Date | Message |
---|---|---|---|---|
Rmd | 78afe8f | opus1993 | 2021-09-29 | add viridis color palette |
html | 5a14657 | opus1993 | 2021-09-24 | Build site. |
Rmd | ab4df5f | opus1993 | 2021-09-24 | apply theme_jim() ggplot color schemes |
Inspired by Julia Silge’s Predicting injuries for Chicago traffic crashes
Our goal here is to demonstrate how to use the tidymodels
framework to model live-caught data on traffic crashes in the City of Chicago on injuries.
suppressPackageStartupMessages({
library(tidyverse)
library(lubridate)
library(here)
library(tidymodels)
library(RSocrata)
library(themis) # upsample/downsample for unbalanced datasets
library(baguette) # bagging models
})
source(here::here("code","_common.R"),
verbose = FALSE,
local = knitr::knit_global())
ggplot2::theme_set(theme_jim(base_size = 12))
We will load the latest data directly from the Chicago data portal. This dataset covers traffic crashes on city streets within Chicago city limits under the jurisdiction of the Chicago Police Department.
Let’s download the last three years of data to train our model.
years_ago <- today() - years(3)
crash_url <- glue::glue("https://data.cityofchicago.org/Transportation/Traffic-Crashes-Crashes/85ca-t3if?$where=CRASH_DATE > '{years_ago}'")
crash_raw <- as_tibble(read.socrata(crash_url))
Data preparation
crash <- crash_raw %>%
arrange(desc(crash_date)) %>%
transmute(
injuries = if_else(injuries_total > 0, "injuries", "noninjuries"),
crash_date,
crash_hour,
report_type = if_else(report_type == "", "UNKNOWN", report_type),
num_units,
trafficway_type,
posted_speed_limit,
weather_condition,
lighting_condition,
roadway_surface_cond,
first_crash_type,
trafficway_type,
prim_contributory_cause,
latitude, longitude
) %>%
na.omit()
crash %>%
mutate(crash_date = as_date(floor_date(crash_date, unit = "week"))) %>%
count(crash_date, injuries) %>%
filter(
crash_date != last(crash_date),
crash_date != first(crash_date)
) %>%
mutate(name_lab = if_else(crash_date == last(crash_date), injuries, NA_character_)) %>%
ggplot() +
geom_line(aes(as.Date(crash_date), n, color = injuries),
size = 1.5, alpha = 0.7
) +
scale_x_date(
labels = scales::date_format("%Y"),
expand = c(0, 0),
breaks = seq.Date(as_date("2018-09-01"),
as_date("2021-09-01"),
by = "year"
),
minor_breaks = "3 months",
limits = c(as_date("2018-09-01"), as_date("2021-12-01"))
) +
ggrepel::geom_text_repel(
data = . %>% filter(!is.na(crash_date)),
aes(
x = crash_date,
y = n + 200,
label = name_lab,
color = injuries
),
fontface = "bold",
size = 4,
direction = "y",
xlim = c(2022, NA),
hjust = 0,
segment.size = .7,
segment.alpha = .5,
segment.linetype = "dotted",
box.padding = .4,
segment.curvature = -0.1,
segment.ncp = 3,
segment.angle = 20
) +
scale_y_continuous(limits = (c(0, NA))) +
labs(
title = "How have the number of crashes changed over time?",
x = NULL, y = "Number of traffic crashes per week",
color = "Injuries?", caption = "Data: Chicago Data Portal | Visual: @jim_gruman"
) +
theme(legend.position = "")
This is not a balanced dataset, in that the injuries are a small portion of traffic incidents. Let’s look at the percentage.
crash %>%
mutate(crash_date = floor_date(crash_date, unit = "week")) %>%
count(crash_date, injuries) %>%
filter(
crash_date != last(crash_date),
crash_date != first(crash_date)
) %>%
group_by(crash_date) %>%
mutate(percent_injury = n / sum(n)) %>%
ungroup() %>%
filter(injuries == "injuries") %>%
ggplot(aes(as_date(crash_date), percent_injury)) +
geom_line(size = 1.5, alpha = 0.7) +
scale_y_continuous(
limits = c(0, NA),
labels = percent_format(accuracy = 1)
) +
scale_x_date(
labels = scales::date_format("%Y"),
expand = c(0, 0),
breaks = seq.Date(as_date("2018-09-01"),
as_date("2021-09-01"),
by = "year"
),
minor_breaks = "3 months",
limits = c(as_date("2018-09-01"), as_date("2021-12-01"))
) +
labs(
x = NULL, y = "% of crashes that involve injuries",
title = "How has the traffic injury rate changed over time?",
caption = "Data: Chicago Data Portal | Visual: @jim_gruman"
)
crash %>%
mutate(crash_date = wday(crash_date, label = TRUE)) %>%
count(crash_date, injuries) %>%
group_by(injuries) %>%
mutate(percent = n / sum(n)) %>%
ungroup() %>%
ggplot(aes(percent, crash_date, fill = injuries)) +
geom_col(position = "dodge", alpha = 0.8) +
scale_x_continuous(labels = percent_format()) +
labs(
x = "% of crashes", y = NULL, fill = NULL,
title = "How does the injury rate change through the week?",
caption = "Data: Chicago Data Portal | Visual: @jim_gruman"
)
crash %>%
count(first_crash_type, injuries) %>%
mutate(first_crash_type = fct_reorder(first_crash_type, n)) %>%
group_by(injuries) %>%
mutate(percent = n / sum(n)) %>%
ungroup() %>%
group_by(first_crash_type) %>%
filter(sum(n) > 1e4) %>%
ungroup() %>%
ggplot(aes(percent, first_crash_type, fill = injuries)) +
geom_col(position = "dodge", alpha = 0.8) +
scale_x_continuous(labels = percent_format()) +
labs(
x = "% of crashes", y = NULL, fill = NULL,
title = "How do injuries vary with first crash type?",
caption = "Data: Chicago Data Portal | Visual: @jim_gruman"
)
crash %>%
filter(latitude > 0) %>%
ggplot(aes(longitude, latitude, color = injuries)) +
geom_point(size = 0.5, alpha = 0.4) +
labs(color = NULL) +
coord_map() +
guides(col = guide_legend(override.aes = list(size = 3, alpha = 1))) +
theme(axis.text.x = element_blank(), axis.text.y = element_blank()) +
labs(
x = NULL, y = NULL, fill = NULL,
title = "Are injuries more likely in different locations?",
caption = "Data: Chicago Data Portal | Visual: @jim_gruman"
)
This is all the information we will use in building our model to predict which crashes caused injuries.
Let’s start by splitting our data and creating 10 cross-validation folds.
crash_split <- initial_split(crash, strata = injuries)
crash_train <- training(crash_split)
crash_test <- testing(crash_split)
crash_folds <- vfold_cv(crash_train,
v = 10,
strata = injuries
)
Next, let’s create a model.
The feature engineering includes creating date features such as day of the week, handling the high cardinality of weather conditions, contributing cause, etc, and perhaps most importantly, downsampling to account for the class imbalance (injuries are more rare than non-injury-causing crashes).
crash_rec <- recipe(injuries ~ ., data = crash_train) %>%
step_date(crash_date) %>%
step_rm(crash_date) %>%
step_other(
weather_condition,
first_crash_type,
trafficway_type,
prim_contributory_cause,
other = "OTHER"
) %>%
step_downsample(injuries)
bag_spec <- bag_tree(min_n = 10) %>%
set_engine("rpart", times = 25) %>%
set_mode("classification")
crash_wf <- workflow() %>%
add_recipe(crash_rec) %>%
add_model(bag_spec)
crash_wf
== Workflow ====================================================================
Preprocessor: Recipe
Model: bag_tree()
-- Preprocessor ----------------------------------------------------------------
4 Recipe Steps
* step_date()
* step_rm()
* step_other()
* step_downsample()
-- Model -----------------------------------------------------------------------
Bagged Decision Tree Model Specification (classification)
Main Arguments:
cost_complexity = 0
min_n = 10
Engine-Specific Arguments:
times = 25
Computational engine: rpart
Let’s fit this model to the cross-validation resamples to understand how well it will perform.
all_cores <- parallelly::availableCores(omit = 1)
all_cores
system
11
future::plan("multisession", workers = all_cores) # on Windows
crash_res <- fit_resamples(
crash_wf,
crash_folds,
control = control_resamples(save_pred = TRUE)
)
What do the results look like?
collect_metrics(crash_res) # metrics on the training set
# A tibble: 2 x 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.725 10 0.000990 Preprocessor1_Model1
2 roc_auc binary 0.817 10 0.000643 Preprocessor1_Model1
Not bad.
Let’s now fit to the entire training set and evaluate on the testing set.
crash_fit <- last_fit(crash_wf, crash_split)
collect_metrics(crash_fit) # metrics on the test set, look for overfitting
# A tibble: 2 x 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.728 Preprocessor1_Model1
2 roc_auc binary 0.821 Preprocessor1_Model1
Spot on.
Which features were most important in predicting an injury?
crash_imp <- crash_fit$.workflow[[1]] %>%
extract_fit_parsnip()
crash_imp$fit$imp %>%
slice_max(value, n = 10) %>%
ggplot(aes(value, fct_reorder(term, value))) +
geom_col(alpha = 0.8) +
labs(x = "Variable importance score", y = NULL) +
theme(panel.grid.major.y = element_blank())
How does the ROC curve for the testing data look?
collect_predictions(crash_fit) %>%
roc_curve(injuries, .pred_injuries) %>%
ggplot(aes(x = 1 - specificity, y = sensitivity)) +
geom_line(size = 1.5, color = hrbrthemes::ipsum_pal()(1)) +
geom_abline(
lty = 2, alpha = 0.5,
color = "gray50",
size = 1.2
) +
coord_equal() +
labs(title = "ROC Curve")
crash_wf_model <- crash_fit$.workflow[[1]]
# crash_wf_model <- butcher::butcher(crash_fit$.workflow[[1]])
This is an object we can make predictions with. For example, is this particular crash predicted to have any injuries?
predict(crash_wf_model, crash_test[222, ])
# A tibble: 1 x 1
.pred_class
<fct>
1 noninjuries
sessionInfo()
R version 4.1.1 (2021-08-10)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 22000)
Matrix products: default
locale:
[1] LC_COLLATE=English_United States.1252
[2] LC_CTYPE=English_United States.1252
[3] LC_MONETARY=English_United States.1252
[4] LC_NUMERIC=C
[5] LC_TIME=English_United States.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] rpart_4.1-15 vctrs_0.3.8 rlang_0.4.11 baguette_0.1.1
[5] themis_0.1.4 RSocrata_1.7.11-2 yardstick_0.0.8 workflowsets_0.1.0
[9] workflows_0.2.3 tune_0.1.6 rsample_0.1.0 recipes_0.1.17
[13] parsnip_0.1.7.900 modeldata_0.1.1 infer_1.0.0 dials_0.0.10
[17] scales_1.1.1 broom_0.7.9 tidymodels_0.1.3 here_1.0.1
[21] lubridate_1.7.10 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.7
[25] purrr_0.3.4 readr_2.0.2 tidyr_1.1.4 tibble_3.1.4
[29] ggplot2_3.3.5 tidyverse_1.3.1 workflowr_1.6.2
loaded via a namespace (and not attached):
[1] utf8_1.2.2 R.utils_2.11.0 tidyselect_1.1.1
[4] grid_4.1.1 pROC_1.18.0 munsell_0.5.0
[7] codetools_0.2-18 ragg_1.1.3 future_1.22.1
[10] withr_2.4.2 colorspace_2.0-2 highr_0.9
[13] knitr_1.36 rstudioapi_0.13 Rttf2pt1_1.3.9
[16] listenv_0.8.0 labeling_0.4.2 git2r_0.28.0
[19] TeachingDemos_2.12 farver_2.1.0 DiceDesign_1.9
[22] rprojroot_2.0.2 mlr_2.19.0 parallelly_1.28.1
[25] generics_0.1.0 ipred_0.9-12 xfun_0.26
[28] R6_2.5.1 doParallel_1.0.16 lhs_1.1.3
[31] cachem_1.0.6 assertthat_0.2.1 promises_1.2.0.1
[34] nnet_7.3-16 gtable_0.3.0 Cubist_0.3.0
[37] globals_0.14.0 timeDate_3043.102 BBmisc_1.11
[40] systemfonts_1.0.2 splines_4.1.1 butcher_0.1.5
[43] extrafontdb_1.0 earth_5.3.1 checkmate_2.0.0
[46] yaml_2.2.1 reshape2_1.4.4 modelr_0.1.8
[49] backports_1.2.1 httpuv_1.6.3 extrafont_0.17
[52] usethis_2.0.1 inum_1.0-4 tools_4.1.1
[55] lava_1.6.10 ellipsis_0.3.2 jquerylib_0.1.4
[58] Rcpp_1.0.7 plyr_1.8.6 parallelMap_1.5.1
[61] ParamHelpers_1.14 viridis_0.6.1 haven_2.4.3
[64] ggrepel_0.9.1 hrbrthemes_0.8.0 fs_1.5.0
[67] furrr_0.2.3 unbalanced_2.0 magrittr_2.0.1
[70] data.table_1.14.2 reprex_2.0.1 RANN_2.6.1
[73] GPfit_1.0-8 mvtnorm_1.1-2 whisker_0.4
[76] ROSE_0.0-4 R.cache_0.15.0 hms_1.1.1
[79] mime_0.12 evaluate_0.14 readxl_1.3.1
[82] gridExtra_2.3 compiler_4.1.1 maps_3.4.0
[85] crayon_1.4.1 R.oo_1.24.0 htmltools_0.5.2
[88] later_1.3.0 tzdb_0.1.2 Formula_1.2-4
[91] libcoin_1.0-9 DBI_1.1.1 dbplyr_2.1.1
[94] MASS_7.3-54 Matrix_1.3-4 cli_3.0.1
[97] C50_0.1.5 R.methodsS3_1.8.1 parallel_4.1.1
[100] gower_0.2.2 pkgconfig_2.0.3 xml2_1.3.2
[103] foreach_1.5.1 bslib_0.3.0 hardhat_0.1.6
[106] plotmo_3.6.1 prodlim_2019.11.13 rvest_1.0.1
[109] digest_0.6.27 rmarkdown_2.11 cellranger_1.1.0
[112] fastmatch_1.1-3 gdtools_0.2.3 curl_4.3.2
[115] lifecycle_1.0.1 jsonlite_1.7.2 mapproj_1.2.7
[118] viridisLite_0.4.0 fansi_0.5.0 pillar_1.6.3
[121] lattice_0.20-44 fastmap_1.1.0 httr_1.4.2
[124] plotrix_3.8-2 survival_3.2-11 glue_1.4.2
[127] conflicted_1.0.4 FNN_1.1.3 iterators_1.0.13
[130] class_7.3-19 stringi_1.7.4 sass_0.4.0
[133] rematch2_2.1.2 textshaping_0.3.5 partykit_1.2-15
[136] styler_1.6.2 future.apply_1.8.1