在需要经常进行调参的情况下,可以使用 Training Flags 来快速变换参数,比起直接修改模型参数来得快而且不易出错。
https://tensorflow.rstudio.com/tools/training_flags.html
使用 flags()
library(keras)
FLAGS <- flags(
flag_integer("dense_units1", 128),
flag_numeric("dropout1", 0.4),
flag_integer("dense_units2", 128),
flag_numeric("dropout2", 0.3),
flag_integer("epochs", 30),
flag_integer("batch_size", 128),
flag_numeric("learning_rate", 0.001)
)
input <- layer_input(shape = c(784))
predictions <- input %>%
layer_dense(units = FLAGS$dense_units1, activation = \'relu\') %>%
layer_dropout(rate = FLAGS$dropout1) %>%
layer_dense(units = FLAGS$dense_units2, activation = \'relu\') %>%
layer_dropout(rate = FLAGS$dropout2) %>%
layer_dense(units = 10, activation = \'softmax\')
model <- keras_model(input, predictions) %>% compile(
loss = \'categorical_crossentropy\',
optimizer = optimizer_rmsprop(lr = FLAGS$learning_rate),
metrics = c(\'accuracy\')
)
history <- model %>% fit(
x_train, y_train,
batch_size = FLAGS$batch_size,
epochs = FLAGS$epochs,
verbose = 1,
validation_split = 0.2
)
flags()
是 keras 库的函数,不是R语言本身的函数。
使用YAML文件
flags()
可以搭配YAML文件使用。按照官方教程,以为是把参数定义在YAML文件里,然后使用flags(file="flags.yml")
直接读入。但是发现这样行不通,flags(file="flags.yml")
得到的是一个空list。后来发现可能得这样使用才是正确的:
FLAGS <- flags(file = "flags.yml",
flag_integer("dense_units1", 128, "Dense units in first layer"),
flag_numeric("dropout1", 0.4, "Dropout after first layer"),
flag_integer("epochs", 30, "Number of epochs to train for")
)
flags.yml
中的参数优先,会覆盖掉flags()
里的定义,也就是说,如果 flags.yml
里面是这样定义的:
dense_units1: 256
dropout1: 0.4
epochs: 30
那么,dense_units1
这个参数的值是 256,而不是 128。
下面这种用法不正确,
FLAGS <- flags(file = "flags.yml",
)
会得到一个空list。可以认为,flags.yml
其实是用来覆盖或者说修改flags()
里面已有的参数定义。