本文整理汇总了Java中org.deeplearning4j.nn.api.Model类的典型用法代码示例。如果您正苦于以下问题:Java Model类的具体用法?Java Model怎么用?Java Model使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
Model类属于org.deeplearning4j.nn.api包,在下文中一共展示了Model类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: fromFile
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
public static DLModel fromFile(File file) throws Exception {
Model model = null;
try {
System.out.println("Trying to load file as computation graph: " + file);
model = ModelSerializer.restoreComputationGraph(file);
System.out.println("Loaded Computation Graph.");
} catch (Exception e) {
try {
System.out.println("Failed to load computation graph. Trying to load model.");
model = ModelSerializer.restoreMultiLayerNetwork(file);
System.out.println("Loaded Multilayernetwork");
} catch (Exception e1) {
System.out.println("Give up trying to load file: " + file);
throw e;
}
}
return new DLModel(model);
}
开发者ID:jesuino,项目名称:java-ml-projects,代码行数:19,代码来源:DLModel.java
示例2: onEpochEnd
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onEpochEnd(Model model) {
currentEpoch++;
// Skip if this is not an evaluation epoch
if (currentEpoch % n != 0) {
return;
}
String s = "Epoch [" + currentEpoch + "/" + numEpochs + "]\n";
if (enableIntermediateEvaluations) {
s += "Train Set: \n" + evaluateDataSetIterator(model, trainIterator, true);
if (validationIterator != null) {
s += "Validation Set: \n" + evaluateDataSetIterator(model, validationIterator, false);
}
}
log(s);
}
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:21,代码来源:EpochListener.java
示例3: iterationDone
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int i) {
if (printIterations <= 0)
printIterations = 1;
if (iterCount % printIterations == 0) {
iter.reset();
double cost = 0;
double count = 0;
while(iter.hasNext()) {
DataSet minibatch = iter.next(miniBatchSize);
cost += ((MultiLayerNetwork)model).scoreExamples(minibatch, false).sumNumber().doubleValue();
count += minibatch.getLabelsMaskArray().sumNumber().doubleValue();
}
log.info(String.format("Iteration %5d test set score: %.4f", iterCount, cost/count));
}
iterCount++;
}
开发者ID:jpatanooga,项目名称:strata-2016-nyc-dl4j-rnn,代码行数:18,代码来源:HeldoutScoreIterationListener.java
示例4: iterationDone
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int iteration, int epoch) {
//Check per-iteration termination conditions
double latestScore = model.score();
trainer.setLatestScore(latestScore);
for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
if (c.terminate(latestScore)) {
trainer.setTermination(true);
trainer.setTerminationReason(c);
break;
}
}
if (trainer.getTermination()) {
// use built-in kill switch to stop fit operation
wrapper.stopFit();
}
trainer.incrementIteration();
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:EarlyStoppingParallelTrainer.java
示例5: testListenersForModel
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private static void testListenersForModel(Model model, List<IterationListener> listeners) {
int nWorkers = 2;
ParallelWrapper wrapper = new ParallelWrapper.Builder(model).workers(nWorkers).averagingFrequency(1)
.reportScoreAfterAveraging(true).build();
if (listeners != null) {
wrapper.setListeners(listeners);
}
List<DataSet> data = new ArrayList<>();
for (int i = 0; i < nWorkers; i++) {
data.add(new DataSet(Nd4j.rand(1, 10), Nd4j.rand(1, 10)));
}
DataSetIterator iter = new ExistingDataSetIterator(data);
TestListener.clearCounts();
wrapper.fit(iter);
assertEquals(2, TestListener.workerIDs.size());
assertEquals(1, TestListener.sessionIDs.size());
assertEquals(2, TestListener.forwardPassCount.get());
assertEquals(2, TestListener.backwardPassCount.get());
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:TestListeners.java
示例6: updateGradientAccordingToParams
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
if (model instanceof ComputationGraph) {
ComputationGraph graph = (ComputationGraph) model;
if (computationGraphUpdater == null) {
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
computationGraphUpdater = new ComputationGraphUpdater(graph);
}
}
computationGraphUpdater.update(gradient, getIterationCount(model), getEpochCount(model), batchSize);
} else {
if (updater == null) {
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
updater = UpdaterCreator.getUpdater(model);
}
}
Layer layer = (Layer) model;
updater.update(layer, gradient, getIterationCount(model), getEpochCount(model), batchSize);
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:BaseOptimizer.java
示例7: onForwardPass
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
int iterCount = getModelInfo(model).iterCount;
if (calcFromActivations() && updateConfig.reportingFrequency() > 0
&& (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
if (updateConfig.collectHistograms(StatsType.Activations)) {
activationHistograms = getHistograms(activations, updateConfig.numHistogramBins(StatsType.Activations));
}
if (updateConfig.collectMean(StatsType.Activations)) {
meanActivations = calculateSummaryStats(activations, StatType.Mean);
}
if (updateConfig.collectStdev(StatsType.Activations)) {
stdevActivations = calculateSummaryStats(activations, StatType.Stdev);
}
if (updateConfig.collectMeanMagnitudes(StatsType.Activations)) {
meanMagActivations = calculateSummaryStats(activations, StatType.MeanMagnitude);
}
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:BaseStatsListener.java
示例8: onGradientCalculation
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onGradientCalculation(Model model) {
int iterCount = getModelInfo(model).iterCount;
if (calcFromGradients() && updateConfig.reportingFrequency() > 0
&& (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
Gradient g = model.gradient();
if (updateConfig.collectHistograms(StatsType.Gradients)) {
gradientHistograms = getHistograms(g.gradientForVariable(), updateConfig.numHistogramBins(StatsType.Gradients));
}
if (updateConfig.collectMean(StatsType.Gradients)) {
meanGradients = calculateSummaryStats(g.gradientForVariable(), StatType.Mean);
}
if (updateConfig.collectStdev(StatsType.Gradients)) {
stdevGradient = calculateSummaryStats(g.gradientForVariable(), StatType.Stdev);
}
if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
meanMagGradients = calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude);
}
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:BaseStatsListener.java
示例9: configureListeners
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private void configureListeners(Model m, int counter) {
if (iterationListeners != null) {
List<IterationListener> list = new ArrayList<>(iterationListeners.size());
for (IterationListener l : iterationListeners) {
if (listenerRouterProvider != null && l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
rl.setStorageRouter(listenerRouterProvider.getRouter());
String workerID = UIDProvider.getJVMUID() + "_" + counter;
rl.setWorkerID(workerID);
}
list.add(l); //Don't need to clone listeners: not from broadcast, so deserialization handles
}
if (m instanceof MultiLayerNetwork)
((MultiLayerNetwork) m).setListeners(list);
else
((ComputationGraph) m).setListeners(list);
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ParameterAveragingTrainingWorker.java
示例10: testLoadNormalizersFile
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Test
public void testLoadNormalizersFile() throws Exception {
MultiLayerNetwork net = getNetwork();
File tempFile = File.createTempFile("tsfs", "fdfsdf");
tempFile.deleteOnExit();
ModelSerializer.writeModel(net, tempFile, true);
NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
ModelSerializer.addNormalizerToModel(tempFile, normalizer);
Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath());
assertEquals(model, net);
assertEquals(normalizer, normalizer1);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ModelGuesserTest.java
示例11: testLoadNormalizersInputStream
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Test
public void testLoadNormalizersInputStream() throws Exception {
MultiLayerNetwork net = getNetwork();
File tempFile = File.createTempFile("tsfs", "fdfsdf");
tempFile.deleteOnExit();
ModelSerializer.writeModel(net, tempFile, true);
NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
ModelSerializer.addNormalizerToModel(tempFile, normalizer);
Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
try (InputStream inputStream = new FileInputStream(tempFile)) {
Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(inputStream);
assertEquals(model, net);
assertEquals(normalizer, normalizer1);
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:ModelGuesserTest.java
示例12: iterationDone
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(final Model model, final int iteration) {
runOnUiThread(new Runnable() {
@Override
public void run() {
if (iteration % 100 == 0) {
double result = model.score();
String message = "\nScore at iteration " + iteration + " is " + result;
Log.d(TAG, message);
loggingArea.append(message);
}
}
});
}
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:16,代码来源:MainActivity.java
示例13: iterationDone
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int iteration) {
if(m_printIterations <= 0)
m_printIterations = 1;
if(m_iterCount % m_printIterations == 0) {
invoke();
double result = model.score();
m_progressBar.printProgress("Iteration: " + m_iterCount + ", Score: " + result);
}
m_iterCount++;
}
开发者ID:braeunlich,项目名称:anagnostes,代码行数:12,代码来源:TrainProgressIterationListener.java
示例14: iterationDone
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone (Model model,
int iteration)
{
iterCount++;
if ((iterCount % constants.listenerPeriod.getValue()) == 0) {
invoke();
final double score = model.score();
final int count = (int) iterCount;
logger.info(String.format("Score at iteration %d is %.5f", count, score));
display(epoch, count, score);
}
}
开发者ID:Audiveris,项目名称:audiveris,代码行数:16,代码来源:TrainingPanel.java
示例15: iterationDone
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int i) {
if (printIterations <= 0)
printIterations = 1;
if (iterCount % printIterations == 0) {
saveModel((MultiLayerNetwork)model, this.modelSavePath);
}
iterCount++;
}
开发者ID:jpatanooga,项目名称:strata-2016-nyc-dl4j-rnn,代码行数:10,代码来源:ModelSaver.java
示例16: iterationDone
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int i) {
if(printIterations <= 0)
printIterations = 1;
if (iterCount % printIterations == 0) {
invoke();
String[] samples = sampleBeerRatingFromNetwork(net, reader, rng, temperature, maxCharactersToSample, 1, styleIndex);
System.out.println("----- Generating Lager Beer Review Samples -----");
for (int j = 0; j < samples.length; j++) {
System.out.println("SAMPLE " + j + ": " + samples[j]);
}
}
iterCount++;
}
开发者ID:jpatanooga,项目名称:strata-2016-nyc-dl4j-rnn,代码行数:16,代码来源:SampleGeneratorListener.java
示例17: InferenceWorker
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice) {
this.inputQueue = inputQueue;
this.protoModel = model;
this.rootDevice = rootDevice;
this.setDaemon(true);
this.setName("InferenceThread-" + id);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:10,代码来源:ParallelInference.java
示例18: scoreMinibatch
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) {
if(network instanceof MultiLayerNetwork){
return ((MultiLayerNetwork) network).score(new DataSet(get0(features), get0(labels), get0(fMask), get0(lMask)), false)
* features[0].size(0);
} else if(network instanceof ComputationGraph){
return ((ComputationGraph) network).score(new MultiDataSet(features, labels, fMask, lMask))
* features[0].size(0);
} else {
throw new RuntimeException("Unknown model type: " + network.getClass());
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:13,代码来源:DataSetLossCalculator.java
示例19: updateExamplesMinibatchesCounts
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private void updateExamplesMinibatchesCounts(Model model) {
ModelInfo modelInfo = getModelInfo(model);
int examplesThisMinibatch = 0;
if (model instanceof MultiLayerNetwork) {
examplesThisMinibatch = ((MultiLayerNetwork) model).batchSize();
} else if (model instanceof ComputationGraph) {
examplesThisMinibatch = ((ComputationGraph) model).batchSize();
} else if (model instanceof Layer) {
examplesThisMinibatch = ((Layer) model).getInputMiniBatchSize();
}
modelInfo.examplesSinceLastReport += examplesThisMinibatch;
modelInfo.totalExamples += examplesThisMinibatch;
modelInfo.minibatchesSinceLastReport++;
modelInfo.totalMinibatches++;
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:16,代码来源:BaseStatsListener.java
示例20: getUpdater
import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) {
if (layer instanceof MultiLayerNetwork) {
return new MultiLayerUpdater((MultiLayerNetwork) layer);
} else if (layer instanceof ComputationGraph) {
return new ComputationGraphUpdater((ComputationGraph) layer);
} else {
return new LayerUpdater((Layer) layer);
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:10,代码来源:UpdaterCreator.java
注:本文中的org.deeplearning4j.nn.api.Model类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论