Skip to content
Snippets Groups Projects
Commit 83c261b5 authored by Helfenstein, Anatol's avatar Helfenstein, Anatol :speech_balloon:
Browse files

variable importance, PDP and introduction to xML method using Shapley values

parent abc07af1
No related branches found
No related tags found
No related merge requests found
# Name: 53_model_evaluation_var_imp_xML.R
# (var = variable; imp = importance; xML = explainable machine learning)
#
# Content: - read in target variable regression matrix and fitted model
# - Help explain/interpret ML/AI (xML/xAI) using 3 methods
# 1. Variable importance (permutation or impurity)
# 2. Partial dependence plots (PDP): TARGET/response variable vs.
# predictors/covariates
# 3. Shapley values (used for local xML)
#
# Refs: - QRF package and vignettes:
# https://cran.r-project.org/web/packages/quantregForest/quantregForest.pdf
# https://mran.microsoft.com/snapshot/2015-07-15/web/packages/quantregForest/vignettes/quantregForest.pdf
#
# Inputs: - target regression matrix: out/data/model/tbl_regmat_[TARGET].Rds
# - final fitted QRF model using all calibration data (optimal):
# "out/data/model/QRF_fit_[TARGET]_obs[]_p[]_[]CV_optimal.Rds"
#
# Output: - variable importance plot(s)
#
# Project BIS-4D Masterclass
# Author: Anatol Helfenstein
# Updated: 2023-07-13
#-------------------------------------------------------------------------------
### Empty memory and workspace; load required packages =========================
gc()
rm(list=ls())
pkgs <- c("tidyverse", "ranger", "caret", "foreach", "viridis", "doParallel", "pdp",
"raster", "sf", "fastshap")
lapply(pkgs, library, character.only = TRUE)
### Designate script parameters & load modelling data ==========================
# 1) Specify target soil property
TARGET = "SOM_per"
# expression of TARGET for model evaluation plots
TARGET_EXP = "SOM [%] (observed)"
TARGET_PRED = expression(paste(hat(SOM), " [%] (predicted)"))
# 2) Specify whether to (log) transform, observation quality & dynamic or static model
TRANSFORM = "" # if no transformation of response is needed: empty quotes
OBS_QUAL = "_lab" # lab measurements & field estimates: ""; lab meas. only: "_lab"
TIME = "_dyn" # calibration 3D+T model using 2D+T & 3D+T dynamic covariates
TIME_DIR = "dynamic" # either "static" or "dynamic"; directory to save plots
# 3) PFB validation sites field campaign 2022 data
tbl_val_PFB <- read_csv("data/soil/field_campaign/PFB_val_2022.csv") %>%
# common IDs to select samples which were re-sampled in 2022
unite("id", site_id, d_upper, d_lower, remove = FALSE) %>%
arrange(id)
# 4) Regression matrix data containing calibration and validation data
tbl_regmat_target <- read_rds(paste0(
"out/data/model/", TARGET, "/dynamic/tbl_regmat_", TARGET, "_dyn.Rds"
))
# if model should only include lab measurements, remove field estimates
if (OBS_QUAL == "_lab") {
tbl_regmat_target <- tbl_regmat_target %>%
filter(!BIS_type == "field")
}
# Separate tables for calibration and validation data
tbl_regmat_target_cal <- tbl_regmat_target %>%
filter(split %in% "train")
tbl_regmat_target_val <- tbl_regmat_target %>%
filter(split %in% "test")
# double check 2022 validation data was not used for model calibration
tbl_regmat_target %>%
filter(BIS_tbl %in% "PFB") %>%
filter(site_id %in% tbl_val_PFB$site_id) %>%
pull(split) %>%
unique() # should be "test", NOT "train"
# join old BIS data with new field campaign data via common IDs
tbl_regmat_target_val_PFB <- tbl_regmat_target_val %>%
filter(BIS_tbl == "PFB") %>%
# common IDs to select samples which were re-sampled in 2022
unite("id", site_id, d_upper, d_lower, remove = FALSE) %>%
filter(id %in% tbl_val_PFB$id) %>%
arrange(id) %>%
# remove empty horizon type cols
dplyr::select(-hor) %>%
full_join(., tbl_val_PFB) %>%
# re-arrange cols
dplyr::select(all_of(colnames(tbl_regmat_target_val)), sample_id, year2,
d_upper_new, d_lower_new, SOM_per_2022:metadata_2022) %>%
dplyr::select(split:site_id, sample_id, X:d_mid, d_upper_new, d_lower_new, year, year2,
SOM_per, SOM_per_2022, pH_KCl, pH_KCl_2022, clay_per:metadata_2022,
contains("_1km")) %>%
# remove NA's in validation data;
# due to coarser resolution of covariates (1km) in this masterclass than in
# BIS-4D at 1km resolution, some validation sites are designated as NA (were
# probably too close to body of water or built-up areas)
filter(!BIS_tbl %in% NA)
# remove NA's
tbl_val_PFB <- tbl_val_PFB %>%
filter(site_id %in% unique(tbl_regmat_target_val_PFB$site_id))
# 5) Number of covariates used in model
COV = ncol(dplyr::select(tbl_regmat_target, -c(split:hor, year, all_of(TARGET))))
# 6) Value of K (in k-fold CV); i.e. number of folds in CV
K = 10
# 7) Specify which (previously fitted) optimal model to use:
QRF_FIT_optimal <- read_rds(paste0(
"out/data/model/", TARGET, "/", TIME_DIR, "/QRF_fit_", TRANSFORM, TARGET, OBS_QUAL, TIME,
"_obs", nrow(tbl_regmat_target_cal), "_p", COV, "_LLO_", K, "FCV_optimal.Rds"))
# 8) Specify which quantiles we want to predict:
# 50th (median), 5th & 95th quantile to calculate 90th prediction interval (PI90)
QUANTILES = c(0.05, 0.50, 0.95)
### Global xML method 1: permutation or impurity ===============================
# Permutation: computed from permuting OOB data: For each tree, prediction error on
# out-of-bag (OOB) portion of data is recorded (error rate for classification,
# MSE for regression). Then same is done after permuting each predictor variable.
# Difference between the two are then averaged over all trees, and normalized by
# the standard deviation of the differences. If the standard deviation of the
# differences is equal to 0 for a variable, the division is not done
# (but the average is almost always equal to 0 in that case).
# Impurity: variable importance measured using node impurity:
# The second measure is the total decrease in node impurities from splitting on
# variable, averaged over all trees. For regression, measured by residual sum of squares.
# mode chosen during model training to calculate variable importance
QRF_FIT_optimal$finalModel$importance.mode
# 20 most importance variables using mode chosen during model training (see above)
varImp(QRF_FIT_optimal)
# tbl of variable importances for all covariates
tbl_var_imp <- tibble(Covariate = names(QRF_FIT_optimal$finalModel$variable.importance),
Importance = QRF_FIT_optimal$finalModel$variable.importance) %>%
arrange(-Importance)
# plot variable importance (30 most important ones) using model chosen (see above)
if (COV > 33) {
p_var_imp <- tbl_var_imp %>%
slice(1:33) %>%
ggplot(., aes(y = reorder(Covariate, Importance))) +
geom_bar(aes(weight = Importance)) +
xlab(paste0("Variable importance (",
QRF_FIT_optimal$finalModel$importance.mode, ")")) +
ylab(paste0("Covariates (p = ", COV, ")")) +
theme_bw()
} else {
p_var_imp <- tbl_var_imp %>%
ggplot(., aes(y = reorder(Covariate, Importance))) +
geom_bar(aes(weight = Importance)) +
xlab(paste0("Variable importance (",
QRF_FIT_optimal$finalModel$importance.mode, ")")) +
ylab(paste0("Covariates (p = ", COV, ")")) +
theme_bw()
}
# save plot to disk
# ggsave(paste0("p_QRF_", TRANSFORM, TARGET, OBS_QUAL, TIME, "_var_imp_",
# QRF_FIT_optimal$finalModel$importance.mode, ".pdf"),
# p_var_imp,
# path = paste0("out/figs/models/", TARGET, "/", TIME_DIR),
# width = 8, height = 8)
# local variable importance (for each sample)
# ATTENTION: check if this is correct!
# if (is.null(QRF_FIT_optimal$finalModel$variable.importance.local)) {
# tbl_var_imp_local <- as_tibble(QRF_FIT_optimal$finalModel$variable.importance.local) %>%
# bind_cols(tbl_regmat_target %>%
# filter(split %in% "train") %>%
# dplyr::select(split:hor),
# .) %>%
# rownames_to_column('id') %>%
# gather(covariate, count, d_upper:water_wetness_probability_2015_1km) %>%
# group_by(id) %>%
# slice(which.max(count)) %>%
# ungroup()
#
# # plot local variable importance
# p_var_imp_local <- plyr::count(tbl_var_imp_local, vars = "covariate") %>%
# as_tibble() %>%
# filter(!freq %in% 1) %>%
# ggplot(., aes(y = reorder(covariate, freq))) +
# geom_bar(aes(weight = freq)) +
# xlab("Local variable importance") +
# ylab("Covariates") +
# theme_bw()
# }
# save plot to disk
# ggsave(paste0("p_QRF_", TARGET, "_var_imp_local.pdf"),
# p_var_imp_local,
# path = paste0("out/figs/models/", TARGET, "/", TIME_DIR),
# width = 6, height = 8)
### Global xML method 2: partial dependence plots (PDP) ========================
## PDP: mean vs. covariate (1-dimensional interaction) -------------------------
# although function "pdp::partial()" can be parallelized, this doesn't speed up
# computation and even leads to socket reading errors b/c ranger is already running
# in parallel...
# this returns MEAN prediction (default RF) vs. most important covariate
system.time(
tbl_pd_mean_var1 <- partial(
QRF_FIT_optimal,
tbl_var_imp$Covariate[1],
grid.resolution = 21, # e.g. at intervals of 0.05 for peat_xydt
# parallel = TRUE,
# paropts = list(.packages = c("caret", "ranger")),
progress = TRUE
) %>%
as_tibble() %>% # for easier data manipulation
add_column(pred = "Mean")
)
# time elapse SOM lab (grid res: 21): 10 sec
# list of MEAN predictions vs. each covariate
system.time(
ls_tbl_pd_mean <- foreach(p = 1:COV) %do% {
pdp::partial(QRF_FIT_optimal,
tbl_var_imp$Covariate[p],
grid.resolution = 21,
progress = TRUE) %>%
tibble::as_tibble() %>%
rename(cov = tbl_var_imp$Covariate[p]) %>%
add_column(pred = "Mean")
} %>%
set_names(tbl_var_imp$Covariate) # name list element instead of col in each tibble
)
# time elapse SOM lab (grid res: 21): 4 min
# list of PDP ggplots for each covariate
ls_p_pd_mean <- foreach(p = 1:COV) %do% {
ls_tbl_pd_mean[[p]] %>%
as_tibble() %>%
ggplot(aes(x = cov, y = yhat, group = 1)) +
geom_line() +
labs(x = tbl_var_imp$Covariate[p], y = TARGET_PRED) +
theme_bw()
}
# arrange DPD ggplots into one grid
p_pd_mean <- gridExtra::grid.arrange(grobs = ls_p_pd_mean,
ncol = 6)
# save plot to disk
ggsave(paste0("p_QRF_", TRANSFORM, TARGET, OBS_QUAL, TIME,
"_var_imp_PDP_1dim_mean_all_cov.pdf"),
p_pd_mean,
path = paste0("out/figs/models/", TARGET, "/", TIME_DIR, "/", OBS_QUAL),
width = 20, height = 20)
## PDP: mean, median & PI90 vs. covariate (1-dimensional interaction) ----------
# Use modified pdp::partial() function that works for QRF
source("R/other/fun_partial_qrf.R")
# this returns chosen quantile predictions (QRF)
tbl_pd_quant_var1 <- partial_qrf(
QRF_FIT_optimal$finalModel, # needs to be class ranger
tbl_var_imp$Covariate[1],
train = QRF_FIT_optimal$trainingData,
Q = c(0.05, 0.5, 0.95),
grid.resolution = 21,
progress = TRUE
) %>%
as_tibble() %>%
rename(yhat = .outcome) %>%
bind_rows(tbl_pd_mean_var1) # add MEAN predictions
# plot MEAN, MEDIAN, Q5 and Q95 vs. most important covariate
p_pd_var1 <- tbl_pd_quant_var1 %>%
ggplot(aes(x = get(tbl_var_imp$Covariate[1]), y = yhat, color = pred)) +
geom_line(aes(color = pred)) +
scale_color_manual(values = c("black", "#1b9e77", "#d95f02", "#7570b3")) +
labs(x = tbl_var_imp$Covariate[1], y = TARGET_PRED, color = NULL) +
theme_bw()
# could also use geom_ribbon for PI90 as in script 70 (time series script)
# list of TARGET (MEAN and quantiles) vs. each covariate
system.time(
ls_tbl_pd_quant <- foreach(p = 1:COV) %do% {
partial_qrf(
QRF_FIT_optimal$finalModel, # needs to be class ranger
tbl_var_imp$Covariate[p],
train = QRF_FIT_optimal$trainingData,
Q = c(0.05, 0.5, 0.95),
grid.resolution = 21,
progress = TRUE
) %>%
as_tibble() %>%
rename(yhat = .outcome,
cov = tbl_var_imp$Covariate[p]) %>%
bind_rows(ls_tbl_pd_mean[[p]])
} %>%
set_names(tbl_var_imp$Covariate)
)
# time elapse SOM lab (grid res: 21): 44 min
# in order to plot specific to covariate type (categorical vs. continuous), create
# vector with data type info
v_cov_type <- unlist(map(QRF_FIT_optimal$ptyp, ~class(.x)))
# rearrange to same order as variable importance
v_cov_type <- v_cov_type[order(match(names(v_cov_type), tbl_var_imp$Covariate))]
# list of PDP ggplots (MEAN and quantiles) for each covariate
ls_p_pd_quant <- foreach(p = 1:COV) %do% {
if (v_cov_type[p] == "factor") {
# for categorical covariates use geom_point
ls_tbl_pd_quant[[p]] %>%
ggplot(aes(x = cov, y = yhat)) +
geom_line(color = "gray50") +
geom_point(aes(color = pred), size = 2, shape = 15) +
scale_color_manual(values = c("black", "#1b9e77", "#d95f02", "#7570b3")) +
labs(x = tbl_var_imp$Covariate[p], y = TARGET_PRED, color = NULL) +
theme_bw()
} else {
# for continuous covariates use geom_line
ls_tbl_pd_quant[[p]] %>%
ggplot(aes(x = cov, y = yhat, color = pred)) +
geom_line(aes(color = pred)) +
scale_color_manual(values = c("black", "#1b9e77", "#d95f02", "#7570b3")) +
labs(x = tbl_var_imp$Covariate[p], y = TARGET_PRED, color = NULL) +
theme_bw()
}
}
# arrange DPD ggplots into one grid
p_pd_quant <- gridExtra::grid.arrange(grobs = ls_p_pd_quant,
ncol = 5)
# save plot to disk
ggsave(paste0("p_QRF_", TRANSFORM, TARGET, OBS_QUAL, TIME,
"_var_imp_PDP_1dim_quant_all_cov.pdf"),
p_pd_quant,
path = paste0("out/figs/models/", TARGET, "/", TIME_DIR, "/", OBS_QUAL),
width = 20, height = 20)
## PDP: mean & median vs. 2 covariates (2-dimensional interaction) -------------
# TARGET (MEAN) vs. 2 most important covariate
system.time(
tbl_pd_mean_2dim <- partial(
QRF_FIT_optimal,
tbl_var_imp$Covariate[1:2],
grid.resolution = 21, # e.g. at intervals of 0.05 for peat_xydt
progress = TRUE
)
)
# time elapse SOM lab (grid res: 21): 3 min
p_pd_mean_2dim <- plotPartial(
tbl_pd_mean_2dim,
col.regions = hcl.colors(n = nrow(tbl_pd_mean_2dim), palette = "YlOrBr", rev = TRUE)
)
# TARGET (MEDIAN) vs. 2 most important covariate
system.time(
tbl_pd_median_2dim <- partial_qrf(QRF_FIT_optimal$finalModel,
tbl_var_imp$Covariate[1:2],
train = QRF_FIT_optimal$trainingData,
Q = 0.5, # MEDIAN only
grid.resolution = 21,
progress = TRUE) %>%
as_tibble() %>%
rename(yhat = .outcome) # ...otherwise partialPlot fun doesn't find col
)
# time elapse SOM lab (grid res: 21): 35 min
# function to assign "partial" & "data.frame" class for plotPartial() fun
fun_class_partial <- function(x){structure(x, class = c("partial", "data.frame"))}
p_pd_median_2dim <- plotPartial(
fun_class_partial(tbl_pd_median_2dim),
col.regions = hcl.colors(n = nrow(tbl_pd_median_2dim), palette = "YlOrBr", rev = TRUE)
)
# list of MEAN predictions vs. most important covariate & 1 more covariate
system.time(
ls_tbl_pd_mean_2dim <- foreach(p = 2:COV) %do% {
pdp::partial(QRF_FIT_optimal,
tbl_var_imp$Covariate[c(1, p)],
grid.resolution = 21,
progress = TRUE)
}
)
# time elapse SOM lab (grid res: 21): 1.5 h
# list of plots of MEAN predictions vs. most important covariate & 1 more covariate
ls_p_pd_mean_2dim <- foreach(p = 1:(COV-1)) %do% {
plotPartial(ls_tbl_pd_mean_2dim[[p]],
col.regions = hcl.colors(n = nrow(ls_tbl_pd_mean_2dim[[p]]),
palette = "YlOrBr", rev = TRUE),
main = tbl_var_imp$Covariate[p+1])
}
# arrange DPD ggplots into one grid
p_pd_mean_2dim_all <- gridExtra::grid.arrange(grobs = ls_p_pd_mean_2dim,
ncol = 5)
# save plot to disk
ggsave(paste0("p_QRF_", TRANSFORM, TARGET, OBS_QUAL, TIME,
"_var_imp_PDP_2dim_mean_all_cov.pdf"),
p_pd_mean_2dim_all,
path = paste0("out/figs/models/", TARGET, "/", TIME_DIR, "/", OBS_QUAL),
width = 30, height = 30)
# same as above but for MEDIAN predictions...
# list of MEDIAN predictions vs. most important covariate & 1 more covariate
system.time(
ls_tbl_pd_median_2dim <- foreach(p = 2:COV) %do% {
partial_qrf(QRF_FIT_optimal$finalModel,
tbl_var_imp$Covariate[c(1, p)],
train = QRF_FIT_optimal$trainingData,
Q = 0.5, # MEDIAN only
grid.resolution = 21,
progress = TRUE) %>%
as_tibble() %>%
rename(yhat = .outcome) # ...otherwise partialPlot fun doesn't find col
}
)
# time elapse SOM lab:
# list of plots of MEDIAN predictions vs. most important covariate & 1 more covariate
ls_p_pd_median_2dim <- foreach(p = 1:(COV-1)) %do% {
plotPartial(fun_class_partial(ls_tbl_pd_median_2dim[[p]]),
col.regions = hcl.colors(n = nrow(ls_tbl_pd_median_2dim[[p]]),
palette = "YlOrBr", rev = TRUE),
main = tbl_var_imp$Covariate[p+1])
}
# arrange DPD ggplots into one grid
p_pd_median_2dim_all <- gridExtra::grid.arrange(grobs = ls_p_pd_median_2dim,
ncol = 5)
# save plot to disk
ggsave(paste0("p_QRF_", TRANSFORM, TARGET, OBS_QUAL, TIME,
"_var_imp_PDP_2dim_median_all_cov.pdf"),
p_pd_median_2dim_all,
path = paste0("out/figs/models/", TARGET, "/", TIME_DIR, "/", OBS_QUAL),
width = 30, height = 30)
# read in tables with reclassified values (to assign proper description to classes
# of categorical covariates)
# read in covariate metadata
# tbl_cov_static_cat <- read_csv("data/covariates/covariates_metadata.csv") %>%
# # only interested in covariates we use in model
# filter(name %in% tbl_var_imp$Covariate) %>%
# filter(values_type %in% "categorical")
#
# ls_tbl_cov_static_cat_recl <- foreach(tbl = 1:nrow(tbl_cov_static_cat)) %do% {
# readr::read_csv(paste0("data/covariates/", tbl_cov_static_cat$category[tbl], "/",
# tbl_cov_static_cat$name[tbl], "_reclassify.csv"))
# } %>%
# map(., ~dplyr::select(.x, value_rcl, description_rcl)) %>%
# map(., ~distinct(.x)) %>%
# set_names(tbl_cov_static_cat$name)
# map(ls_tbl_pd_mean, ~mutate(
# case_when(colnames(.x) %in% tbl_cov_static_cat$name) ~
# ))
### Prepare dynamic covariates at 3D locations & years to compute Shapley values =====
## Obtain values from covariates used to make dynamic covariates ---------------
# locate covariates from which we will derive dynamic (2D+T and 3D+T) covariates
v_cov_dyn_names <- dir("out/data/covariates/final_stack_dyn",
pattern = "\\.grd$", recursive = TRUE) %>%
.[c(1:7, 12:16, 8:11)] # change order of HGNs and LGNs by year of creation
# read in covariates and make stack
ls_r_cov_dyn <- foreach(cov = 1:length(v_cov_dyn_names)) %do%
raster(paste0("out/data/covariates/final_stack_dyn/",
v_cov_dyn_names[[cov]]))
r_stack_cov_dyn <- stack(ls_r_cov_dyn)
# Read in other covariates useful for 3D+T modelling
# r_luc_freq <- raster("out/data/covariates/final_stack_dyn/luc_freq_1km.tif")
r_bodem50_2021_peatcode <- raster("out/data/covariates/final_stack/bodem50_2021_peatcode_1km.grd")
r_bodem50_2006_peatcode <- raster("out/data/covariates/final_stack/bodem50_2006_peatcode_1km.grd")
# Read in metadata of soilmap years: original mapping year and updated year
r_bodem50_2021_update <- raster("data/other/bodem50_2021_update_1km.tif")
r_bodem50_2006_date <- raster("data/other/bodem50_2006_date_1km.tif")
# add to raster stack
r_stack_cov_dyn <- stack(r_stack_cov_dyn,
r_bodem50_2021_peatcode, r_bodem50_2006_peatcode,
r_bodem50_2021_update, r_bodem50_2006_date)
# convert tbl to sf object
sf_val_PFB <- tbl_val_PFB %>%
st_as_sf(., coords = c("X", "Y"), crs = crs(r_stack_cov_dyn))
# extract dynamic covariate values at soil sampling locations
tbl_cov_dyn_val_PFB <- raster::extract(r_stack_cov_dyn, sf_val_PFB)
# time elapsed sequential: 2 min
# make into tibble
tbl_cov_dyn_val_PFB <- as_tibble(tbl_cov_dyn_val_PFB)
## compute covariate values for years of interest ------------------------------
# read in dynamic covariate functions
# see also script "R/other/fun_cov_dyn_xyt_xydt.R" to see how functions work
source("R/other/fun_cov_dyn_xyt_xydt.R")
# choose years of interest
YEAR = c(1970, 1980)
# combine year with static covariates to compute dynamic covariates
tbl_cov_dyn_val_PFB <- tibble(
d_upper = tbl_regmat_target_val_PFB$d_upper,
d_lower = tbl_regmat_target_val_PFB$d_lower,
d_mid = tbl_regmat_target_val_PFB$d_mid,
tbl_cov_dyn_val_PFB
)
# list of regression matrix for years of interest
ls_tbl_cov_dyn_val_PFB <- map(
YEAR, ~mutate(dplyr::select(tbl_cov_dyn_val_PFB, -contains(c("_xyt_", "_xydt_"))),
year = .x)
)
# compute dynamic covariate values
# land use (2D+T): LU_xyt[_delta]
ls_tbl_cov_dyn_val_PFB <- foreach(y = 1:length(ls_tbl_cov_dyn_val_PFB)) %do% {
LU_xyt(
data = ls_tbl_cov_dyn_val_PFB[[y]],
year = "year",
LU = c("hgn_1900_filled_1km", "hgn_1960_1km", "hgn_1970_1km", "hgn_1980_1km",
"lgn1_1km", "hgn_1990_1km", "lgn2_1km", "lgn3_1km", "lgn4_1km",
"lgn5_1km", "lgn6_1km", "lgn7_1km", "lgn2018_1km", "lgn2019_1km",
"lgn2020_1km", "lgn2021_1km"),
LU_years = list(1913:1930, 1931:1965, 1966:1975, 1976:1982,
1983:1988, 1989:1990, 1991:1994, 1995:1998, 1999:2001,
2002:2005, 2006:2010, 2011:2015, 2016:2018, 2019,
2020, 2021:2022),
LU_year_min = 1953,
LU_year_max = 2022
) %>%
bind_cols(ls_tbl_cov_dyn_val_PFB[[y]], .)
}
# peat categories (2D+T): peat[1:8]_xyt
# vector of colnames and add empty cols
v_colnames_peat_xyt <- paste0("peat", 1:8, "_xyt_1km")
ls_tbl_cov_dyn_val_PFB <- map(
ls_tbl_cov_dyn_val_PFB,
~add_column(.x, !!!set_names(as.list(rep(NA, 8)), nm = v_colnames_peat_xyt))
)
# fill in values of new cols using "peat_xyt" function
foreach(y = 1:length(ls_tbl_cov_dyn_val_PFB)) %do% {
for (i in 1:length(v_colnames_peat_xyt)) {
ls_tbl_cov_dyn_val_PFB[[y]] <- ls_tbl_cov_dyn_val_PFB[[y]] %>%
mutate(across(starts_with(paste0("peat", i)),
~peat_xyt(version1 = bodem50_2006_peatcode_1km,
version2 = bodem50_2021_peatcode_1km,
version1_year = bodem50_2006_date_1km,
version2_year = bodem50_2021_update_1km,
class = i, year = year),
.names = v_colnames_peat_xyt[i]))
}
}
# peat binary variable (3D+T): peat_xydt
ls_tbl_cov_dyn_val_PFB <- foreach(y = 1:length(ls_tbl_cov_dyn_val_PFB)) %do% {
mutate(ls_tbl_cov_dyn_val_PFB[[y]],
peat_xydt_1km = peat_xydt(version1 = bodem50_2006_peatcode_1km,
version2 = bodem50_2021_peatcode_1km,
version1_year = bodem50_2006_date_1km,
version2_year = bodem50_2021_update_1km,
year = year,
d_upper = d_upper,
d_lower = d_lower))
}
# remove LU and peat maps used to derive dynamic covariates and name list
ls_tbl_cov_dyn_val_PFB <- ls_tbl_cov_dyn_val_PFB %>%
map(., ~dplyr::select(.x, d_upper:d_mid, contains(c("_xyt_", "_xydt_")))) %>%
set_names(paste0("tbl_val_PFB_", YEAR))
# add static covariates to list of regression matrices from different years
ls_tbl_cov_dyn_val_PFB <- ls_tbl_cov_dyn_val_PFB %>%
map(., ~bind_cols(.x,
tbl_regmat_target_val_PFB %>%
dplyr::select(contains("1km")) %>%
dplyr::select(-contains(c("_xyt_", "_xydt_")))))
# add sample and site metadata
ls_tbl_cov_dyn_val_PFB <- map(ls_tbl_cov_dyn_val_PFB, ~tibble(
dplyr::select(tbl_regmat_target_val_PFB, split:hor), .x
))
## Predict for years of interest -----------------------------------------------
# Use modified ranger function from ISRIC to calculate mean in addition to the
# usual quantiles from QRF
source("R/other/predict_qrf_fun.R")
# predict independent test dataset for original sampling years
ls_qrf_tree_val_PFB <- foreach(y = 1:length(ls_tbl_cov_dyn_val_PFB)) %do% {
predict.ranger.tree(
QRF_FIT_optimal$finalModel,
data = ls_tbl_cov_dyn_val_PFB[[y]],
type = "treepred"
)
}
# predict all quantiles and then also the mean
ls_tbl_pred_val_PFB <- foreach(y = 1:length(ls_qrf_tree_val_PFB)) %do% {
data.frame(t(apply(ls_qrf_tree_val_PFB[[y]]$predictions,
1, quantile, QUANTILES, na.rm = TRUE))) %>%
as_tibble() %>%
rename_all(~ paste0("quant_", QUANTILES * 100)) %>%
# predict mean value
add_column(pred_mean = apply(ls_qrf_tree_val_PFB[[y]]$predictions,
1, mean, na.rm = TRUE),
.before = "quant_5")
}
# add observations, depths & name list
ls_tbl_pred_val_PFB <- ls_tbl_pred_val_PFB %>%
map(., ~bind_cols(dplyr::select(tbl_regmat_target_val_PFB, c(site_id, sample_id)),
dplyr::select(tbl_regmat_target_val_PFB, d_upper:d_mid),
.x)) %>%
set_names(paste0("tbl_pred_", YEAR))
# create list of tibbles where each element of list is one site with a tibble
# of predictions for every year for each sample at that site over 70 year period
ls_tbl_pred_timeseries <- rbindlist(ls_tbl_pred_val_PFB) %>%
as_tibble() %>%
group_split(site_id) %>%
map(., ~add_column(.x,
Year = rep(YEAR, each = nrow(.)/length(YEAR)),
obs = NA, .before = "pred_mean")) %>%
map(., ~arrange(.x, Year, sample_id)) %>%
set_names(unique(sort(tbl_regmat_target_val_PFB$site_id)))
### Local xML method: Shapley values for time series predictions ===============
# prepare specific location of interest to compute Shapley values for at different times
ls_df_PFB_1676a_cov <- map(ls_tbl_cov_dyn_val_PFB, ~.x %>%
filter(sample_id == "1676a") %>%
# arrange rows by increasing depth
arrange(d_upper) %>%
dplyr::select(contains(tbl_var_imp$Covariate)) %>%
# order cols by variable importance
relocate(tbl_var_imp$Covariate) %>%
as.data.frame())
# Prediction wrapper for MEAN predictions
fun_pred_mean <- function(object, newdata) {
predict(object, data = newdata)$predictions
}
# Prediction wrapper for MEDIAN predictions
fun_pred_median <- function(object, newdata) {
predict(object, data = newdata,
type = "quantiles", quantiles = 0.5)$predictions
}
# Plot individual explanations
system.time({ # estimate run time
set.seed(2023)
ls_tbl_shap_sample_1676a <- foreach(y = 1:length(ls_df_PFB_1676a_cov)) %do% {
explain(
QRF_FIT_optimal$finalModel,
# order cols by variable importance
X = QRF_FIT_optimal$trainingData[,tbl_var_imp$Covariate],
pred_wrapper = fun_pred_mean,
nsim = 1e3,
newdata = ls_df_PFB_1676a_cov[[y]]
)
}
})
# time elapse SOM (mean; lab only; 1000 sim): 1 min
# time elapse SOM (median; lab only; 2 sim):
# Restructure cols into rows & combine into one tibble
tbl_shap_sample_1676a <- map2(ls_tbl_shap_sample_1676a, YEAR, ~tibble(
Covariate = colnames(.x),
Contribution = apply(.x, MARGIN = 2, FUN = function(x) x)
) %>%
arrange(-Contribution) %>%
add_column(year = .y)) %>%
bind_rows() %>%
mutate(sign = case_when(Contribution <= 0 ~ "negative",
Contribution > 0 ~ "positive"),
Contribution = round(Contribution, 2))
# plot Shapley explanations/contributions at one prediction location
ggplot(tbl_shap_sample_1676a,
aes(reorder(Covariate, Contribution), Contribution, fill = sign)) +
geom_col() +
scale_fill_manual(values = c("#ca0020", "#0571b0")) +
coord_flip() +
facet_wrap(vars(year)) +
labs(x = NULL, y = "Shapley value", fill = NULL) +
theme_bw() +
theme(legend.position = "none")
# or using waterfall pkg
library(waterfalls)
# mean predictions of sample and years of interest
v_pred_mean_sample <- ls_tbl_pred_timeseries$`1676` %>%
filter(sample_id == "1676a") %>%
pull(pred_mean)
# plot Shapley explanations/contributions at one prediction location
waterfall(tbl_shap_sample_1676a[1:33,]) +
coord_flip() +
scale_fill_manual(values = c("#ca0020", "#0571b0")) +
geom_hline(yintercept = v_pred_mean_sample[1L]) +
coord_flip() +
# facet_wrap(vars(year)) +
labs(x = NULL, y = "Shapley value", fill = NULL) +
theme_bw()
# Partial Dependence Plot for Quantile Random Forest
# for dependency of any quantile in QRF, we can use adjusted fnc based on:
# https://github.com/vlyubchich/MML/blob/master/R/partial_qrf.R
partial_qrf <- function(object, pred.var
,Q = c(0.05, 0.5, 0.95)
,...)
{
Q <- sort(unique(Q))
predfun <- function(object, newdata) {
qpred <- predict(object, newdata, type = "quantiles", quantiles = Q)$predictions
tmp <- c(apply(qpred, 2, mean))
# "q_temp" was chosen b/c should be name that doesn't occur in R environment
# see https://stackoverflow.com/questions/44702134/r-error-cannot-change-value-of-locked-binding-for-df
q_temp <- paste0("q", Q)
q_temp <<- q_temp <- gsub("\\.", "_", q_temp)
names(tmp) <- q_temp
return(tmp)
}
pdpout <- pdp::partial(object
,pred.var = pred.var
,pred.fun = predfun
,...)
if ("call" %in% ls(object)) {
response <- as.character(object$call)
response <- base::strsplit(response[2], "\\ ")[[1]][1]
colnames(pdpout)[length(pred.var) + 1] <- response
}
if (length(Q) > 1) {
colnames(pdpout)[ncol(pdpout)] <- "pred"
pdpout$pred <- factor(pdpout$pred, levels = rev(q_temp))
}
return(pdpout)
}
\ No newline at end of file
File added
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment