• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python libfann.training_data函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中pyfann.libfann.training_data函数的典型用法代码示例。如果您正苦于以下问题:Python training_data函数的具体用法?Python training_data怎么用?Python training_data使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了training_data函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: test

    def test(self, ann_file, test_file):
        """Test an artificial neural network."""
        if not os.path.isfile(ann_file):
            raise IOError("Cannot open %s (no such file)" % ann_file)
        if not os.path.isfile(test_file):
            raise IOError("Cannot open %s (no such file)" % test_file)

        # Get the prefix for the classification columns.
        try:
            dependent_prefix = self.config.data.dependent_prefix
        except:
            dependent_prefix = OUTPUT_PREFIX

        self.ann = libfann.neural_net()
        self.ann.create_from_file(ann_file)

        self.test_data = TrainData()
        try:
            self.test_data.read_from_file(test_file, dependent_prefix)
        except IOError as e:
            logging.error("Failed to process the test data: %s" % e)
            exit(1)

        logging.info("Testing the neural network...")
        fann_test_data = libfann.training_data()
        fann_test_data.set_train_data(self.test_data.get_input(),
            self.test_data.get_output())

        self.ann.test_data(fann_test_data)

        mse = self.ann.get_MSE()
        logging.info("Mean Square Error on test data: %f" % mse)
开发者ID:xieyanfu,项目名称:nbclassify,代码行数:32,代码来源:training.py


示例2: main

def main():
    # setting the prediction parameters 
    known_days = 7
    predict_days = 1
    verify_days = 30

    # setting up the parameters of the network
    connection_rate = 1
    learning_rate = 0.1
    num_input = known_days * 2
    num_hidden = 60
    num_output = predict_days
    
    # setting up the parameters of the network, continued
    desired_error = 0.000040
    max_iterations = 10000
    iteration_between_reports = 100

    # setting up the network
    net = libfann.neural_net()
    net.create_sparse_array(connection_rate, (num_input, num_hidden, num_output))
    net.set_learning_rate(learning_rate)
    net.set_activation_function_output(libfann.SIGMOID_SYMMETRIC_STEPWISE)

    # read the input file and format data
    fin = open("cw3.in")
    lines = fin.readlines()
    fin.close()
    rawdata = list(map(float, lines))[-1000:]
    datain0 = rawdata[0::2]
    datain1 = rawdata[1::2]
    n0 = max(datain0) * 1.4
    n1 = max(datain1) * 1.4
    datain0 = list(map(lambda x: x / n0, datain0))
    datain1 = list(map(lambda x: x / n1, datain1))

    # train the network
    data = libfann.training_data()
    drange = range(len(datain0) - known_days - verify_days)
    data.set_train_data(
        map(lambda x: datain0[x:][:known_days] + datain1[x:][:known_days], drange),
        map(lambda x: datain0[x + known_days:][:predict_days], drange)
        )
    net.train_on_data(data, max_iterations, iteration_between_reports, desired_error)

    # 
    result = []
    for i in range(verify_days):
        dtest = datain0[-known_days - verify_days + i:][:known_days] + datain1[-known_days - verify_days + i:][:known_days]
        result += [net.run(dtest)[0] * n0]
    plot.plot(list(map(lambda x: x * n0, datain0[-verify_days: -verify_days])) + result, "r")
    plot.plot(map(lambda x: x * n0, datain0[-verify_days:]), "b")
    #plot.plot(list(map(lambda x: x * n0, datain0[-verify_days * 2: -verify_days])) + result, "r")
    #plot.plot(map(lambda x: x * n0, datain0[-verify_days * 2:]), "b")
    plot.show()

#    net.train_on_file("cw3.in", max_iterations, iteration_between_reports, desired_error)
    #print(net.run([1,1]))
    print("hehe")
    return
开发者ID:starrify,项目名称:CW2013,代码行数:60,代码来源:2013AI_cw3.py


示例3: train

    def train(self, inputs, outputs, params):
        self.p = inputs.shape[1]       #number of input features
        self.n_r = outputs.shape[1]    #size of output grid in rows
        self.n_c = outputs.shape[2]    #size of output grid in cols

        self.out_min = outputs.min()
        self.out_max = outputs.max()

        d = self.out_max - self.out_min
        self.out_min -= d / 98
        self.out_max -= d / 98

        outputs = (outputs - self.out_min) / (self.out_max - self.out_min)

        assert inputs.shape[0] == outputs.shape[0]

        nn = libfann.neural_net()
        #nn.create_standard_array((self.p, 50, 50, self.n_r*self.n_c))
        nn.create_shortcut_array((self.p, self.n_r*self.n_c))
        nn.set_learning_rate(.7)
        nn.set_activation_function_hidden(libfann.SIGMOID_SYMMETRIC)
        nn.set_activation_function_output(libfann.SIGMOID)

        data = libfann.training_data()
        data.set_train_data(inputs, outputs.reshape((-1, self.n_r*self.n_c)))

        #nn.train_on_data(data, 500, 10, .001)
        nn.cascadetrain_on_data(data, 15, 1, .001)

        nn.save('nn.net')
        nn.destroy()
开发者ID:bhumbers,项目名称:745approx,代码行数:31,代码来源:neural_approx.py


示例4: __init__

    def __init__(self,
                 datafile,
                 desired_error = 0.0000000001,
                 iterations_between_reports = 1000):
        self.datafile = datafile
        self.desired_error = desired_error
        self.iterations_between_reports  = iterations_between_reports
        f = open(datafile+".train", 'r')
        firstline = f.readline()
        f.close
        l = string.split(firstline)
        self.num_input = int(l[1])
        self.num_output = int(l[2])
        self.breeding = False
        self.stage = 0
        self.netsTried = 0
        self.maxMutations = 18
        self.populationSize = 12
        self.trainingData = libfann.training_data()
        self.trainingData.read_train_from_file(datafile+".train")
        self.testData = libfann.training_data()
        self.testData.read_train_from_file(datafile+".test")
        self.flist = [libfann.FANN_LINEAR,libfann.FANN_SIGMOID,libfann.FANN_SIGMOID_STEPWISE,libfann.FANN_SIGMOID_SYMMETRIC,libfann.FANN_SIGMOID_SYMMETRIC_STEPWISE,
                      libfann.FANN_GAUSSIAN,libfann.FANN_GAUSSIAN_SYMMETRIC,libfann.FANN_ELLIOT,libfann.FANN_ELLIOT_SYMMETRIC,libfann.FANN_LINEAR_PIECE,
                      libfann.FANN_LINEAR_PIECE_SYMMETRIC,libfann.FANN_SIN_SYMMETRIC,libfann.FANN_COS_SYMMETRIC]
        self.mutationlist = ["change_connection_rate",
                        "change_learning_rate",
                        "change_num_neurons_hidden",
                        "change_num_layers_hidden",
                        "change_max_iterations",
                        "change_training_algorithm",
                        "change_activation_function_hidden",
                        "change_activation_function_output",
                        "change_learning_momentum",
                        "change_activation_steepness_hidden",
                        "change_activation_steepness_output",
                        "change_training_param"]
        self.trmutlist = ["change_connection_type",
                          "change_quickprop_decay",
                          "change_quickprop_mu",
                          "change_rprop_increase_factor",
                          "change_rprop_decrease_factor",
                          "change_rprop_delta_min",
                          "change_rprop_delta_max",
#                          "change_rprop_delta_zero"
                           ]
开发者ID:Buggaboo,项目名称:Triathlon,代码行数:46,代码来源:Triathlon-Breeder.py


示例5: testNet

def testNet():
    data = libfann.training_data()
    data.read_train_from_file(test_file);

    ann = libfann.neural_net()
    ann.create_from_file(nn_file)

    ann.reset_MSE()
    ann.test_data(data)
    print("Mean square error: {0}".format(ann.get_MSE()));
开发者ID:jeffames-cs,项目名称:nnot,代码行数:10,代码来源:ann.py


示例6: load_data_prefix

def load_data_prefix(prefix):
	inp = numpy.loadtxt(prefix + "_i.txt")
	inp = check_matrix(inp)
	out = numpy.loadtxt(prefix + "_o.txt")
	out = check_matrix(out)

	data = fann.training_data()
	data.set_train_data(inp,out)

	return data
开发者ID:Verderey,项目名称:Classification_Attemption,代码行数:10,代码来源:demo_1.py


示例7: load_data

def load_data(filename, in_outs):
	a = numpy.loadtxt(filename)
	inp = numpy.compress(numpy.ones(in_outs[0]), a, axis=1)
	inp = check_matrix(inp)
	out = numpy.compress(numpy.concatenate([numpy.zeros(in_outs[0]), numpy.ones(in_outs[1])]), a, axis=1)
	out = check_matrix(out)

	data = fann.training_data()
	data.set_train_data(inp,out)

	return data
开发者ID:Verderey,项目名称:Classification_Attemption,代码行数:11,代码来源:demo_1.py


示例8: test

    def test(self):
        print "Creating network."
        train_data = libfann.training_data()
        train_data.read_train_from_file(tfile)
        ann = libfann.neural_net()
        ann.create_sparse_array(
            connection_rate, (len(train_data.get_input()[0]), num_neurons_hidden, len(train_data.get_output()[0]))
        )
        ann.set_learning_rate(learning_rate)
        ann.set_activation_function_hidden(libfann.SIGMOID_SYMMETRIC_STEPWISE)
        ann.set_activation_function_output(libfann.SIGMOID_STEPWISE)
        ann.set_training_algorithm(libfann.TRAIN_INCREMENTAL)
        ann.train_on_data(train_data, max_iterations, iterations_between_reports, desired_error)

        print "Testing network"
        test_data = libfann.training_data()
        test_data.read_train_from_file(test_file)
        ann.reset_MSE()
        ann.test_data(test_data)
        print "MSE error on test data: %f" % ann.get_MSE()
开发者ID:psiddarth,项目名称:Neuron,代码行数:20,代码来源:test.py


示例9: load_data

    def load_data(self, data_file,val_file=None):
        # create training data, and ann object
        print "Loading data"
        self.train_data = libfann.training_data()
        self.train_data.read_train_from_file(data_file)
        self.dim_input=self.train_data.num_input_train_data()
        self.dim_output=self.train_data.num_output_train_data()

        input=self.train_data.get_input()
        target=self.train_data.get_output()
        
        data_lo_hi=[0,0]
        for i in range(len(input)):
            if target[i][0]<0.5:
               data_lo_hi[0]+=1 
            elif target[i][0]>0.5:
               data_lo_hi[1]+=1
        
        print "\t Train data is %d low and %d high" % tuple(data_lo_hi)

        
        if (val_file and os.path.exists(val_file)):
            print "Loading validation data"
            self.do_validation=True
            self.val_data=libfann.training_data()
            self.val_data.read_train_from_file(val_file)
            input=self.val_data.get_input()
            target=self.val_data.get_output()
            data_lo_hi=[0,0]
            for i in range(len(input)):
                if target[i][0]<0.5:
                   data_lo_hi[0]+=1 
                elif target[i][0]>0.5:
                   data_lo_hi[1]+=1
            print "\t Validation data is %d low and %d high" % tuple(data_lo_hi)
        else:
            self.val_data=self.train_data
            self.do_validation=False
开发者ID:DontLookAtMe,项目名称:fann-mrnn,代码行数:38,代码来源:fann_trainer.py


示例10: mainLoop

def mainLoop():
    n_iter = 0
    last_save = 0
    min_test_MSE = 1.0
    max_iters_after_save = 50
    
    try:
        while True:
            n_iter += 1
            print "Iteration: %5d " % (n_iter),
            seg_copy = map(lambda (c, seg): (c, cv.CloneImage(seg)), segments)
            seg_copy = map(lambda (c, seg): (c, spoil(seg)), seg_copy)
            shuffle(seg_copy)
            
            f = open(train_file, "w")
            f.write("%d %d %d\n" % (len(segments), num_input, num_output))
        
            for c, image in seg_copy:
                image = adjustSize(image, (segW, segH))
                for y in range(image.height):
                    for x in range(image.width):
                        n = image[y, x] / 159.375 - 0.8
                        f.write("%f " % n)
                f.write("\n")
                n = charset.index(c)
                f.write("-1 " * n + "1" + " -1" * (num_output - n - 1) + "\n")
        
            f.close()
            
            train = libfann.training_data()
            train.read_train_from_file(train_file)
            ann.train_epoch(train)
            train.destroy_train()
            print "Train MSE: %f " % (ann.get_MSE()),
            print "Train bit fail: %5d " % (ann.get_bit_fail()),
            ann.test_data(test)
            mse = ann.get_MSE()
            print "Test MSE: %f " % (mse),
            print "Test bit fail: %5d " % (ann.get_bit_fail()),
            if mse < min_test_MSE:
                min_test_MSE = mse
                ann.save(ann_file)
                last_save = n_iter
                print "saved",
            if n_iter - last_save > max_iters_after_save: break
            print
    except KeyboardInterrupt: print "Interrupted by user."
开发者ID:woto,项目名称:EPC,代码行数:47,代码来源:train.py


示例11: train_my_net

def train_my_net(data_file, net=None):

    desired_error = 0.01
    max_iter = 100000
    report_time = 100

    if net is None:
        network = new_net()
    else:
        network = net

    data = libfann.training_data()
    data.read_train_from_file(data_file)

    network.train_on_data(data, max_iter, report_time, desired_error)

    return network
开发者ID:the-mandarine,项目名称:esiea-school-projects,代码行数:17,代码来源:test_cancer_valid.py


示例12: initNet

def initNet():
    learning_rate = 0.3
    num_neurons_hidden = num_input / 3
    
    #desired_error = 0.015
    #max_iterations = 10000
    #iterations_between_reports = 10
    
    global ann
    ann = libfann.neural_net()
    ann.create_standard_array((num_input, num_neurons_hidden, num_output))
    ann.set_learning_rate(learning_rate)
    ann.set_activation_function_hidden(libfann.SIGMOID_SYMMETRIC_STEPWISE)
    ann.set_activation_function_output(libfann.SIGMOID_SYMMETRIC_STEPWISE)

    train = libfann.training_data()
    train.read_train_from_file(train_file)
    ann.init_weights(train)
    train.destroy_train()
开发者ID:woto,项目名称:EPC,代码行数:19,代码来源:train.py


示例13: TestOnData

def TestOnData(nn, testdata):
    ann = libfann.neural_net()
    ann.create_from_file(nn)
    
    testData = libfann.training_data()
    testData.read_train_from_file(testdata)
    ann.reset_MSE()

    if args.full:
        inputs = testData.get_input()
        outputs = testData.get_output()

        missed_goodbuys = 0
        missed_badbuys = 0
        correct_goodbuys = 0
        correct_badbuys = 0

        print "#Row\tCorrect\tCalc\tFail"

        for i in xrange(len(inputs)):
            nn_out = ann.run(inputs[i])[0]
            c_out = outputs[i][0]
            s = ' '
            if c_out == 1.0 and nn_out < 0.8:
                s = 'B'
                missed_badbuys += 1
            if c_out == 0.0 and nn_out >= 0.8:
                s = 'G'
                missed_goodbuys += 1
            if c_out == 1.0 and nn_out >= 0.8:
                correct_badbuys += 1
            if c_out == 0.0 and nn_out < 0.8:
                correct_goodbuys += 1
            
            print "%5u\t%1.3f\t%1.3f\t%s" % (i+1, outputs[i][0], ann.run(inputs[i])[0], s)
        print "Missed %u bad buys of %u (%2.1f%%)" % (missed_badbuys, correct_badbuys+missed_badbuys,
                                                    float(missed_badbuys)/(correct_badbuys+missed_badbuys)*100)
        print "Missed %u good buys of %u (%2.1f%%)" % (missed_goodbuys, correct_goodbuys+missed_goodbuys,
                                                     float(missed_goodbuys)/(correct_goodbuys+missed_goodbuys)*100)
    else:
        ann.test_data(testData)
        print "Bit Fail: " + str(ann.get_bit_fail())
        print "Mean Squared Error: " + str(ann.get_MSE())
开发者ID:malthejorgensen,项目名称:DontGetKicked,代码行数:43,代码来源:train.py


示例14: XY_to_fann_train_data

def XY_to_fann_train_data(X, Y):
    if len(X) != len(Y):
        raise ValueError("X and Y must have the same number of lines.")

    train_data = libfann.training_data()

    if len(X):
        dim_X, dim_Y = len(X[0]), len(Y[0])

        tmp = tempfile.NamedTemporaryFile(delete=False)
        with tmp:
            tmp.write("%d %d %d\n"%(len(X), dim_X,  dim_Y))
            for i in xrange(len(X)):
                for line in [ X[i], Y[i] ]:
                    tmp.write("%s\n"% ' '.join( str(float(val)) for val in line ))

        train_data.read_train_from_file(tmp.name)
        tmp.unlink(tmp.name)

    return train_data
开发者ID:jmoudrik,项目名称:orange-hacks,代码行数:20,代码来源:fann_neural.py


示例15: __init__

    def __init__(self,xdat,ydat,idxs):
        if shape(xdat)[0] != shape(ydat)[0]:
            raise Exception('dimension mismatch b/w x, y')

        nt = len(xdat)
        
        ny = shape(ydat)[1]
        nx = shape(xdat)[1]

        num_input = nx;
        num_output = ny;
        num_layers = 3;
        num_neurons_hidden = 3;
        desired_error =  0.2;
        max_epochs =2000;
        epochs_between_reports = 1000;

        net = fann.neural_net()
        net.create_standard_array([num_layers, num_input, num_neurons_hidden, num_output]);

        net.set_activation_function_hidden( fann.SIGMOID_SYMMETRIC);
        net.set_activation_function_output( fann.SIGMOID_SYMMETRIC);
        
        t = fann.training_data()
        
        t.set_train_data(xdat,ydat)
        nt = net.train_on_data(t,max_epochs,epochs_between_reports,desired_error)
        out = net.save( "xor_float.net");

        print net.get_training_algorithm()
        raise Exception()

        fann.train_on_file( "xor.data", max_epochs, epochs_between_reports, desired_error);

        out = net.save( "xor_float.net");
        
        net.destroy();
开发者ID:bh0085,项目名称:compbio,代码行数:37,代码来源:backup_gagd.py


示例16:

if opts.output_activation == "SIGMOID_SYMMETRIC_STEPWISE":
	ann.set_activation_function_output(libfann.SIGMOID_SYMMETRIC_STEPWISE)
elif opts.output_activation == "GAUSSIAN":
	ann.set_activation_function_output(libfann.GAUSSIAN)
elif opts.output_activation == "GAUSSIAN_SYMMETRIC":
	ann.set_activation_function_output(libfann.GAUSSIAN_SYMMETRIC)
elif opts.output_activation == "SIGMOID":
	ann.set_activation_function_output(libfann.SIGMOID)
else:
	ann.set_activation_function_output(libfann.SIGMOID_STEPWISE)
ann.set_activation_steepness_output(opts.steep_out)


########## Import training data #####################
print "Getting training data : %s" % opts.training_file
train_data = libfann.training_data()
train_data.read_train_from_file(opts.training_file.replace(".pat",".ann"))
#train_data.scale_train_data(0.0,1.0)

########## GA Training #####################
print "Setting GA training parameters"
genome = G1DConnections.G1DConnections()
genome.evaluator.set(GAnnEvaluators.evaluateMSE)

genome.setParams(rangemin=opts.range_min, rangemax=opts.range_max, layers=layers, bias=bias, gauss_mu=opts.gauss_mu, gauss_sigma=opts.gauss_sigma)
#genome.mutator.set(GAnnMutators.G1DConnMutateNodes)
ga = GAnnGA.GAnnGA(genome, ann, train_data)
ga.setMutationRate(opts.mutation_rate)
ga.setPopulationSize(opts.population)
ga.setGenerations(opts.generations)
print "Start running GA"
开发者ID:chiewoo,项目名称:GANNCode,代码行数:31,代码来源:GAnnTrainFile.py


示例17: test_ann

def test_ann(ann_path, test_data_path, output_path=None, conf_path=None, error=0.01):
    """Test an artificial neural network."""
    for path in (ann_path, test_data_path, conf_path):
        if path and not os.path.exists(path):
            logging.error("Cannot open %s (no such file or directory)" % path)
            return 1

    if output_path and not conf_path:
        raise ValueError("Argument `conf_path` must be set when `output_path` is set")

    if conf_path:
        yml = open_yaml(conf_path)
        if not yml:
            return 1
        if 'classes' not in yml:
            logging.error("Classes are not set in the YAML file. Missing object 'classes'.")
            return 1

    # Get the prefix for the classification columns.
    dependent_prefix = "OUT:"
    if 'data' in yml:
        dependent_prefix = getattr(yml.data, 'dependent_prefix', dependent_prefix)

    ann = libfann.neural_net()
    ann.create_from_file(ann_path)

    test_data = common.TrainData()
    try:
        test_data.read_from_file(test_data_path, dependent_prefix)
    except ValueError as e:
        logging.error("Failed to process the test data: %s" % e)
        exit(1)

    logging.info("Testing the neural network...")
    fann_test_data = libfann.training_data()
    fann_test_data.set_train_data(test_data.get_input(), test_data.get_output())

    ann.test_data(fann_test_data)

    mse = ann.get_MSE()
    logging.info("Mean Square Error on test data: %f" % mse)

    if not output_path:
        return

    out_file = open(output_path, 'w')
    out_file.write( "%s\n" % "\t".join(['ID','Class','Classification','Match']) )

    # Get codeword for each class.
    codewords = get_codewords(yml.classes)

    total = 0
    correct = 0
    for label, input, output in test_data:
        total += 1
        row = []

        if label:
            row.append(label)
        else:
            row.append("")

        if len(codewords) != len(output):
            logging.error("Codeword length (%d) does not match the number of classes. "
                "Please make sure the correct classes are set in %s" % (len(output), conf_path))
            exit(1)

        class_e = get_classification(codewords, output, error)
        assert len(class_e) == 1, "The codeword for a class can only have one positive value"
        row.append(class_e[0])

        codeword = ann.run(input)
        try:
            class_f = get_classification(codewords, codeword, error)
        except ValueError as e:
            logging.error("Classification failed: %s" % e)
            return 1
        row.append(", ".join(class_f))

        # Check if the first items of the classifications match.
        if len(class_f) > 0 and class_f[0] == class_e[0]:
            row.append("+")
            correct += 1
        else:
            row.append("-")

        out_file.write( "%s\n" % "\t".join(row) )

    fraction = float(correct) / total
    out_file.write( "%s\n" % "\t".join(['','','',"%.3f" % fraction]) )
    out_file.close()

    logging.info("Correctly classified: %.1f%%" % (fraction*100))
    logging.info("Testing results written to %s" % output_path)
开发者ID:naturalis,项目名称:imgpheno,代码行数:94,代码来源:train.py


示例18: exists

if not exists(output_dir):
    os.makedirs(output_dir)

states_files = args.states_files
if len(states_files) == 1:
    states_files = glob(states_files[0])

# Convert the files and move them to the build path
if args.fast:
    n_max = 200
else:
    n_max = inf
convert_two_particle_hdf5_to_fann(states_files, output_dir, train_ratio=0.85, n_max=n_max, min_distance=args.min_distance, max_distance=args.max_distance)

# Load data
train_data = libfann.training_data()
validate_data = libfann.training_data()
test_data = libfann.training_data()

train_data_filename = str(join(output_dir, "train.fann"))
validate_data_filename = str(join(output_dir, "validate.fann"))
test_data_filename = str(join(output_dir, "test.fann"))

print "Loading data:\n", train_data_filename, "\n", validate_data_filename, "\n", test_data_filename

train_data.read_train_from_file(train_data_filename)
validate_data.read_train_from_file(validate_data_filename)
test_data.read_train_from_file(test_data_filename)

# Create and train networks
best_test_result = inf
开发者ID:dragly,项目名称:fann-md,代码行数:31,代码来源:fann_train_two_particles.py


示例19: get_training_data

def get_training_data(data_file):
    data = libfann.training_data()
    data.read_train_from_file(data_file)
    return data
开发者ID:the-mandarine,项目名称:esiea-school-projects,代码行数:4,代码来源:test_cancer_valid.py


示例20: run_fann

def run_fann( num_hidden = 4, fname = "ann_ws496.net", fname_data_prefix = '', n_iter = 100, disp = True, graph = True):
	print "num_hidden =", num_hidden    
	
	fname_data_train = fname_data_prefix + "train_in.data"
	fname_data_test = fname_data_prefix + "test_in.data"

	connection_rate = 1
	learning_rate = 0.7
	num_input = 1024
	#num_hidden = 40
	num_output = 1

	desired_error = 0.0001
	max_iterations = 1
	iterations_between_reports = 1

	ann = libfann.neural_net()
	ann.create_sparse_array(connection_rate, (num_input, num_hidden, num_output))
	ann.set_learning_rate(learning_rate)
	ann.set_activation_function_hidden(libfann.SIGMOID_SYMMETRIC)
	ann.set_activation_function_output(libfann.LINEAR)

	# train_data is loaded
	train_data = libfann.training_data()
	train_data.read_train_from_file( fname_data_train)

	# test_data is loaded
	test_data = libfann.training_data()
	test_data.read_train_from_file( fname_data_test)
	train_mse = list()
	test_mse = list()
	for ii in range(n_iter):
		# Training is performed with training data
		ann.reset_MSE()
		ann.train_on_data(train_data, max_iterations, iterations_between_reports, desired_error)

		# Testing is performed with test data
		ann.reset_MSE()
		ann.test_data(train_data)
		mse_train = ann.get_MSE(); train_mse.append( mse_train)

		# Testing is performed with test data
		ann.reset_MSE()
		ann.test_data(test_data)
		mse_test = ann.get_MSE(); test_mse.append( mse_test)

		if disp: 
			print ii, "MSE of train, test", mse_train, mse_test

	ann.save( fname)

	# We show the results of ANN training with validation. 
	if graph:
		plot( train_mse, label = 'train')
		plot( test_mse, label = 'test')
		legend( loc = 1)
		xlabel('iteration')
		ylabel('MSE')
		grid()
		show()
	
	return train_mse, test_mse
开发者ID:jskDr,项目名称:jamespy_py3,代码行数:62,代码来源:jann.py



注:本文中的pyfann.libfann.training_data函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python utility.jsonify函数代码示例发布时间:2022-05-25
下一篇:
Python libfann.neural_net函数代码示例发布时间:2022-05-25
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap