本文整理汇总了Python中train.train函数的典型用法代码示例。如果您正苦于以下问题:Python train函数的具体用法?Python train怎么用?Python train使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了train函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: start_offline
def start_offline(dataset):
#sys.path.append(envpath)
os.chdir(envpath)
import prepare
prepare.prepare(dataset)
import train
train.train(dataset)
开发者ID:wchgit,项目名称:wlan_positioning,代码行数:7,代码来源:go.py
示例2: main
def main():
print "In Main Experiment\n"
# get the classnames from the directory structure
directory_names = list(set(glob.glob(os.path.join("train", "*"))).difference(set(glob.glob(os.path.join("train", "*.*")))))
# get the number of rows through image count
numberofImages = parseImage.gestNumberofImages(directory_names)
num_rows = numberofImages # one row for each image in the training dataset
# We'll rescale the images to be 25x25
maxPixel = 25
imageSize = maxPixel * maxPixel
num_features = imageSize + 2 + 128 # for our ratio
X = np.zeros((num_rows, num_features), dtype=float)
y = np.zeros((num_rows)) # numeric class label
files = []
namesClasses = list() #class name list
# Get the image training data
parseImage.readImage(True, namesClasses, directory_names,X, y, files)
print "Training"
# get test result
train.train(X, y, namesClasses)
print "Testing"
test.test(num_rows, num_features, X, y, namesClasses = list())
开发者ID:LvYe-Go,项目名称:Foreign-Exchange,代码行数:28,代码来源:main.py
示例3: parse_command_line
def parse_command_line():
parser = argparse.ArgumentParser(
description="""Train, validate, and test a face detection classifier that will determine if
two faces are the same or different.""")
parser.add_argument("--test_data", help="Use preTrain model on test data, calcu the accuracy and ROC.", action="store_true")
parser.add_argument("--test_val", help="Use preTrain model on validation data, calcu the accuracy and ROC.", action="store_true")
parser.add_argument("--weights", help="""The trained model weights to use; if not provided
defaults to the network that was just trained""", type=str, default=None)
parser.add_argument("-t", "--threshold", help="The margin of two dense", type=int, default=80)
args = vars(parser.parse_args())
if os.environ.get("CAFFEHOME") == None:
print "You must set CAFFEHOME to point to where Caffe is installed. Example:"
print "export CAFFEHOME=/usr/local/caffe"
exit(1)
# Ensure the random number generator always starts from the same place for consistent tests.
random.seed(0)
lfw = data.Lfw()
lfw.load_data()
lfw.pair_data()
if args["weights"] == None:
args["weights"] = constants.TRAINED_WEIGHTS
if args["test_data"] == True:
test_pairings(lfw, weight_file=args["weights"], is_test=True, threshold=args["threshold"])
elif args["test_val"] == True:
test_pairings(lfw, weight_file=args["weights"], threshold=args["threshold"])
else:
train(True, data=lfw)
开发者ID:SundayDX,项目名称:LFW-adventure,代码行数:34,代码来源:main.py
示例4: main
def main():
flags = parse_flags()
hparams = parse_hparams(flags.hparams)
if flags.mode == 'train':
utils.resample(sample_rate=flags.sample_rate, dir=flags.train_clip_dir, csv_path=flags.train_csv_path)
train.train(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
train_csv_path=flags.train_csv_path,
train_clip_dir=flags.train_clip_dir+'/resampled',
train_dir=flags.train_dir, sample_rate=flags.sample_rate)
elif flags.mode == 'eval':
#TODO uncomment
#utils.resample(sample_rate=flags.sample_rate, dir=flags.eval_clip_dir, csv_path=flags.eval_csv_path)
evaluation.evaluate(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
eval_csv_path=flags.eval_csv_path,
eval_clip_dir=flags.eval_clip_dir+'/resampled',
checkpoint_path=flags.checkpoint_path)
else:
assert flags.mode == 'inference'
utils.resample(sample_rate=flags.sample_rate, dir=flags.test_clip_dir, csv_path='test')
inference.predict(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
test_clip_dir=flags.test_clip_dir,
checkpoint_path=flags.checkpoint_path,
predictions_csv_path=flags.predictions_csv_path)
开发者ID:ssgalitsky,项目名称:Research-Audio-classification-using-Audioset-Freesound-Databases,代码行数:29,代码来源:main.py
示例5: main
def main():
from train import train
lr = NeuralNet(n_features=2, n_hidden=10)
lr.optimizer.lr = 0.2
train(model=lr, data='lin')
train(model=lr, data='xor')
开发者ID:ticcky,项目名称:nn_intro,代码行数:8,代码来源:neural_net.py
示例6: main
def main(FLAGS):
"""
"""
if FLAGS.mode == "train":
train(FLAGS)
elif FLAGS.mode == "infer":
infer(FLAGS)
else:
raise Exception("Choose --mode=<train|infer>")
开发者ID:GKarmakar,项目名称:oreilly-pytorch,代码行数:10,代码来源:main.py
示例7: train_model
def train_model(db_file, entity_db_file, vocab_file, word2vec, **kwargs):
db = AbstractDB(db_file, 'r')
entity_db = EntityDB.load(entity_db_file)
vocab = Vocab.load(vocab_file)
if word2vec:
w2vec = ModelReader(word2vec)
else:
w2vec = None
train.train(db, entity_db, vocab, w2vec, **kwargs)
开发者ID:studio-ousia,项目名称:ntee,代码行数:11,代码来源:cli.py
示例8: test_train_success
def test_train_success(self):
train_root_dir = self._config['train_root_dir']
if not tf.gfile.Exists(train_root_dir):
tf.gfile.MakeDirs(train_root_dir)
for stage_id in train.get_stage_ids(**self._config):
tf.reset_default_graph()
real_images = provide_random_data()
model = train.build_model(stage_id, real_images, **self._config)
train.add_model_summaries(model, **self._config)
train.train(model, **self._config)
开发者ID:ALISCIFP,项目名称:models,代码行数:11,代码来源:train_test.py
示例9: train_PNet
def train_PNet(base_dir, prefix, end_epoch, display, lr):
"""
train PNet
:param dataset_dir: tfrecord path
:param prefix:
:param end_epoch:
:param display:
:param lr:
:return:
"""
net_factory = P_Net
train(net_factory,prefix, end_epoch, base_dir, display=display, base_lr=lr)
开发者ID:jiapei100,项目名称:MTCNN-Tensorflow,代码行数:12,代码来源:train_PNet.py
示例10: main
def main():
"""main function"""
# flag = True
util.check_tensorflow_version()
util.check_and_mkdir()
#util.TRAIN_YAML = yaml
config = load_yaml()
check_config(config)
hparams = create_hparams(config)
print(hparams.values())
log = Log(hparams)
hparams.logger = log.logger
train.train(hparams)
开发者ID:zeroToAll,项目名称:tensorflow_practice,代码行数:13,代码来源:main.py
示例11: predict
def predict(corpusPath, modelsPath, dummy, corpusId=None, connection=None, directed="both"):
for model in getModels(corpusPath, modelsPath, corpusId, directed):
if os.path.exists(model["model"]):
print "Skipping existing target", model["model"]
continue
print "Processing target", model["model"], "directed =", model["directed"]
if dummy:
continue
train.train(model["model"], task=CORPUS_ID, corpusDir=model["corpusDir"], connection=connection,
exampleStyles={"examples":model["exampleStyle"]}, parse="McCC",
classifierParams={"examples":"c=1,10,100,500,1000,1500,2500,3500,4000,4500,5000,7500,10000,20000,25000,27500,28000,29000,30000,35000,40000,50000,60000,65000"})
for dataset in ("devel", "test"):
if os.path.exists(model[dataset]):
evaluate(model[dataset], model[dataset + "-gold"], model[dataset + "-eval"])
开发者ID:jbjorne,项目名称:TEES,代码行数:14,代码来源:SemEval2010Task8Tools.py
示例12: main
def main():
"""
Args: data_dir save_dir logs_dir
"""
args = sys.argv
data_dir = args[1]
save_dir = args[2]
logs_dir = args[3]
sess = tf.Session()
with sess.as_default():
train_data, test_data = arrows.get_input_producers(data_dir)
train.train(arrows.build_net, train_data, test_data, logs_dir=logs_dir, save_dir=save_dir)
开发者ID:vlpolyansky,项目名称:video-cnn,代码行数:14,代码来源:arrows_train.py
示例13: train_dataset
def train_dataset(dataset, train_params):
temp_dataset_dir = dataset_dir
data_dir = os.path.join(temp_dataset_dir, dataset)
print("Data Directory: %s" % data_dir)
# Model name (layers_size_model_time)
model_name = "%d_%d_%s" % (train_params.num_layers,
train_params.rnn_size,
train_params.model)
model_dir = os.path.join(data_dir, models_dir, model_name)
print("Model Dir: %s" % model_dir)
train_args = train_params.get_training_arguments(data_dir, model_dir)
tf.reset_default_graph()
train.train(train_args)
return model_name
开发者ID:Zbot21,项目名称:char-rnn-tensorflow,代码行数:15,代码来源:automated_testing.py
示例14: main
def main():
"""
Args: data_dir save_dir logs_dir
"""
args = sys.argv
data_dir = args[1]
save_dir = args[2]
logs_dir = args[3]
sess = tf.Session()
with sess.as_default():
train_data, test_data = movie.get_input_producers(data_dir)
train.train(movie.build_net, train_data, test_data, logs_dir=logs_dir, save_dir=save_dir, need_load=True,
init_rate=0.0005, test_only=False)
开发者ID:vlpolyansky,项目名称:video-cnn,代码行数:15,代码来源:movie_train.py
示例15: k_result
def k_result(k):
train_k = random.sample(train_set,k)
scp_k = os.path.join(tempdir,'scp_k')
with open(scp_k,'w') as f:
f.writelines(train_k)
final_dir = train(outdir, config, scp_k, proto, htk_dict, words_mlf, monophones, tempdir)
return test(outdir, final_dir, wdnet, htk_dict, monophones, scp_test, words_mlf, tempdir)
开发者ID:Tdebel,项目名称:HTK-scripts,代码行数:7,代码来源:graph.py
示例16: main
def main():
x, y = load_train_data('../data/cifar-10-batches-py')
parse = build_argparser()
for seed in xrange(int(sys.argv[1])):
hp = {
'model': 'cp2f1d',
'batch_size': 512,
'n_train_steps': int(sys.argv[2]),
'np_seed': seed,
'checkpoint_dir': 'checkpoints/train_order/cp2f1d-s%i' % seed,
}
str_hp = sum(map(lambda k: ['--%s' % k, '%s' % hp[k]], hp), []) + ['--restart']
print '* arguments'
print str_hp
args = parse.parse_args(str_hp)
train(x, y, vars(args))
开发者ID:falcondai,项目名称:cifar10,代码行数:16,代码来源:train_order.py
示例17: main
def main():
hmm = HMM(*train(sys.argv[1]))
with open(sys.argv[2]) as f:
correct = 0
wrong = 0
correct_sents = 0
wrong_sents = 0
correct_known = 0
wrong_known = 0
for i, sent in enumerate(Reader(f)):
prob, path = hmm.decode([word for (word, pos) in sent])
correct1 = 0
wrong1 = 0
for (gold, predicted) in zip(sent, path):
if gold == predicted:
correct1 += 1
else:
wrong1 += 1
print('%e\t%.3f\t%s' % (prob, correct1 / (correct1 + wrong1), ' '.join('%s/%s' % pair for pair in path)))
if prob > 0:
correct_sents += 1
correct_known += correct1
wrong_known += wrong1
else:
wrong_sents += 1
correct += correct1
wrong += wrong1
print("Correctly tagged words: %s" % (correct / (correct + wrong)))
print("Sentences with non-zero probability: %s" % (correct_sents / (correct_sents + wrong_sents)))
print("Correctly tagged words when only considering sentences with non-zero probability: %s" % (correct_known / (correct_known + wrong_known)))
开发者ID:steffervescency,项目名称:compling,代码行数:35,代码来源:tagger.py
示例18: main
def main():
db = MySQLdb.connect(host="localhost",user="root",passwd="123456",db="xataka")
X_train = getX(db,0,20) ## getX returns the X matrix to by constructing it from posts starting with offset 0 and limited up to 20(1st and 2nd argument respectively)
Y_train = getY(db,0,20) ## getY returns the Y matrix to by constructing it from post catgories starting with posts from offset 0 and limited up to 20(1st and 2nd argument respectively)
X_test = getX(db,20,6)
Y_test = getY(db,20,6)
train(X_train,Y_train,100) ## train takes as arguments X,Y and batchsize
pred = predict(X_test,Y_test,1000) ## predict takes as arguments X,Y and batchsize
acc = prcntagaccurcy(pred,Y_test) ## calculate the accuracy of classifier
print "Accuracy :", acc , "%"
stop = timeit.default_timer() ## calculate time taken by program to run
print stop-start, "seconds"
开发者ID:agilemedialab,项目名称:TextCategorizer,代码行数:16,代码来源:main.py
示例19: handle_predict
def handle_predict(shop_id, pay_fn, view_fn, start_date, end_date):
pay_action_count = load_action_stat(pay_fn)
view_action_count = load_action_stat(view_fn)
week_sample = handle_sample(pay_action_count, view_action_count, start_date, end_date, "%Y-%m-%d")
predict_list = []
for week_num, sample_list in week_sample.items():
p = train.train(sample_list)
if p == None:
predict_list.append(0)
continue
sample = get_latest_sample_pay("2016-10-24", "2016-10-31", pay_action_count)
sample = map(int, sample)
rt = np.sum(np.multiply(p, sample))
predict_list.append(rt)
predict_list = map(post_handle, predict_list)
result = []
result.append(shop_id)
result += predict_list[1:] # 周二到周日
result += predict_list # 周一到周日
result += predict_list[0:1] # 周一
result = map(str, result)
print(",".join(result))
开发者ID:jiaorenyu,项目名称:learning,代码行数:28,代码来源:common.py
示例20: main
def main(_):
if not tf.gfile.Exists(FLAGS.train_root_dir):
tf.gfile.MakeDirs(FLAGS.train_root_dir)
config = _make_config_from_flags()
logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))
for stage_id in train.get_stage_ids(**config):
tf.reset_default_graph()
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
real_images = None
with tf.device('/cpu:0'), tf.name_scope('inputs'):
real_images = _provide_real_images(**config)
model = train.build_model(stage_id, real_images, **config)
train.add_model_summaries(model, **config)
train.train(model, **config)
开发者ID:ALISCIFP,项目名称:models,代码行数:16,代码来源:train_main.py
注:本文中的train.train函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论