• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Java WordVectors类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Java中org.deeplearning4j.models.embeddings.wordvectors.WordVectors的典型用法代码示例。如果您正苦于以下问题:Java WordVectors类的具体用法?Java WordVectors怎么用?Java WordVectors使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



WordVectors类属于org.deeplearning4j.models.embeddings.wordvectors包,在下文中一共展示了WordVectors类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。

示例1: testWriteWordVectorsFromWord2Vec

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Test
@Ignore
public void testWriteWordVectorsFromWord2Vec() throws IOException {
    WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
    WordVectorSerializer.writeWordVectors((Word2Vec) vec, pathToWriteto);

    WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
    INDArray wordVector1 = wordVectors.getWordVectorMatrix("Morgan_Freeman");
    INDArray wordVector2 = wordVectors.getWordVectorMatrix("JA_Montalbano");
    assertEquals(vec.getWordVectorMatrix("Morgan_Freeman"), wordVector1);
    assertEquals(vec.getWordVectorMatrix("JA_Montalbano"), wordVector2);
    assertTrue(wordVector1.length() == 300);
    assertTrue(wordVector2.length() == 300);
    assertEquals(wordVector1.getDouble(0), 0.044423, 1e-3);
    assertEquals(wordVector2.getDouble(0), 0.051964, 1e-3);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:WordVectorSerializerTest.java


示例2: getDataSetIterator

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
public static DataSetIterator getDataSetIterator(String DATA_PATH, boolean isTraining, WordVectors wordVectors, int minibatchSize,
                                                  int maxSentenceLength, Random rng ){
    String path = FilenameUtils.concat(DATA_PATH, (isTraining ? "aclImdb/train/" : "aclImdb/test/"));
    String positiveBaseDir = FilenameUtils.concat(path, "pos");
    String negativeBaseDir = FilenameUtils.concat(path, "neg");

    File filePositive = new File(positiveBaseDir);
    File fileNegative = new File(negativeBaseDir);

    Map<String,List<File>> reviewFilesMap = new HashMap<>();
    reviewFilesMap.put("Positive", Arrays.asList(filePositive.listFiles()));
    reviewFilesMap.put("Negative", Arrays.asList(fileNegative.listFiles()));

    LabeledSentenceProvider sentenceProvider = new FileLabeledSentenceProvider(reviewFilesMap, rng);

    return new CnnSentenceDataSetIterator.Builder()
            .sentenceProvider(sentenceProvider)
            .wordVectors(wordVectors)
            .minibatchSize(minibatchSize)
            .maxSentenceLength(maxSentenceLength)
            .useNormalizedWordVectors(false)
            .build();
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:24,代码来源:CNNSentenceClassification.java


示例3: SentimentExampleIterator

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 * @param dataDirectory the directory of the IMDB review data set
 * @param wordVectors WordVectors object
 * @param batchSize Size of each minibatch for training
 * @param truncateLength If reviews exceed
 * @param train If true: return the training data. If false: return the testing data.
 */
public SentimentExampleIterator(String dataDirectory, WordVectors wordVectors, int batchSize, int truncateLength, boolean train) throws IOException {
	this.batchSize = batchSize;
	this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;


	File p = new File(FilenameUtils.concat(dataDirectory, "aclImdb/" + (train ? "train" : "test") + "/pos/") + "/");
	File n = new File(FilenameUtils.concat(dataDirectory, "aclImdb/" + (train ? "train" : "test") + "/neg/") + "/");
	positiveFiles = p.listFiles();
	negativeFiles = n.listFiles();

	this.wordVectors = wordVectors;
	this.truncateLength = truncateLength;

	tokenizerFactory = new DefaultTokenizerFactory();
	tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:24,代码来源:SentimentExampleIterator.java


示例4: RnnTextEmbeddingDataSetIterator

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 * @param data Instances with documents and labels
 * @param wordVectors WordVectors object
 * @param tokenFact Tokenizer factory
 * @param tpp Token pre processor
 * @param stopWords Stop word object
 * @param batchSize Size of each minibatch for training
 * @param truncateLength If reviews exceed
 */
public RnnTextEmbeddingDataSetIterator(
    Instances data,
    WordVectors wordVectors,
    TokenizerFactory tokenFact,
    TokenPreProcess tpp,
    AbstractStopwords stopWords,
    LabeledSentenceProvider sentenceProvider,
    int batchSize,
    int truncateLength) {
  this.batchSize = batchSize;
  this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;

  this.data = data;

  this.wordVectors = wordVectors;
  this.truncateLength = truncateLength;

  this.tokenizerFactory = tokenFact;
  this.tokenizerFactory.setTokenPreProcessor(tpp);
  this.stopWords = stopWords;
  this.sentenceProvider = sentenceProvider;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:32,代码来源:RnnTextEmbeddingDataSetIterator.java


示例5: SentimentRecurrentIterator

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 * @param dataDirectory the directory of the IMDB review data set
 * @param wordVectors WordVectors object
 * @param batchSize Size of each minibatch for training
 * @param truncateLength If reviews exceed
 * @param train If true: return the training data. If false: return the testing data.
 */
public SentimentRecurrentIterator(String dataDirectory, WordVectors wordVectors, int batchSize, int truncateLength, boolean train) throws IOException {
  this.batchSize = batchSize;
  this.vectorSize = wordVectors.lookupTable().layerSize();

  File p = new File(FilenameUtils.concat(dataDirectory, (train ? "train" : "test") + "/positive/") + "/");
  File n = new File(FilenameUtils.concat(dataDirectory, (train ? "train" : "test") + "/negative/") + "/");
  positiveFiles = p.listFiles();
  negativeFiles = n.listFiles();
  numPositives  = positiveFiles.length;
  numNegatives  = negativeFiles.length;
  numTotals     = numPositives+numNegatives;
  rnd           = new Random(1);

  this.wordVectors = wordVectors;
  this.truncateLength = truncateLength;

  tokenizerFactory = new DefaultTokenizerFactory();
  tokenizerFactory.setTokenPreProcessor(new LowCasePreProcessor());
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:27,代码来源:SentimentRecurrentIterator.java


示例6: main

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: sentimentモデル名
 * args[2] input: test親フォルダ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);

  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
  Evaluation evaluation = new Evaluation();
  while(test.hasNext()) {
    DataSet t = test.next();
    INDArray features = t.getFeatures();
    INDArray lables = t.getLabels();
    INDArray inMask = t.getFeaturesMaskArray();
    INDArray outMask = t.getLabelsMaskArray();
    INDArray predicted = model.output(features,false,inMask,outMask);
    evaluation.evalTimeSeries(lables,predicted,outMask);
  }
  System.out.println(evaluation.stats());
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:30,代码来源:SentimentRecurrentTestCmd.java


示例7: testWriteWordVectors

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Test
@Ignore
public void testWriteWordVectors() throws IOException {
    WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
    InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();
    WordVectorSerializer.writeWordVectors(lookupTable, lookupCache, pathToWriteto);

    WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
    double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
    double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
    assertTrue(wordVector1.length == 300);
    assertTrue(wordVector2.length == 300);
    assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
    assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:WordVectorSerializerTest.java


示例8: testFromTableAndVocab

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Test
@Ignore
public void testFromTableAndVocab() throws IOException {

    WordVectors vec = WordVectorSerializer.loadGoogleModel(textFile, false);
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
    InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();

    WordVectors wordVectors = WordVectorSerializer.fromTableAndVocab(lookupTable, lookupCache);
    double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
    double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
    assertTrue(wordVector1.length == 300);
    assertTrue(wordVector2.length == 300);
    assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
    assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:WordVectorSerializerTest.java


示例9: testStaticLoaderArchive

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 * This method tests ZIP file loading as static model
 *
 * @throws Exception
 */
@Test
public void testStaticLoaderArchive() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();

    WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
    WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(w2v);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
    INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("night");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:WordVectorSerializerTest.java


示例10: testUnifiedLoaderArchive1

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Test
public void testUnifiedLoaderArchive1() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();

    WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
    WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(w2v, false);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
    INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("night");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);

    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1Neg());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:WordVectorSerializerTest.java


示例11: testUnifiedLoaderArchive2

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Test
public void testUnifiedLoaderArchive2() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();

    WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
    WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(w2v, true);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
    INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("night");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);

    assertNotEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:WordVectorSerializerTest.java


示例12: testUnifiedLoaderText

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 * This method tests CSV file loading via unified loader
 *
 * @throws Exception
 */
@Test
public void testUnifiedLoaderText() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    WordVectors vectorsLive = WordVectorSerializer.loadTxtVectors(textFile);
    WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(textFile, true);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
    INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("Morgan_Freeman");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);

    // we're trying EXTENDED model, but file doesn't have syn1/huffman info, so it should be silently degraded to simplified model
    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:WordVectorSerializerTest.java


示例13: windows

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 * Constructs a list of window of size windowSize.
 * Note that padding for each window is created as well.
 * @param words the words to tokenize and construct windows from
 * @param tokenizerFactory tokenizer factory to use
 * @param windowSize the window size to generate
 * @return the list of windows for the tokenized string
 */
public static List<Window> windows(String words, @NonNull TokenizerFactory tokenizerFactory, int windowSize,
                WordVectors vectors) {
    Tokenizer tokenizer = tokenizerFactory.create(words);
    List<String> list = new ArrayList<>();
    while (tokenizer.hasMoreTokens()) {
        String token = tokenizer.nextToken();

        // if we don't have UNK word defined - we have to skip this word
        if (vectors.getWordVectorMatrix(token) != null)
            list.add(token);
    }

    if (list.isEmpty())
        throw new IllegalStateException("No tokens found for windows");

    return windows(list, windowSize);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:Windows.java


示例14: testGoogleModelForInference

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Ignore
@Test
public void testGoogleModelForInference() throws Exception {
    WordVectors googleVectors = WordVectorSerializer.loadGoogleModelNonNormalized(
                    new File("/ext/GoogleNews-vectors-negative300.bin.gz"), true, false);

    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());

    ParagraphVectors pv =
                    new ParagraphVectors.Builder().tokenizerFactory(t).iterations(10).useHierarchicSoftmax(false)
                                    .trainWordVectors(false).iterations(10).useExistingWordVectors(googleVectors)
                                    .negativeSample(10).sequenceLearningAlgorithm(new DM<VocabWord>()).build();

    INDArray vec1 = pv.inferVector("This text is pretty awesome");
    INDArray vec2 = pv.inferVector("Fantastic process of crazy things happening inside just for history purposes");

    log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:ParagraphVectorsTest.java


示例15: testGlove

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Test
public void testGlove() throws Exception {
    Glove glove = new Glove(true, 5, 100);
    JavaRDD<String> corpus = sc.textFile(new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath())
                    .map(new Function<String, String>() {
                        @Override
                        public String call(String s) throws Exception {
                            return s.toLowerCase();
                        }
                    });


    Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
    WordVectors vectors = WordVectorSerializer
                    .fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
    Collection<String> words = vectors.wordsNearest("day", 20);
    assertTrue(words.contains("week"));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:GloveTest.java


示例16: getSenseEmbedding

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 *
 * @param wordVector The Word Embeddings dictionary
 * @param synset The synset ...
 * @param word The word ...
 * @return The sense embedding of a synset
 */
public static double[] getSenseEmbedding(WordVectors wordVector, Synset synset, String word, SenseComputation senseComputation) {
    if(wordEmbeddings.containsKey(synset)){
        return ArrayUtils.toPrimitive(wordEmbeddings.get(synset));
    }

    String[] words = getSenseBag(synset, word);

    double[] senseEmbedding, tmpEmbedding;
    Double[] tmpEmbedding2, tmpSenseEmbedding;
    ArrayList<Double[]> senseEmbeddings = new ArrayList<>();

    // For each word in the sense bag, get the coresponding word embeddings and store them in an array
    for (String w : words) {
        if (w != null) {
            if (wordVector.hasWord(w)) {
                tmpEmbedding = wordVector.getWordVector(w);

                tmpEmbedding2 = new Double[tmpEmbedding.length];
                for (int i = 0; i < tmpEmbedding.length; i++) {
                    tmpEmbedding2[i] = tmpEmbedding[i];
                }
                senseEmbeddings.add(tmpEmbedding2);
            }
        }
    }

    senseEmbedding = senseComputation.compute(senseEmbeddings);

    tmpSenseEmbedding = new Double[senseEmbedding.length];
    for (int i = 0; i < tmpSenseEmbedding.length; i++) {
        tmpSenseEmbedding[i] = senseEmbedding[i];
    }
    wordEmbeddings.put(synset, tmpSenseEmbedding);

    return senseEmbedding;
}
 
开发者ID:butnaruandrei,项目名称:ShotgunWSD,代码行数:44,代码来源:SenseEmbedding.java


示例17: useExistingWordVectors

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
/**
 * This method allows you to use pre-built WordVectors model (Word2Vec or GloVe) for Par2Hier.
 * Existing model will be transferred into new model before training starts.
 *
 * PLEASE NOTE: Non-normalized model is recommended to use here.
 *
 * @param vec existing WordVectors model
 * @return a builder
 */
@Override
@SuppressWarnings("unchecked")
public Builder useExistingWordVectors(@NonNull WordVectors vec) {
  if (((InMemoryLookupTable<VocabWord>) vec.lookupTable()).getSyn1() == null &&
      ((InMemoryLookupTable<VocabWord>) vec.lookupTable()).getSyn1Neg() == null) {
    throw new ND4JIllegalStateException("Model being passed as existing has no syn1/syn1Neg available");
  }

  this.existingVectors = vec;
  return this;
}
 
开发者ID:tteofili,项目名称:par2hier,代码行数:21,代码来源:Par2Hier.java


示例18: testTextCnnTextFilesRegression

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Test
public void testTextCnnTextFilesRegression() throws Exception {
  CnnTextFilesEmbeddingInstanceIterator cnnTextIter = new CnnTextFilesEmbeddingInstanceIterator();
  cnnTextIter.setTrainBatchSize(64);
  cnnTextIter.setWordVectorLocation(DatasetLoader.loadGoogleNewsVectors());
  cnnTextIter.setTextsLocation(DatasetLoader.loadAngerFilesDir());
  clf.setInstanceIterator(cnnTextIter);

  cnnTextIter.initialize();
  final WordVectors wordVectors = cnnTextIter.getWordVectors();
  int vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;

  ConvolutionLayer conv1 = new ConvolutionLayer();
  conv1.setKernelSize(new int[] {3, vectorSize});
  conv1.setNOut(10);
  conv1.setStride(new int[] {1, vectorSize});
  conv1.setConvolutionMode(ConvolutionMode.Same);

  ConvolutionLayer conv2 = new ConvolutionLayer();
  conv2.setKernelSize(new int[] {2, vectorSize});
  conv2.setNOut(10);
  conv2.setStride(new int[] {1, vectorSize});
  conv2.setConvolutionMode(ConvolutionMode.Same);

  GlobalPoolingLayer gpl = new GlobalPoolingLayer();

  OutputLayer out = new OutputLayer();
  out.setLossFn(new LossMSE());
  out.setActivationFn(new ActivationIdentity());

  clf.setLayers(conv1, conv2, gpl, out);
  clf.setCacheMode(CacheMode.MEMORY);
  final Instances data = DatasetLoader.loadAngerMeta();
  TestUtil.holdout(clf, data);
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:36,代码来源:Dl4jMlpTest.java


示例19: testTextCnnTextFilesClassification

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
@Test
public void testTextCnnTextFilesClassification() throws Exception {
  CnnTextFilesEmbeddingInstanceIterator cnnTextIter = new CnnTextFilesEmbeddingInstanceIterator();
  cnnTextIter.setTrainBatchSize(64);
  cnnTextIter.setWordVectorLocation(DatasetLoader.loadGoogleNewsVectors());
  cnnTextIter.setTextsLocation(DatasetLoader.loadAngerFilesDir());
  clf.setInstanceIterator(cnnTextIter);

  cnnTextIter.initialize();
  final WordVectors wordVectors = cnnTextIter.getWordVectors();
  int vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;

  ConvolutionLayer conv1 = new ConvolutionLayer();
  conv1.setKernelSize(new int[] {4, vectorSize});
  conv1.setNOut(10);
  conv1.setStride(new int[] {1, vectorSize});
  conv1.setConvolutionMode(ConvolutionMode.Same);
  conv1.setDropOut(0.2);
  conv1.setActivationFn(new ActivationReLU());

  ConvolutionLayer conv2 = new ConvolutionLayer();
  conv2.setKernelSize(new int[] {3, vectorSize});
  conv2.setNOut(10);
  conv2.setStride(new int[] {1, vectorSize});
  conv2.setConvolutionMode(ConvolutionMode.Same);
  conv2.setDropOut(0.2);
  conv2.setActivationFn(new ActivationReLU());

  GlobalPoolingLayer gpl = new GlobalPoolingLayer();
  gpl.setDropOut(0.33);

  OutputLayer out = new OutputLayer();

  clf.setLayers(conv1, conv2, gpl, out);
  clf.setCacheMode(CacheMode.MEMORY);
  final Instances data = DatasetLoader.loadAngerMetaClassification();
  TestUtil.holdout(clf, data);
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:39,代码来源:Dl4jMlpTest.java


示例20: makeData

import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; //导入依赖的package包/类
public Instances makeData() throws Exception {
  final Instances data = TestUtil.makeTestDataset(42,
      100,
      0,
      0,
      1,
      0,
      0,
      1,
      Attribute.NUMERIC,
      1,
      false);

  WordVectors wordVectors = WordVectorSerializer.loadStaticModel(DatasetLoader.loadGoogleNewsVectors());
  String[] words = (String[]) wordVectors.vocab().words().toArray(new String[0]);

  Random rand = new Random(42);
  for (Instance inst : data) {
    StringBuilder sentence = new StringBuilder();
    for(int i = 0; i < 10; i++){
      final int idx = rand.nextInt(words.length);
      sentence.append(" ").append(words[idx]);
    }
    inst.setValue(0, sentence.toString());
  }
  return data;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:28,代码来源:CnnTextFilesEmbeddingInstanceIteratorTest.java



注:本文中的org.deeplearning4j.models.embeddings.wordvectors.WordVectors类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Java IDevice类代码示例发布时间:2022-05-21
下一篇:
Java NullDBWritable类代码示例发布时间:2022-05-21
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap