Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
617 views
in Technique[技术] by (71.8m points)

extracting variables name in lightgbm model in R

I have multiple lightgbm model in R for which I want to validate and extract the variable names used during the fit. This is really simple with a glm, but I can manage to find the way (if possible, see here) with lightgbm models.

Here a reproducible example to make everything clearer:

I use the data from lightgbm package:

library(lightgbm)
data(agaricus.train, package = "lightgbm")

I first run the basic lgbm model:

# formating the data
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
params <- list(objective = "regression", metric = "l2")
valids <- list(test = dtest)

# running the model
model_lgbm <- lgb.train(
  params = params
  , data = dtrain
  , nrounds = 10L
  , valids = valids
  , min_data = 1L
  , learning_rate = 1.0
  , early_stopping_rounds = 5L
)

Now, I can do the same thing for a glm:

## preparing the data
dd <- data.frame(label = train$label, as(train$data, "matrix")[,1:10])
## making the model
model_glm <- glm(label ~ ., data=dd, family="binomial")

From the glm, there is lots of ways to find quickly the variables used for the modeling, for example the most obvious one:

variable.names(model_glm)
 [1] "(Intercept)"         "cap.shape.bell"      "cap.shape.conical"   "cap.shape.convex"   
 [5] "cap.shape.flat"      "cap.shape.knobbed"   "cap.shape.sunken"    "cap.surface.fibrous"
 [9] "cap.surface.grooves" "cap.surface.scaly"  

This function is not implemented in lightgbm :

variable.names(model_lgbm)
NULL

And trying to get into the model object with str is not helpful:

str(model_lgbm)
Classes 'lgb.Booster', 'R6' <lgb.Booster>
  Public:
    add_valid: function (data, name) 
    best_iter: 3
    best_score: 0
    current_iter: function () 
    dump_model: function (num_iteration = NULL, feature_importance_type = 0L) 
    eval: function (data, name, feval = NULL) 
    eval_train: function (feval = NULL) 
    eval_valid: function (feval = NULL) 
    finalize: function () 
    initialize: function (params = list(), train_set = NULL, modelfile = NULL, 
    lower_bound: function () 
    predict: function (data, start_iteration = NULL, num_iteration = NULL, 
    raw: NA
    record_evals: list
    reset_parameter: function (params, ...) 
    rollback_one_iter: function () 
    save: function () 
    save_model: function (filename, num_iteration = NULL, feature_importance_type = 0L) 
    save_model_to_string: function (num_iteration = NULL, feature_importance_type = 0L) 
    set_train_data_name: function (name) 
    to_predictor: function () 
    update: function (train_set = NULL, fobj = NULL) 
    upper_bound: function () 
  Private:
    eval_names: l2
    get_eval_info: function () 
    handle: 8.19470876878865e-316
    higher_better_inner_eval: FALSE
    init_predictor: NULL
    inner_eval: function (data_name, data_idx, feval = NULL) 
    inner_predict: function (idx) 
    is_predicted_cur_iter: list
    name_train_set: training
    name_valid_sets: list
    num_class: 1
    num_dataset: 2
    predict_buffer: list
    set_objective_to_none: FALSE
    train_set: lgb.Dataset, R6
    train_set_version: 1
    valid_sets: list 

The only way I managed to access the variables names used is from the lgb.importance function, but it's less than ideal as calculating variables importance can be slow for big models and I'm not even sure it reports all the variables:

lgb.importance(model)$Feature
 [1] "odor=none"                      "stalk-root=club"               
 [3] "stalk-root=rooted"              "spore-print-color=green"       
 [5] "odor=almond"                    "odor=anise"                    
 [7] "bruises?=bruises"               "stalk-surface-below-ring=scaly"
 [9] "gill-size=broad"                "cap-surface=grooves"           
[11] "cap-shape=conical"              "gill-color=brown"              
[13] "cap-shape=bell"                 "cap-shape=flat"                
[15] "cap-surface=scaly"              "cap-color=white"               
[17] "population=clustered"  

Is there a way to access only the variable names used in the lightgbm model? Thanks.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

The comment "and I'm not even sure it reports all the variables" has me a bit confused about what you're asking for when you say "variable names used during the fit", so I've answered both interpretations here.

Both answers assume this slightly-smaller version of your reproducible example.

library(lightgbm)
data(agaricus.train, package = "lightgbm")

# formating the data
dtrain <- lgb.Dataset(
    agaricus.train$data
    , label = agaricus.train$label
)
data(agaricus.test, package = "lightgbm")
params <- list(
    objective = "regression"
    , metric = "l2"
)

# running the model
model_lgbm <- lgb.train(
    params = params
    , data = dtrain
    , nrounds = 10L
    , min_data = 1L
    , learning_rate = 1.0
)

Feature Names of the Input Dataset

If you want to know the names of all features in the input dataset that was passed to LightGBM, regardless of whether or not all those columns were chosen for splits, you can examine the dumped model.

parsed_model <- jsonlite::fromJSON(
    model_lgbm$dump_model()
)
parsed_model$feature_names
[1] "cap-shape=bell"
[2] "cap-shape=conical"
[3] "cap-shape=convex"
[4] "cap-shape=flat"
[5] "cap-shape=knobbed"
[6] "cap-shape=sunken"
[7] "cap-surface=fibrous"
[8] "cap-surface=grooves"
[9] "cap-surface=scaly"
[10] "cap-surface=smooth"

Features Chosen for Splits

If you want to know which features were actually used in splits chosen by LightGBM, you can use either lgb.model.dt.tree() or the feature importance example you gave above.

modelDT <- lgb.model.dt.tree(model_lgbm)
modelDT$split_feature

lgb.model.dt.tree() returns a data.table representation of the trained model. One row in the table corresponds to either one side of a split or to one leaf node. Rows that refer to a leaf node have NA for $split_feature.

If you have suggestions for making this easier, PRs and issues are welcome at https://github.com/microsoft/LightGBM.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...