First of all, the images needs to be converted to tensors.
(首先,图像需要转换为张量。)
The first approach would be to create a tensor containing all the features (respectively a tensor containing all the labels).(第一种方法是创建包含所有特征的张量(分别包含所有标签的张量)。)
This should the way to go only if the dataset contains few images.(仅当数据集包含少量图像时,才应采用这种方式。)
const imageBuffer = await fs.readFile(feature_file);
tensorFeature = tfnode.node.decodeImage(imageBuffer) // create a tensor for the image
// create an array of all the features
// by iterating over all the images
tensorFeatures = tf.stack([tensorFeature, tensorFeature2, tensorFeature3])
The labels would be an array indicating the type of each image
(标签将是一个数组,指示每个图像的类型)
labelArray = [0, 1, 2] // maybe 0 for dog, 1 for cat and 2 for birds
One needs now to create a hot encoding of the labels
(现在需要创建标签的热编码)
tensorLabels = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 3);
Once there is the tensors, one would need to create the model for training.
(一旦有了张量,就需要创建训练模型。)
Here is a simple model.(这是一个简单的模型。)
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [height, width, numberOfChannels], // numberOfChannels = 3 for colorful images and one otherwise
filters: 32,
kernelSize: 3,
activation: 'relu',
}));
model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
Then the model can be trained
(然后可以训练模型)
model.fit(tensorFeatures, tensorLabels)
If the dataset contains a lot of images, one would need to create a tfDataset instead.
(如果数据集包含很多图像,则需要创建一个tfDataset。)
This answer discusses why.(这个答案讨论了为什么。)
const genFeatureTensor = image => {
const imageBuffer = await fs.readFile(feature_file);
return tfnode.node.decodeImage(imageBuffer)
}
const labelArray = indice => Array.from({length: numberOfClasses}, (_, k) => k === indice ? 1 : 0)
function* dataGenerator() {
const numElements = numberOfImages;
let index = 0;
while (index < numFeatures) {
const feature = genFeatureTensor ;
const label = tf.tensor1d(labelArray(classImageIndex))
index++;
yield {xs: feature, ys: label};
}
}
const ds = tf.data.generator(dataGenerator);
And use model.fitDataset(ds) to train the model
(并使用model.fitDataset(ds)训练模型)