Friday, September 18, 2020

SHAP plot

 #Part 1: library input

suppressPackageStartupMessages({

  library(SHAPforxgboost)

  library(xgboost)

  library(data.table)

  library(ggplot2)

})


#part 2:

#file load and shap value calculation

a <- read.csv(file.choose())

X1 = as.matrix(a[,-1])

mod1 = xgboost::xgboost(

  data = X1, label = a$SPEEDING_CRASH, gamma = 0, eta = 1, 

  lambda = 0,nrounds = 10, verbose = F, objective = "reg:squarederror")


# shap.values(model, X_dataset) returns the SHAP

# data matrix and ranked features by mean|SHAP|

shap_values <- shap.values(xgb_model = mod1, X_train = X1)

shap_values$mean_shap_score



#part 3:

shap_values_iris <- shap_values$shap_score


# shap.prep() returns the long-format SHAP data from either model or

shap_long_iris <- shap.prep(xgb_model = mod1, X_train = X1)

# is the same as: using given shap_contrib

shap_long_iris <- shap.prep(shap_contrib = shap_values_iris, X_train = X1)


# -------------------------------------------------------------------------



shap.plot.summary(shap_long_iris)


# option of dilute is offered to make plot faster if there are over thousands of observations

# please see documentation for details. 


shap.plot.summary(shap_long_iris, x_bound  = 1.5, dilute = 10)

#end...................................



0 comments:

Post a Comment