#Part 1: library input
suppressPackageStartupMessages({
library(SHAPforxgboost)
library(xgboost)
library(data.table)
library(ggplot2)
})
#part 2:
#file load and shap value calculation
a <- read.csv(file.choose())
X1 = as.matrix(a[,-1])
mod1 = xgboost::xgboost(
data = X1, label = a$SPEEDING_CRASH, gamma = 0, eta = 1,
lambda = 0,nrounds = 10, verbose = F, objective = "reg:squarederror")
# shap.values(model, X_dataset) returns the SHAP
# data matrix and ranked features by mean|SHAP|
shap_values <- shap.values(xgb_model = mod1, X_train = X1)
shap_values$mean_shap_score
#part 3:
shap_values_iris <- shap_values$shap_score
# shap.prep() returns the long-format SHAP data from either model or
shap_long_iris <- shap.prep(xgb_model = mod1, X_train = X1)
# is the same as: using given shap_contrib
shap_long_iris <- shap.prep(shap_contrib = shap_values_iris, X_train = X1)
# -------------------------------------------------------------------------
shap.plot.summary(shap_long_iris)
# option of dilute is offered to make plot faster if there are over thousands of observations
# please see documentation for details.
shap.plot.summary(shap_long_iris, x_bound = 1.5, dilute = 10)
#end...................................
0 comments:
Post a Comment