• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python persist.save_obj函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中neon.util.persist.save_obj函数的典型用法代码示例。如果您正苦于以下问题:Python save_obj函数的具体用法?Python save_obj怎么用?Python save_obj使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了save_obj函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: save_history

    def save_history(self, epoch, model):
        # if history > 1, this function will save the last N checkpoints
        # where N is equal to self.history.  The files will have the form
        # of save_path with the epoch added to the filename before the ext

        if len(self.checkpoint_files) > self.history:
            # remove oldest checkpoint file when max count have been saved
            fn = self.checkpoint_files.popleft()
            try:
                os.remove(fn)
                logger.info("removed old checkpoint %s" % fn)
            except OSError:
                logger.warn("Could not delete old checkpoint file %s" % fn)

        path_split = os.path.splitext(self.save_path)
        save_path = "%s_%d%s" % (path_split[0], epoch, path_split[1])
        # add the current file to the deque
        self.checkpoint_files.append(save_path)
        save_obj(model.serialize(keep_states=True), save_path)

        # maintain a symlink pointing to the latest model params
        try:
            if os.path.islink(self.save_path):
                os.remove(self.save_path)
            os.symlink(os.path.split(save_path)[-1], self.save_path)
        except OSError:
            logger.warn("Could not create latest model symlink %s -> %s" % (self.save_path, save_path))
开发者ID:yapjiaqing,项目名称:neon,代码行数:27,代码来源:callbacks.py


示例2: on_epoch_end

 def on_epoch_end(self, callback_data, model, epoch):
     _eil = self._get_cached_epoch_loss(callback_data, model, epoch, "loss")
     if _eil:
         if _eil["cost"] < self.best_cost or self.best_cost is None:
             # TODO: switch this to a general seralization op
             save_obj(model.serialize(keep_states=True), self.best_path)
             self.best_cost = _eil["cost"]
开发者ID:yapjiaqing,项目名称:neon,代码行数:7,代码来源:callbacks.py


示例3: serialize

    def serialize(self, fn=None, keep_states=True):
        """
        Creates a dictionary storing the layer parameters and epochs complete.

        Arguments:
            fn (str): file to save pkl formatted model dictionary
            keep_states (bool): Whether to save optimizer states.

        Returns:
            dict: Model data including layer parameters and epochs complete.
        """

        # get the model dict with the weights
        pdict = self.get_description(get_weights=True, keep_states=keep_states)
        pdict['epoch_index'] = self.epoch_index + 1
        if self.initialized:
            if not hasattr(self.layers, 'decoder'):
                pdict['train_input_shape'] = self.layers.in_shape
            else:
                # serialize shapes both for encoder and decoder
                pdict['train_input_shape'] = (self.layers.encoder.in_shape +
                                              self.layers.decoder.in_shape)
        if fn is not None:
            save_obj(pdict, fn)
            return
        return pdict
开发者ID:rlugojr,项目名称:neon,代码行数:26,代码来源:model.py


示例4: save_params

    def save_params(self, param_path, keep_states=True):
        """
        Serializes and saves model parameters to the path specified.

        Arguments:
            param_path (str): File to write serialized parameter dict to.
            keep_states (bool): Whether to save optimizer states too.
                                Defaults to True.
        """
        save_obj(self.serialize(keep_states), param_path)
开发者ID:bin2000,项目名称:neon,代码行数:10,代码来源:model.py


示例5: on_epoch_end

    def on_epoch_end(self, epoch):

        if 'cost/validation' in self.callback_data:
            val_freq = self.callback_data['cost/validation'].attrs['epoch_freq']
            if (epoch + 1) % val_freq == 0:
                validation_cost = self.callback_data['cost/validation'][epoch/val_freq]

                if validation_cost < self.best_cost or self.best_cost is None:
                    save_obj(self.model.serialize(keep_states=True), self.best_path)
                    self.best_cost = validation_cost
开发者ID:rupertsmall,项目名称:neon,代码行数:10,代码来源:callbacks.py


示例6: save_meta

 def save_meta(self):
     save_obj({'ntrain': self.ntrain,
               'nval': self.nval,
               'train_start': self.train_start,
               'val_start': self.val_start,
               'macro_size': self.macro_size,
               'batch_prefix': self.batch_prefix,
               'global_mean': self.global_mean,
               'label_dict': self.label_dict,
               'label_names': self.label_names,
               'val_nrec': self.val_nrec,
               'train_nrec': self.train_nrec,
               'img_size': self.target_size,
               'nclass': self.nclass}, self.meta_file)
开发者ID:GerritKlaschke,项目名称:neon,代码行数:14,代码来源:batch_writer.py


示例7: on_sigint_catch

    def on_sigint_catch(self, epoch, minibatch):
        """
        Callback to handle SIGINT events

        Arguments:
            epoch (int): index of current epoch
            minibatch (int): index of minibatch that is ending
        """
        # restore the orignal handler
        signal.signal(signal.SIGINT, signal.SIG_DFL)

        # save the model
        if self.save_path is not None:
            save_obj(self.model().serialize(keep_states=True), self.save_path)
            raise KeyboardInterrupt("Checkpoint file saved to {0}".format(self.save_path))
        else:
            raise KeyboardInterrupt
开发者ID:yapjiaqing,项目名称:neon,代码行数:17,代码来源:callbacks.py


示例8: serialize

    def serialize(self, fn=None, keep_states=True):
        """
        Creates a dictionary storing the layer parameters and epochs complete.

        Arguments:
            fn (str): file to save pkl formatted model dictionary
            keep_states (bool): Whether to save optimizer states.

        Returns:
            dict: Model data including layer parameters and epochs complete.
        """

        # get the model dict with the weights
        pdict = self.get_description(get_weights=True, keep_states=keep_states)
        pdict['epoch_index'] = self.epoch_index + 1
        if fn is not None:
            save_obj(pdict, fn)
            return
        return pdict
开发者ID:maony,项目名称:neon,代码行数:19,代码来源:model.py


示例9: save_meta

 def save_meta(self):
     save_obj(
         {
             "ntrain": self.ntrain,
             "nval": self.nval,
             "train_start": self.train_start,
             "val_start": self.val_start,
             "macro_size": self.macro_size,
             "batch_prefix": self.batch_prefix,
             "global_mean": self.global_mean,
             "label_dict": self.label_dict,
             "label_names": self.label_names,
             "val_nrec": self.val_nrec,
             "train_nrec": self.train_nrec,
             "img_size": self.target_size,
             "nclass": self.nclass,
         },
         self.meta_file,
     )
开发者ID:hgl888,项目名称:neon,代码行数:19,代码来源:batch_writer.py


示例10: save_history

    def save_history(self, epoch):
        # if history > 1, this function will save the last N checkpoints
        # where N is equal to self.history.  The files will have the form
        # of save_path with the epoch added to the filename before the ext

        if len(self.checkpoint_files) > self.history:
            # remove oldest checkpoint file when max count have been saved
            fn = self.checkpoint_files.popleft()
            try:
                os.remove(fn)
                logger.info('removed old checkpoint %s' % fn)
            except OSError:
                logger.warn('Could not delete old checkpoint file %s' % fn)

        path_split = os.path.splitext(self.save_path)
        save_path = '%s_%d%s' % (path_split[0], epoch, path_split[1])
        # add the current file to the deque
        self.checkpoint_files.append(save_path)
        save_obj(self.model.serialize(keep_states=True), save_path)
开发者ID:rupertsmall,项目名称:neon,代码行数:19,代码来源:callbacks.py


示例11: get_w2v_vocab

def get_w2v_vocab(fname, max_vocab_size, cache=True):
    """
    Get ordered dict of vocab from google word2vec
    """
    if cache:
        cache_fname = fname.split('.')[0] + ".vocab"

        if os.path.isfile(cache_fname):
            vocab, vocab_size = load_obj(cache_fname)
            neon_logger.display("Word2Vec vocab cached, size is: {}".format(vocab_size))
            return vocab, vocab_size

    with open(fname, 'rb') as f:
        header = f.readline()
        vocab_size, embed_dim = map(int, header.split())
        binary_len = np.dtype('float32').itemsize * embed_dim

        neon_logger.display("Word2Vec vocab size is: {}".format(vocab_size))
        vocab_size = min(max_vocab_size, vocab_size)
        neon_logger.display("Reducing vocab size to: {}".format(vocab_size))

        vocab = OrderedDict()

        for i, line in enumerate(range(vocab_size)):
            word = []
            while True:
                ch = f.read(1)
                if ch == b' ':
                    word = (b''.join(word)).decode('utf-8')
                    break
                if ch != b'\n':
                    word.append(ch)
            f.read(binary_len)
            vocab[word] = i

    if cache:
        save_obj((vocab, vocab_size), cache_fname)

    return vocab, vocab_size
开发者ID:rlugojr,项目名称:neon,代码行数:39,代码来源:util.py


示例12: PolySchedule

lr_sched = PolySchedule(total_epochs=10, power=0.5)
opt_gdm = GradientDescentMomentum(0.01, 0.9, wdecay=0.0002, schedule=lr_sched)
opt_biases = GradientDescentMomentum(0.02, 0.9, schedule=lr_sched)

opt = MultiOptimizer({'default': opt_gdm, 'Bias': opt_biases})
if not args.resume:
    # fit the model for 3 epochs
    model.fit(train, optimizer=opt, num_epochs=3, cost=cost, callbacks=callbacks)

train.reset()
# get 1 image
for im, l in train:
    break
train.exit_batch_provider()
save_obj((im.get(), l.get()), 'im1.pkl')
im_save = im.get().copy()
if args.resume:
    (im2, l2) = load_obj('im1.pkl')
    im.set(im2)
    l.set(l2)

# run fprop and bprop on this minibatch save the results
out_fprop = model.fprop(im)

out_fprop_save = [x.get() for x in out_fprop]
im.set(im_save)
out_fprop = model.fprop(im)
out_fprop_save2 = [x.get() for x in out_fprop]
for x, y in zip(out_fprop_save, out_fprop_save2):
    assert np.max(np.abs(x - y)) == 0.0, '2 fprop iterations do not match'
开发者ID:JediKoder,项目名称:neon,代码行数:30,代码来源:inception.py


示例13: IOError

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("cache_file", help="path to data cache file")
    args = parser.parse_args()

    cache_file = args.cache_file

    # check for RW access to file
    assert os.path.exists(cache_file), "file does not exist %s" % cache_file
    if not os.access(os.path.abspath(cache_file), os.R_OK | os.W_OK):
        raise IOError("Need to add read and/or write permissions on file %s" % cache_file)

    dc = load_obj(cache_file)

    if "global_mean" not in dc or "img_size" not in dc:
        raise ValueError("data cache file missing global_mean key")

    sz = dc["img_size"]
    gm = dc["global_mean"]

    if len(gm.shape) != 2 or (gm.shape[0] != sz * sz * 3 or gm.shape[1] != 1):
        raise ValueError("global mean shape {} does not match format expected".format(gm.shape))

    # Collapse the full tensor mean into channel means and correct the order (RGB <-> BGR)
    dc["global_mean"] = np.mean(gm.reshape(3, -1), axis=1).reshape(3, 1)[::-1]

    save_obj(dc, cache_file)

    neon_logger.display("%s updated to new format" % cache_file)
开发者ID:Jokeren,项目名称:neon,代码行数:29,代码来源:update_dataset_cache.py


示例14: load_data


#.........这里部分代码省略.........
        neon_logger.display("open existing vocab file: {}".format(vocab_file_name))
        vocab, rev_vocab, word_count = load_obj(vocab_file_name)
    else:
        neon_logger.display("Building  vocab file")

        # build vocab
        word_count = defaultdict(int)
        for sent in all_sent:
            sent_words = tokenize(sent)

            if len(sent_words) > max_len_w or len(sent_words) == 0:
                continue

            for word in sent_words:
                word_count[word] += 1

        # sort the word_count , re-assign ids by its frequency. Useful for downstream tasks
        # only done for train vocab
        vocab_sorted = sorted(word_count.items(), key=lambda kv: kv[1], reverse=True)

        vocab = OrderedDict()

        # get word count as array in same ordering as vocab (but with maximum length)
        word_count_ = np.zeros((len(word_count), ), dtype=np.int64)
        for i, t in enumerate(list(zip(*vocab_sorted))[0][:max_vocab_size]):
            word_count_[i] = word_count[t]
            vocab[t] = i
        word_count = word_count_

        # generate the reverse vocab
        rev_vocab = dict((wrd_id, wrd) for wrd, wrd_id in vocab.items())

        neon_logger.display("vocabulary from {} is saved into {}".format(path, vocab_file_name))
        save_obj((vocab, rev_vocab, word_count), vocab_file_name)

    vocab_size = len(vocab)
    neon_logger.display("\nVocab size from the dataset is: {}".format(vocab_size))

    neon_logger.display("\nProcessing and saving training data into {}".format(h5_file_name))

    # now process and save the train/valid data
    h5f = h5py.File(h5_file_name, 'w', libver='latest')
    shape, maxshape = (len(train_sent),), (None)
    dt = np.dtype([('text', h5py.special_dtype(vlen=str)),
                   ('num_words', np.uint16)])
    report_text_train = h5f.create_dataset('report_train', shape=shape,
                                           maxshape=maxshape, dtype=dt,
                                           compression='gzip')
    report_train = h5f.create_dataset('train', shape=shape, maxshape=maxshape,
                                      dtype=h5py.special_dtype(vlen=np.int32),
                                      compression='gzip')

    # map text to integers
    wdata = np.zeros((1, ), dtype=dt)
    ntrain = 0
    for sent in train_sent:
        text_int = [-1 if t not in vocab else vocab[t] for t in tokenize(sent)]

        # enforce maximum sentence length
        if len(text_int) > max_len_w or len(text_int) == 0:
            continue

        report_train[ntrain] = text_int

        wdata['text'] = clean_string(sent)
        wdata['num_words'] = len(text_int)
开发者ID:NervanaSystems,项目名称:neon,代码行数:67,代码来源:data_loader.py


示例15: int

        (im_shape, im_scale, gt_boxes, gt_classes,
            num_gt_boxes, difficult) = valid_set.get_metadata_buffers()

        num_gt_boxes = int(num_gt_boxes.get())
        im_scale = float(im_scale.get())

        # retrieve region proposals generated by the model
        (proposals, num_proposals) = proposalLayer.get_proposals()

        # convert outputs to bounding boxes
        boxes = faster_rcnn.get_bboxes(outputs, proposals, num_proposals, num_classes,
                                       im_shape.get(), im_scale, max_per_image, thresh, nms_thresh)

        all_boxes[mb_idx] = boxes

        # retrieve gt boxes
        # we add a extra column to track detections during the AP calculation
        detected = np.array([False] * num_gt_boxes)
        gt_boxes = np.hstack([gt_boxes.get()[:num_gt_boxes] / im_scale,
                              gt_classes.get()[:num_gt_boxes],
                              difficult.get()[:num_gt_boxes], detected[:, np.newaxis]])

        all_gt_boxes[mb_idx] = gt_boxes

neon_logger.display('Evaluating detections')
avg_precision = voc_eval(all_boxes, all_gt_boxes, valid_set.CLASSES, use_07_metric=True)

if args.output is not None:
    neon_logger.display('Saving inference results to {}'.format(args.output))
    save_obj([all_boxes, avg_precision], args.output)
开发者ID:NervanaSystems,项目名称:neon,代码行数:30,代码来源:inference.py


示例16: __init__


#.........这里部分代码省略.........
        # how many ROIs to use to train frcnn
        self.frcn_rois_per_img = frcn_rois_per_img if frcn_rois_per_img \
            else self.FRCNN_ROI_PER_IMAGE

        assert self.img_per_batch == 1, "Only a minibatch of 1 is supported."

        self.num_classes = len(self.CLASSES)
        self._class_to_index = dict(list(zip(self.CLASSES, list(range(self.num_classes)))))

        # shape of the final conv layer
        if conv_size:
            self._conv_size = conv_size
        else:
            self._conv_size = int(np.floor(self.MAX_SIZE * self.SCALE))
        self._feat_stride = 1 / float(self.SCALE)
        self._num_scales = len(self.SCALES) * len(self.RATIOS)
        self._total_anchors = self._conv_size * self._conv_size * self._num_scales
        self.shuffle = shuffle
        self.deterministic = deterministic
        self.add_flipped = add_flipped

        # load the configure the dataset paths
        self.config = self.load_data()

        # annotation metadata
        self._annotation_file_ext = '.xml'
        self._annotation_obj_tag = 'object'
        self._annotation_class_tag = 'name'
        self._annotation_xmin_tag = 'xmin'
        self._annotation_xmax_tag = 'xmax'
        self._annotation_ymin_tag = 'ymin'
        self._annotation_ymax_tag = 'ymax'

        # self.rois_per_batch is 128 (2*64) ROIs
        # But the image path batch size is self.img_per_batch
        # need to control the batch size here
        assert self.img_per_batch is 1, "Only a batch size of 1 image is supported"

        neon_logger.display("Backend batchsize is changed to be {} "
                            "from Object Localization dataset".format(
                             self.img_per_batch))

        self.be.bsz = self.img_per_batch

        # 0. allocate buffers
        self.allocate()

        if not self.mock_db:
            # 1. read image index file
            assert os.path.exists(self.config['image_path']), \
                'Image index file does not exist: {}'.format(self.config['image_path'])
            with open(self.config['index_path']) as f:
                self.image_index = [x.strip() for x in f.readlines()]

            num_images = len(self.image_index)
            self.num_image_entries = num_images * 2 if self.add_flipped else num_images
            self.ndata = self.num_image_entries * self.rois_per_img
        else:
            self.num_image_entries = 1
            self.ndata = self.num_image_entries * self.rois_per_img

        assert (subset_pct > 0 and subset_pct <= 100), ('subset_pct must be between 0 and 100')

        if n_mb is not None:
            self.nbatches = n_mb
        else:
            self.nbatches = int(self.num_image_entries / self.img_per_batch * subset_pct / 100)

        self.cache_file = self.config['cache_path']

        if os.path.exists(self.cache_file) and not rebuild_cache and not self.mock_db:
            self.roi_db = load_obj(self.cache_file)
            neon_logger.display('ROI dataset loaded from file {}'.format(self.cache_file))

        elif not self.mock_db:
            # 2. read object Annotations (XML)
            roi_db = self.load_roi_groundtruth()

            if(self.add_flipped):
                roi_db = self.add_flipped_db(roi_db)

            # 3. construct acnhor targets
            self.roi_db = self.add_anchors(roi_db)

            if NORMALIZE_BBOX_TARGETS:
                # 4. normalize bbox targets by class
                self.roi_db = self.normalize_bbox_targets(self.roi_db)

            save_obj(self.roi_db, self.cache_file)
            neon_logger.display('wrote ROI dataset to {}'.format(self.cache_file))

        else:
            assert self.mock_db is not None
            roi_db = [self.mock_db]
            self.roi_db = self.add_anchors(roi_db)

        # 4. map anchors back to full canvas.
        # This is neccessary because the network outputs reflect the full canvas.
        # We cache the files in the unmapped state (above) to save memory.
        self.roi_db = unmap(self.roi_db)
开发者ID:Jokeren,项目名称:neon,代码行数:101,代码来源:objectlocalization.py


示例17: test_model_serialize

def test_model_serialize(backend_default, data):
    (X_train, y_train), (X_test, y_test), nclass = load_mnist(path=data)

    train_set = DataIterator(
        [X_train, X_train], y_train, nclass=nclass, lshape=(1, 28, 28))

    init_norm = Gaussian(loc=0.0, scale=0.01)

    # initialize model
    path1 = Sequential([Conv((5, 5, 16), init=init_norm, bias=Constant(0), activation=Rectlin()),
                        Pooling(2),
                        Affine(nout=20, init=init_norm, bias=init_norm, activation=Rectlin())])
    path2 = Sequential([Affine(nout=100, init=init_norm, bias=Constant(0), activation=Rectlin()),
                        Dropout(keep=0.5),
                        Affine(nout=20, init=init_norm, bias=init_norm, activation=Rectlin())])
    layers = [MergeMultistream(layers=[path1, path2], merge="stack"),
              Affine(nout=20, init=init_norm, batch_norm=True, activation=Rectlin()),
              Affine(nout=10, init=init_norm, activation=Logistic(shortcut=True))]

    tmp_save = 'test_model_serialize_tmp_save.pickle'
    mlp = Model(layers=layers)
    mlp.optimizer = GradientDescentMomentum(learning_rate=0.1, momentum_coef=0.9)
    mlp.cost = GeneralizedCost(costfunc=CrossEntropyBinary())
    mlp.initialize(train_set, cost=mlp.cost)
    n_test = 3
    num_epochs = 3
    # Train model for num_epochs and n_test batches
    for epoch in range(num_epochs):
        for i, (x, t) in enumerate(train_set):
            x = mlp.fprop(x)
            delta = mlp.cost.get_errors(x, t)
            mlp.bprop(delta)
            mlp.optimizer.optimize(mlp.layers_to_optimize, epoch=epoch)
            if i > n_test:
                break

    # Get expected outputs of n_test batches and states of all layers
    outputs_exp = []
    pdicts_exp = [l.get_params_serialize() for l in mlp.layers_to_optimize]
    for i, (x, t) in enumerate(train_set):
        outputs_exp.append(mlp.fprop(x, inference=True))
        if i > n_test:
            break

    # Serialize model
    save_obj(mlp.serialize(keep_states=True), tmp_save)

    # Load model
    mlp = Model(layers=layers)
    mlp.load_weights(tmp_save)

    outputs = []
    pdicts = [l.get_params_serialize() for l in mlp.layers_to_optimize]
    for i, (x, t) in enumerate(train_set):
        outputs.append(mlp.fprop(x, inference=True))
        if i > n_test:
            break

    # Check outputs, states, and params are the same
    for output, output_exp in zip(outputs, outputs_exp):
        assert np.allclose(output.get(), output_exp.get())

    for pd, pd_exp in zip(pdicts, pdicts_exp):
        for s, s_e in zip(pd['states'], pd_exp['states']):
            if isinstance(s, list):  # this is the batch norm case
                for _s, _s_e in zip(s, s_e):
                    assert np.allclose(_s, _s_e)
            else:
                assert np.allclose(s, s_e)
        for p, p_e in zip(pd['params'], pd_exp['params']):
            assert type(p) == type(p_e)
            if isinstance(p, list):  # this is the batch norm case
                for _p, _p_e in zip(p, p_e):
                    assert np.allclose(_p, _p_e)
            elif isinstance(p, np.ndarray):
                assert np.allclose(p, p_e)
            else:
                assert p == p_e

    os.remove(tmp_save)
开发者ID:GerritKlaschke,项目名称:neon,代码行数:80,代码来源:test_model.py


示例18: DataIterator

                            X_test, y_test, cluster)
                        spec_out = nout
                        spec_set = DataIterator(
                            X_spec, y_spec, nclass=spec_out, lshape=(3, 32, 32))
                        spec_test = DataIterator(
                            X_spec_test, y_spec_test, nclass=spec_out, lshape=(3, 32, 32))

                        # Train the specialist
                        specialist, opt, cost = spec_net(nout=spec_out, archive_path=gene_path)
                        callbacks = Callbacks(specialist, spec_set, args, eval_set=spec_test)
                        callbacks.add_early_stop_callback(early_stop)
                        callbacks.add_save_best_state_callback(path)
                        specialist.fit(spec_set, optimizer=opt,
                                    num_epochs=specialist.epoch_index + num_epochs, cost=cost, callbacks=callbacks)

                        # Print results
                        print 'Specialist Train misclassification error: ', specialist.eval(spec_set, metric=Misclassification())
                        print 'Specialist Test misclassification error: ', specialist.eval(spec_test, metric=Misclassification())
                        print 'Generalist Train misclassification error: ', generalist.eval(spec_set, metric=Misclassification())
                        print 'Generalist Test misclassification error: ', generalist.eval(spec_test, metric=Misclassification())
                        # specialists.append(specialist)
                        save_obj(specialist.serialize(), path)
                except:
                    path = confusion_matrix_name + '_' + clustering_name + '_' + str(num_clusters) + 'clusters/'
                    print 'Failed for ', path
                    failed.append(path)

    for f in failed:
        print f

开发者ID:seba-1511,项目名称:specialists,代码行数:29,代码来源:train_all_specs.py


示例19: save_weights

 def save_weights(self, save_path):
   save_obj(self.model.serialize(keep_states = True), save_path)
开发者ID:rockhowse,项目名称:simple_dqn,代码行数:2,代码来源:deepqnetwork.py


示例20: train

  def train(self, minibatch, epoch = 0):
    # expand components of minibatch
    prestates, steers, speeds, rewards, poststates, terminals = minibatch
    assert len(prestates.shape) == 2
    assert len(poststates.shape) == 2
    assert len(steers.shape) == 1
    assert len(speeds.shape) == 1
    assert len(rewards.shape) == 1
    assert len(terminals.shape) == 1
    assert prestates.shape == poststates.shape
    assert prestates.shape[0] == steers.shape[0] == speeds.shape[0] == rewards.shape[0] == poststates.shape[0] == terminals.shape[0]

    if self.target_steps and self.train_iterations % self.target_steps == 0:
      # HACK: serialize network to disk and read it back to clone
      filename = self.save_weights_prefix + "_target.pkl"
      save_obj(self.model.serialize(keep_states = False), filename)
      self.target_model.load_weights(filename)

    # feed-forward pass for poststates to get Q-values
    self._setInput(poststates)
    postq = self.target_model.fprop(self.input, inference = True)
    assert postq.shape == (self.num_actions, self.batch_size)

    # calculate max Q-value for each poststate
    postq = postq.asnumpyarray()
    maxsteerq = np.max(postq[:self.num_steers,:], axis=0)
    assert maxsteerq.shape == (self.batch_size,), "size: %s" % str(maxsteerq.shape)
    maxspeedq = np.max(postq[-self.num_speeds:,:], axis=0)
    assert maxspeedq.shape == (self.batch_size,)

    # feed-forward pass for prestates
    self._setInput(prestates)
    preq = self.model.fprop(self.input, inference = False)
    assert preq.shape == (self.num_actions, self.batch_size)

    # make copy of prestate Q-values as targets
    # HACK: copy() was needed to make it work on CPU
    targets = preq.asnumpyarray().copy()

    # update Q-value targets for actions taken
    for i, (steer, speed) in enumerate(zip(steers, speeds)):
      if terminals[i]:
        targets[steer, i] = float(rewards[i])
        targets[self.num_steers + speed, i] = float(rewards[i])
      else:
        targets[steer, i] = float(rewards[i]) + self.discount_rate * maxsteerq[i]
        targets[self.num_steers + speed, i] = float(rewards[i]) + self.discount_rate * maxspeedq[i]

    # copy targets to GPU memory
    self.targets.set(targets)

    # calculate errors
    deltas = self.cost.get_errors(preq, self.targets)
    assert deltas.shape == (self.num_actions, self.batch_size)
    #assert np.count_nonzero(deltas.asnumpyarray()) == 2 * self.batch_size, str(np.count_nonzero(deltas.asnumpyarray()))

    # calculate cost, just in case
    cost = self.cost.get_cost(preq, self.targets)
    assert cost.shape == (1,1)
    #print "cost:", cost.asnumpyarray()

    # clip errors
    if self.clip_error:
      self.be.clip(deltas, -self.clip_error, self.clip_error, out = deltas)

    # perform back-propagation of gradients
    self.model.bprop(deltas)

    # perform optimization
    self.optimizer.optimize(self.model.layers_to_optimize, epoch)

    '''
    if np.any(rewards < 0):
        preqq = preq.asnumpyarray().copy()
        self._setInput(prestates)
        qvalues = self.model.fprop(self.input, inference = True).asnumpyarray().copy()
        indexes = rewards < 0
        print "indexes:", indexes
        print "preq:", preqq[:, indexes].T
        print "preq':", qvalues[:, indexes].T
        print "diff:", (qvalues[:, indexes]-preqq[:, indexes]).T
        print "steers:", steers[indexes]
        print "speeds:", speeds[indexes]
        print "rewards:", rewards[indexes]
        print "terminals:", terminals[indexes]
        print "preq[0]:", preqq[:, 0]
        print "preq[0]':", qvalues[:, 0]
        print "diff:", qvalues[:, 0] - preqq[:, 0]
        print "deltas:", deltas.asnumpyarray()[:, indexes].T
        raw_input("Press Enter to continue...")
    '''

    # increase number of weight updates (needed for target clone interval)
    self.train_iterations += 1
开发者ID:tambetm,项目名称:botmobile,代码行数:94,代码来源:deepqnetwork.py



注:本文中的neon.util.persist.save_obj函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python testing.assert_tensor_equal函数代码示例发布时间:2022-05-27
下一篇:
Python persist.load_obj函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap