Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
569 views
in Technique[技术] by (71.8m points)

machine learning - How to specify split in a decision tree in R programming?

I am trying to apply decision tree here. Decision tree takes care of splitting at each node itself. But at first node I want to split my tree on the basis of "Age". How do I force that.?

library(party)    
fit2 <- ctree(Churn ~ Gender + Age + LastTransaction + Payment.Method + spend + marStat, data = tsdata)
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

There is no built-in option to do that in ctree(). The easiest method to do this "by hand" is simply:

  1. Learn a tree with only Age as explanatory variable and maxdepth = 1 so that this only creates a single split.

  2. Split your data using the tree from step 1 and create a subtree for the left branch.

  3. Split your data using the tree from step 1 and create a subtree for the right branch.

This does what you want (although I typically wouldn't recommend to do so...).

If you use the ctree() implementation from partykit you can also merge these three trees into a single tree again for visualizations and predictions etc. It requires a bit of hacking but is still feasible.

I will illustrate this using the iris data and I will force a split in the variable Sepal.Length which otherwise wouldn't be used in the tree. Learning the three trees above is easy:

library("partykit")
data("iris", package = "datasets")
tr1 <- ctree(Species ~ Sepal.Length,     data = iris, maxdepth = 1)
tr2 <- ctree(Species ~ Sepal.Length + ., data = iris,
  subset = predict(tr1, type = "node") == 2)
tr3 <- ctree(Species ~ Sepal.Length + ., data = iris,
  subset = predict(tr1, type = "node") == 3)

Note, however, that it is important to use the formula with Sepal.Length + . to assure that the variables in the model frame are ordered in exactly the same way in all trees.

Next comes the most technical step: We need do extract the raw node structure from all three trees, fix-up the node ids so that they are in a proper sequence and then integrate everything into a single node:

fixids <- function(x, startid = 1L) {
  id <- startid - 1L
  new_node <- function(x) {
    id <<- id + 1L
    if(is.terminal(x)) return(partynode(id, info = info_node(x)))
    partynode(id,
      split = split_node(x),
      kids = lapply(kids_node(x), new_node),
      surrogates = surrogates_node(x),
      info = info_node(x))
  }

  return(new_node(x))   
}
no <- node_party(tr1)
no$kids <- list(
  fixids(node_party(tr2), startid = 2L),
  fixids(node_party(tr3), startid = 5L)
)
no
## [1] root
## |   [2] V2 <= 5.4
## |   |   [3] V4 <= 1.9 *
## |   |   [4] V4 > 1.9 *
## |   [5] V2 > 5.4
## |   |   [6] V4 <= 4.7
## |   |   |   [7] V4 <= 3.6 *
## |   |   |   [8] V4 > 3.6 *
## |   |   [9] V4 > 4.7
## |   |   |   [10] V5 <= 1.7 *
## |   |   |   [11] V5 > 1.7 *

And finally we set up a joint model frame containing all data and combine that with the new joint tree. Some information on fitted nodes and the response is added to be able to turn the tree into a constparty for nice visualization and predictions. See vignette("partykit", package = "partykit") for the background on this:

d <- model.frame(Species ~ Sepal.Length + ., data = iris)
tr <- party(no, 
  data = d,
  fitted = data.frame(
    "(fitted)" = fitted_node(no, data = d),
    "(response)" = model.response(d),
    check.names = FALSE),
  terms = terms(d),
)
tr <- as.constparty(tr)

And then we're done and can visualize our combined tree with the forced first split:

plot(tr)

combined tree


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...