There is no built-in option to do that in ctree()
. The easiest method to do this "by hand" is simply:
Learn a tree with only Age
as explanatory variable and maxdepth = 1
so that this only creates a single split.
Split your data using the tree from step 1 and create a subtree for the left branch.
Split your data using the tree from step 1 and create a subtree for the right branch.
This does what you want (although I typically wouldn't recommend to do so...).
If you use the ctree()
implementation from partykit
you can also merge these three trees into a single tree again for visualizations and predictions etc. It requires a bit of hacking but is still feasible.
I will illustrate this using the iris
data and I will force a split in the variable Sepal.Length
which otherwise wouldn't be used in the tree. Learning the three trees above is easy:
library("partykit")
data("iris", package = "datasets")
tr1 <- ctree(Species ~ Sepal.Length, data = iris, maxdepth = 1)
tr2 <- ctree(Species ~ Sepal.Length + ., data = iris,
subset = predict(tr1, type = "node") == 2)
tr3 <- ctree(Species ~ Sepal.Length + ., data = iris,
subset = predict(tr1, type = "node") == 3)
Note, however, that it is important to use the formula with Sepal.Length + .
to assure that the variables in the model frame are ordered in exactly the same way in all trees.
Next comes the most technical step: We need do extract the raw node
structure from all three trees, fix-up the node id
s so that they are in a proper sequence and then integrate everything into a single node:
fixids <- function(x, startid = 1L) {
id <- startid - 1L
new_node <- function(x) {
id <<- id + 1L
if(is.terminal(x)) return(partynode(id, info = info_node(x)))
partynode(id,
split = split_node(x),
kids = lapply(kids_node(x), new_node),
surrogates = surrogates_node(x),
info = info_node(x))
}
return(new_node(x))
}
no <- node_party(tr1)
no$kids <- list(
fixids(node_party(tr2), startid = 2L),
fixids(node_party(tr3), startid = 5L)
)
no
## [1] root
## | [2] V2 <= 5.4
## | | [3] V4 <= 1.9 *
## | | [4] V4 > 1.9 *
## | [5] V2 > 5.4
## | | [6] V4 <= 4.7
## | | | [7] V4 <= 3.6 *
## | | | [8] V4 > 3.6 *
## | | [9] V4 > 4.7
## | | | [10] V5 <= 1.7 *
## | | | [11] V5 > 1.7 *
And finally we set up a joint model frame containing all data and combine that with the new joint tree. Some information on fitted nodes and the response is added to be able to turn the tree into a constparty
for nice visualization and predictions. See vignette("partykit", package = "partykit")
for the background on this:
d <- model.frame(Species ~ Sepal.Length + ., data = iris)
tr <- party(no,
data = d,
fitted = data.frame(
"(fitted)" = fitted_node(no, data = d),
"(response)" = model.response(d),
check.names = FALSE),
terms = terms(d),
)
tr <- as.constparty(tr)
And then we're done and can visualize our combined tree with the forced first split:
plot(tr)