• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python trainer.Trainer类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中trainer.Trainer的典型用法代码示例。如果您正苦于以下问题:Python Trainer类的具体用法?Python Trainer怎么用?Python Trainer使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



在下文中一共展示了Trainer类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: NutTrainer

class NutTrainer(object):

    def __init__(self):
        self.collection = Collection('./data/')
        self.collection.importFromGProtocolBuffer('gProtoBuf')

    def command(self, argv):
        try:
            opts, args = getopt.getopt(argv,"mt:",["topic="])
        except getopt.GetoptError as err:
            print(err)
            self.usage()
            sys.exit(2)
        for opt, arg in opts:
            if opt == "-m":
                self.modify = True;
            elif opt in ("-t", "--topic"):
                if hasattr(self, 'modify'):
                    self.collection.modifyCollection(arg)
                else:
                  self.trainer = Trainer()
                  if arg in self.collection.topics:
                    self.trainer.start(self.collection.topics[arg])
                  else:
                    print("Error")
            elif o in ("-h", "--help"):
                self.usage()
                sys.exit()
            else:
                assert False, "unhandled option"
        if len(args) == 1 and hasattr(self,'topic'):
            print(args)

    def usage():
        print('remtrainer.py -t <topic>')
开发者ID:AlexBelger,项目名称:nuttrainer,代码行数:35,代码来源:nuttrainer.py


示例2: __init__

    def __init__(self, module, dataset=None, learningrate=0.01, lrdecay=1.0,
                 momentum=0., verbose=False, batchlearning=False,
                 weightdecay=0.):
        """Create a BackpropTrainer to train the specified `module` on the
        specified `dataset`.

        The learning rate gives the ratio of which parameters are changed into
        the direction of the gradient. The learning rate decreases by `lrdecay`,
        which is used to to multiply the learning rate after each training
        step. The parameters are also adjusted with respect to `momentum`, which
        is the ratio by which the gradient of the last timestep is used.

        If `batchlearning` is set, the parameters are updated only at the end of
        each epoch. Default is False.

        `weightdecay` corresponds to the weightdecay rate, where 0 is no weight
        decay at all.
        """
        Trainer.__init__(self, module)
        self.setData(dataset)
        self.verbose = verbose
        self.batchlearning = batchlearning
        self.weightdecay = weightdecay
        self.epoch = 0
        self.totalepochs = 0
        # set up gradient descender
        self.descent = GradientDescent()
        self.descent.alpha = learningrate
        self.descent.momentum = momentum
        self.descent.alphadecay = lrdecay
        self.descent.init(module.params)
开发者ID:kaeufl,项目名称:pybrain,代码行数:31,代码来源:backprop.py


示例3: main

def main(_):
  prepare_dirs_and_logger(config)

  if not config.task.lower().startswith('tsp'):
    raise Exception("[!] Task should starts with TSP")

  if config.max_enc_length is None:
    config.max_enc_length = config.max_data_length
  if config.max_dec_length is None:
    config.max_dec_length = config.max_data_length

  rng = np.random.RandomState(config.random_seed)
  tf.set_random_seed(config.random_seed)

  trainer = Trainer(config, rng)
  save_config(config.model_dir, config)

  if config.is_train:
    trainer.train()
  else:
    if not config.load_path:
      raise Exception("[!] You should specify `load_path` to load a pretrained model")
    trainer.test()

  tf.logging.info("Run finished.")
开发者ID:huyuxiang,项目名称:tensorflow_practice,代码行数:25,代码来源:main.py


示例4: train

def train(args):
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["dcnet"]

    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(
        config_dict["train_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=True)
    valid_loader = uttloader(
        config_dict["valid_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=False)
    checkpoint = config_dict["trainer"]["checkpoint"]
    logger.info("Training for {} epoches -> {}...".format(
        args.num_epoches, "default checkpoint"
        if checkpoint is None else checkpoint))

    dcnet = DCNet(num_bins, **dcnnet_conf)
    trainer = Trainer(dcnet, **config_dict["trainer"])
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)
开发者ID:jhuiac,项目名称:deep-clustering,代码行数:34,代码来源:train_dcnet.py


示例5: main

def main(config):
    prepare_dirs_and_logger(config)
    save_config(config)

    if config.is_train:
        from trainer import Trainer
        if config.dataset == 'line':
            from data_line import BatchManager
        elif config.dataset == 'ch':
            from data_ch import BatchManager
        elif config.dataset == 'kanji':
            from data_kanji import BatchManager
        elif config.dataset == 'baseball' or\
             config.dataset == 'cat':
            from data_qdraw import BatchManager

        batch_manager = BatchManager(config)
        trainer = Trainer(config, batch_manager)
        trainer.train()
    else:
        from tester import Tester
        if config.dataset == 'line':
            from data_line import BatchManager
        elif config.dataset == 'ch':
            from data_ch import BatchManager
        elif config.dataset == 'kanji':
            from data_kanji import BatchManager
        elif config.dataset == 'baseball' or\
             config.dataset == 'cat':
            from data_qdraw import BatchManager
        
        batch_manager = BatchManager(config)
        tester = Tester(config, batch_manager)
        tester.test()
开发者ID:byungsook,项目名称:vectornet,代码行数:34,代码来源:main.py


示例6: tesT_TrainingOnSentances

    def tesT_TrainingOnSentances(self):

        c = Corpus(self.txt)
        rnn = RNN(100, c.V, 50)

        trainer = Trainer(c,rnn, nepochs=50, alpha = 1.8)
        trainer.train()
开发者ID:liuhy0908,项目名称:rnnlm-1,代码行数:7,代码来源:trainer_test.py


示例7: plot_stats

def plot_stats(X,Y,model,costs):
	#two plots, the decision fcn and points and the cost over time
	y_onehot = Trainer.class_to_onehot(Y)
	f,(p1,p2) = plot.subplots(1,2)
	p2.plot(range(len(costs)),costs)
	p2.set_title("Cost over time")
	
	#plot points/centroids/decision fcn
	cls_ct = y_onehot.shape[1]
	y_cls = Trainer.onehot_to_int(y_onehot)
	colors = get_cmap("RdYlGn")(np.linspace(0,1,cls_ct))
	
	#model_cents = model.c.get_value()
	#p1.scatter(model_cents[:,0], model_cents[:,1], c='black', s=81)
	for curclass,curcolor in zip(range(cls_ct),colors):
		inds = [i for i,yi in enumerate(y_cls) if yi==curclass]
		p1.scatter(X[inds,0], X[inds,1], c=curcolor)
		
	nx,ny = 200, 200
	x = np.linspace(X[:,0].min()-1,X[:,0].max()+1,nx)
	y = np.linspace(X[:,1].min()-1,X[:,1].max()+1,ny)
	xv,yv = np.meshgrid(x,y)
	
	Z = np.array([z for z in np.c_[xv.ravel(), yv.ravel()]])
	Zp = Trainer.onehot_to_int(np.array(model.probability(Z)))
	Zp = Zp.reshape(xv.shape)
	p1.imshow(Zp, interpolation='nearest', 
				extent=(xv.min(), xv.max(), yv.min(), yv.max()),
				origin = 'lower', cmap=get_cmap("Set1"))
	
	p1.set_title("Decision boundaries and centroids")
	f.tight_layout()
	plot.show()					
开发者ID:ChenglongChen,项目名称:RBFnet,代码行数:33,代码来源:wrapper.py


示例8: Test_Trainer

class Test_Trainer(unittest.TestCase):

    def setUp(self):
        self.class_number = 21
        self.input_shape = (300, 300, 3)
        self.model = SSD300v2(self.input_shape, num_classes=self.class_number)

    def test_train(self):
        base_lr=3e-4
        self.trainer = Trainer(class_number=self.class_number,
                               input_shape=self.input_shape,
                               priors_file='prior_boxes_ssd300.pkl',
                               train_file='VOC2007_test.pkl',
                               path_prefix='./VOCdevkit/VOC2007/JPEGImages/',
                               model=self.model,
                               weight_file='weights_SSD300.hdf5',
                               freeze=('input_1', 'conv1_1', 'conv1_2', 'pool1',
                                       'conv2_1', 'conv2_2', 'pool2',
                                       'conv3_1', 'conv3_2', 'conv3_3', 'pool3'),
                               save_weight_file='./checkpoints/weights.{epoch:02d}-{val_loss:.2f}.hdf5',  # noqa
                               optim=keras.optimizers.Adam(lr=base_lr),
                               )
        self.trainer.train(nb_epoch=1)

    def teardown(self):
        try:
            subprocess.call("rm -rf " + self.trainer.log_dir, shell=True)
        except subprocess.CalledProcessError as cpe:
            print(str(cpe))
开发者ID:SnowMasaya,项目名称:ssd_keras,代码行数:29,代码来源:test_trainer.py


示例9: pre_train

def pre_train(data, das, nep = 600):
    x = data
    for ec, dc in das:
        dc.x(ec.y)
        tr = Trainer(ec.x, dc.y, src = x, dst = x, lrt = 0.005)
        tr.tune(nep, npt = 10)
        ec.x(x)
        x = ec.y().eval()
    del x
开发者ID:xiaoran831213,项目名称:az,代码行数:9,代码来源:test1.py


示例10: train

    def train(self,
              training_set_x,
              training_set_y,
              hyper_parameters,
              regularization_methods,
              activation_method,
              top=50,
              print_verbose=False,
              validation_set_x=None,
              validation_set_y=None):

        #need to convert the input into tensor variable
        training_set_x = shared(training_set_x, 'training_set_x', borrow=True)
        training_set_y = shared(training_set_y, 'training_set_y', borrow=True)

        symmetric_double_encoder = StackedDoubleEncoder(hidden_layers=[],
                                                        numpy_range=self._random_range,
                                                        input_size_x=training_set_x.get_value(borrow=True).shape[1],
                                                        input_size_y=training_set_y.get_value(borrow=True).shape[1],
                                                        batch_size=hyper_parameters.batch_size,
                                                        activation_method=activation_method)

        params = []

        #In this phase we train the stacked encoder one layer at a time
        #once a layer was added, weights not belonging to the new layer are
        #not changed
        for layer_size in hyper_parameters.layer_sizes:

            self._add_cross_encoder_layer(layer_size,
                                          symmetric_double_encoder,
                                          hyper_parameters.method_in,
                                          hyper_parameters.method_out)


        params = []
        for layer in symmetric_double_encoder:
            params.append(layer.Wx)
            params.append(layer.bias_x)
            params.append(layer.bias_y)

        params.append(symmetric_double_encoder[0].bias_x_prime)
        params.append(symmetric_double_encoder[-1].bias_y_prime)
        params.append(symmetric_double_encoder[-1].Wy)

        Trainer.train(train_set_x=training_set_x,
                      train_set_y=training_set_y,
                      hyper_parameters=hyper_parameters,
                      symmetric_double_encoder=symmetric_double_encoder,
                      params=params,
                      regularization_methods=regularization_methods,
                      print_verbose=print_verbose,
                      validation_set_x=validation_set_x,
                      validation_set_y=validation_set_y)

        return symmetric_double_encoder
开发者ID:aviveise,项目名称:double_encoder,代码行数:56,代码来源:iterative_training_nonsequential_stratagy.py


示例11: run_customization

def run_customization(image_loader, feature_extractor):
    logging.info("Start customize svm")
    logging.info("Generate sample")
    data = get_class_data(params.first_class_params, params.sample_size/2) + get_class_data(params.second_class_params, params.sample_size/2)
    random.shuffle(data)
    trainer = Trainer(image_loader, feature_extractor)
    c_range = [10 ** i for i in xrange(-5, 10)]
    gamma_range = [10 ** i for i in xrange(-5, 5)]
    results = trainer.svm_params_customization(data, params.svm_params, c_range, gamma_range)
    return results
开发者ID:SergeevPavel,项目名称:object_class_recognition,代码行数:10,代码来源:main.py


示例12: run_cross_validation

def run_cross_validation(image_loader, feature_extractor):
    logging.info("Start 5-fold cross validation")
    logging.info("For cat and dogs")
    logging.info(params.svm_params)
    logging.info("Generate sample")
    data = get_class_data(params.first_class_params, params.sample_size / 2) + get_class_data(
        params.second_class_params, params.sample_size / 2)
    random.shuffle(data)
    trainer = Trainer(image_loader, feature_extractor)
    return trainer.k_fold_cross_validation(5, data, params.svm_params, params.labels)
开发者ID:ktisha,项目名称:object_class_recognition,代码行数:10,代码来源:main.py


示例13: train

def train(*args):
    """
    trains the model based on files in the input folder
    """
    input_folder = args[0][0]
    if not input_folder:
        print "Must specify a directory of models"
        return

    trainer = Trainer(input_folder, options.output)
    trainer.train()
开发者ID:wschurman,项目名称:kittenmash,代码行数:11,代码来源:kittenmash.py


示例14: theano_perf

def theano_perf(model, Xnew, Ynew):
	#Xnew,ynew = gaussian_data_gen()
	#Xnew,ynew = exotic_data_gen()
	ynew_onehot = Trainer.class_to_onehot(ynew)
	yhat = np.array(model.predict(Xnew))
	yhat = Trainer.onehot_to_int(yhat)
	errs= 0
	for yh,t in zip(yhat,ynew):
		errs += 1 if yh != t else 0
	err_rate = 100*float(errs)/ynew.shape[0]
	print 'Accuracy:',100-err_rate,'Errors:',errs
开发者ID:ChenglongChen,项目名称:RBFnet,代码行数:11,代码来源:wrapper.py


示例15: fine_tune

def fine_tune(data, das, nep = 600):
    x = data

    ## re-wire encoders and decoders
    ecs, dcs = zip(*das)
    sda = list(ecs) + list(reversed(dcs))
    for i, j in zip(sda[:-1], sda[1:]):
        j.x(i.y) # lower output -> higher input

    tr = Trainer(sda[0].x, sda[-1].y, src = data, dst = data, lrt = 0.0005)
    tr.tune(nep, npt= 10)
    return tr
开发者ID:xiaoran831213,项目名称:az,代码行数:12,代码来源:test1.py


示例16: Recognizer

class Recognizer():
    def __init__(self):
        self.trainer = None

    def train(self, dataFileName):
        self.trainer = Trainer(dataFileName)
        self.trainer.trainAll()
        self.trainer.dump()

    def load(self):
        trainer = Trainer()
        trainer.load()
        self.trainer = trainer

    def classify(self, X, label1, label2):
        '''
        输入向量X, 在label1, label2间预测它的类属性
        '''
        positiveLabel = min(label1, label2)
        negativeLabel = max(label1, label2)
        svm = self.trainer.getSvmInstance(positiveLabel, negativeLabel)
        y = svm.predict(X)
        if y == 1:
            return positiveLabel
        elif y == -1:
            return negativeLabel
        else:
            raise

    def predict(self, X):
        count_dict = {} #{label : times}
        for i in range(10):
            for j in range(i, 10, 1):
                if i == j:
                    continue
                label = self.classify(X, i, j)
                if count_dict.has_key(label):
                    count_dict[label] += 1
                else:
                    count_dict[label] = 1

        maxTime = -1
        maxLabel = -1
        for label in count_dict:
            time = count_dict[label]
            if time > maxTime:
                maxTime = time
                maxLabel = label
        return maxLabel
开发者ID:jinyyu,项目名称:machine-learning,代码行数:49,代码来源:recognizer.py


示例17: setUp

    def setUp(self):
        from trainer import Trainer
        from database import TrainingDataBase,WordDataBase,WordRecord

        self.tr_empty = Trainer(WordDataBase(),TrainingDataBase())

        wdb = WordDataBase()
        wdb.addWord(WordRecord("aaa"))
        wdb.addWord(WordRecord("bbb"))
        wdb.addWord(WordRecord("ccc"))
        tdb = TrainingDataBase()
        tdb.add([WordRecord("aaa"),WordRecord("bbb"),WordRecord("ccc")],[WordRecord("ccc"),WordRecord("bbb")])
        tdb.add([WordRecord("aaa"),WordRecord("ccc")],[WordRecord("ccc"),WordRecord("ccc")])

        self.tr_notempty = Trainer(wdb,tdb)
开发者ID:0x1001,项目名称:jarvis,代码行数:15,代码来源:trainertest.py


示例18: test_trainSvm

    def test_trainSvm(self):
        return
        file = os.path.join('..', 'data', 'sample')
        trainer = Trainer(file)
        t_svm = trainer._trainSvm(5, 8)

        dataSet = DigitDataSet()
        dataSet.load(file).map(5, 8)
        svm = SVM()
        svm.train(dataSet, 2, 0.0001)
        m,n = dataSet.shape()
        for i in range(m):
            X = dataSet.getData(i)
            t_y = t_svm.predict(X)
            y = svm.predict(X)
            self.assertTrue(t_y == y)
开发者ID:jinyyu,项目名称:machine-learning,代码行数:16,代码来源:testTrainer.py


示例19: command

 def command(self, argv):
     try:
         opts, args = getopt.getopt(argv,"mt:",["topic="])
     except getopt.GetoptError as err:
         print(err)
         self.usage()
         sys.exit(2)
     for opt, arg in opts:
         if opt == "-m":
             self.modify = True;
         elif opt in ("-t", "--topic"):
             if hasattr(self, 'modify'):
                 self.collection.modifyCollection(arg)
             else:
               self.trainer = Trainer()
               if arg in self.collection.topics:
                 self.trainer.start(self.collection.topics[arg])
               else:
                 print("Error")
         elif o in ("-h", "--help"):
             self.usage()
             sys.exit()
         else:
             assert False, "unhandled option"
     if len(args) == 1 and hasattr(self,'topic'):
         print(args)
开发者ID:AlexBelger,项目名称:nuttrainer,代码行数:26,代码来源:nuttrainer.py


示例20: __init__

 def __init__(self, vc, opts):
     self.vc = vc
     ret,im = vc.read()
     self.numGestures = opts.num
     self.imHeight,self.imWidth,self.channels = im.shape
     self.trainer = Trainer(numGestures=opts.num, numFramesPerGesture=opts.frames, minDescriptorsPerFrame=opts.desc, numWords=opts.words, descType=opts.type, kernel=opts.kernel, numIter=opts.iter, parent=self)
     self.tester = Tester(numGestures=opts.num, minDescriptorsPerFrame=opts.desc, numWords=opts.words, descType=opts.type, numPredictions=7, parent=self)
开发者ID:arpitgit,项目名称:Talk2dHand,代码行数:7,代码来源:recognizer.py



注:本文中的trainer.Trainer类示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python traitlets.link函数代码示例发布时间:2022-05-27
下一篇:
Python train.train函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap