深度学习-读csv数据做分类器
这篇主要介绍如dl4j如何操作csv,虽然实战中⽐较少⽤,但是对熟悉基本数据操作及结构还是有好处的,代码如下
public class BasicCSVClassifier {
private static Logger log = Logger(BasicCSVClassifier.class);//⼯⼚⽅法⽣成⽇志类
private static Map<Integer,String> eats = readEnumCSV("/DataExamples/animals/eats.csv");//⽤readEnumCSV⽅法直接读csv,存到map
private static Map<Integer,String> sounds = readEnumCSV("/DataExamples/animals/sounds.csv");
private static Map<Integer,String> classifiers = readEnumCSV("/DataExamples/animals/classifiers.csv");
public static void main(String[] args){
try {
/
/Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural
network//RecordReaderDataSetIterator把数据弄成DataSet对象,⽅便放⼊神经⽹络
int labelIndex = 4;    //5 values in each row of CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row//iris每⾏5个值,4个特征后跟⼀个类别,4是类别索引
int numClasses = 3;    //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2//3个类别,标记为0,1,2
int batchSizeTraining = 30;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)//150个数据⼀次载⼊dataset,数据量⼤的时候不推荐,训练批的数量是30
DataSet trainingData = readCSVDataset(
"/DataExamples/animals/animals_train.csv",
batchSizeTraining, labelIndex, numClasses);//readCSVDataset⽅法直接读取csv变成DataSet数据
// this is the data we want to classify
int batchSizeTest = 44;//测试批44,跟上⾯⼀样
DataSet testData = readCSVDataset("/DataExamples/animals/animals.csv",
batchSizeTest, labelIndex, numClasses);
// make the data model for records prior to normalization, because it
// changes the data.//在规范化之前先构建数据结构,因为规范化改变了数据
Map<Integer,Map<String,Object>> animals = makeAnimalsForTesting(testData);//animals是这样的结构{0={eats=Mice, sounds=Meow,
weight=10.0, yearsLived=19}, 1={eats=Cats, sounds=Bark, weight=60.0, yearsLived=9}...}
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance)://规范化数据,0均值,单位⽅差
DataNormalization normalizer = new NormalizerStandardize();//规范化器
normalizer.fit(trainingData);          //Collect the statistics (mean/stdev) from the training data. This does not modify the input data//计算训练数据的均值⽅差,通过
可以获得每列的均值⽅差,注意这时只是收集trainingData的统计信息,trainingData本⾝没有变,如果执⾏
System.out.println(trainingData)会打印出⼀个包含属性的input数组和⼀个包含类别的output数组,⽽且是向量化的类别 [0.00, 0.00, 1.00]
trainingData已经变成规范化的数据了
final int numInputs = 4;//4个属性
int outputNum = 3;//3个类别
int iterations = 1000;//迭代1000次
long seed = 6;//随机种⼦
log.info("");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()//以下套路和上⼀篇⼀致.seed(seed)
.iterations(iterations)
.activation("tanh")
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)                            .activation("softmax")
.nIn(3).nOut(outputNum).build())
.backprop(true).pretrain(false)
.build();
//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
model.fit(trainingData);
//evaluate the model on the test set//⽤测试集评估模型
Evaluation eval = new Evaluation(3);//有3个类,所以输⼊3
INDArray output = model.FeatureMatrix());//根据测试数据属性预测类标签
eval.Labels(), output);//评估器根据实际标签和预测标签进⾏评估
log.info(eval.stats());//打印评估信息
setFittedClassifiers(output, animals);
logAnimals(animals);//把animals的values转成string打印
} catch (Exception e){
e.printStackTrace();
}
}
public static void logAnimals(Map<Integer,Map<String,Object>> animals){
for(Map<String,Object> a:animals.values())eval是做什么的
log.String());
}
public static void setFittedClassifiers(INDArray output, Map<Integer,Map<String,Object>> animals){
for (int i = 0; i < ws() ; i++) {//为每⾏数据匹配⼀个分类,把标号变成类别名
// set the classification from the fitted results
<(i).put("classifier",
<(maxIndex(getFloatArrayFromSlice(output.slice(i)))));//调⽤了下⾯⼏个函数}
}
/**
* This method is to show how to convert the INDArray to a float array. This is to
* provide some more examples on how to convert INDArray to types that are more java
* centric.//这个⽅法展⽰了如何把INDArray转成java的⼩数数组
*
* @param rowSlice
* @return
*/
public static float[] getFloatArrayFromSlice(INDArray rowSlice){
float[] result = new lumns()];//⽣成和rowSlice列长度⼀致的⼩数数组,并填充,记住这⾥的output是类别向量模式for (int i = 0; i < lumns(); i++) {
result[i] = Float(i);
}
return result;
}
/**
* find the maximum item index. This is used when the data is fitted and we
* want to determine which class to assign the test row to
*
* @param vals
* @return
*/
public static int maxIndex(float[] vals){//这个很简单了,因为output是类别向量模式,所以出最⼤值即为预测分类
int maxIndex = 0;
for (int i = 1; i < vals.length; i++){
float newnumber = vals[i];
if ((newnumber > vals[maxIndex])){
maxIndex = i;
}
}
return maxIndex;
}
/**
* take the dataset loaded for the matric and make the record model out of it so
* we can correlate the fitted classifier to the record.//把dataset搞成这种结构主要是为了给没个数据匹配相应的分类名
*
* @param testData
* @return
*/
public static Map<Integer,Map<String,Object>> makeAnimalsForTesting(DataSet testData){
Map<Integer,Map<String,Object>> animals = new HashMap<>();//⽣成hashmap
INDArray features = FeatureMatrix();//获取属性
for (int i = 0; i < ws() ; i++) {//遍历
INDArray slice = features.slice(i);//这⾥slice是把属性弄成⼀⾏⼀⾏的,再对每⾏进⾏处理
Map<String,Object> animal = new HashMap();//⽣成hashmap
//set the attributes//先填充animal,再⽤animal填充animals
animal.put("yearsLived", Int(0));
animal.put("eats", (Int(1)));
animal.put("sounds", (Int(2)));
animal.put("weight", Float(3));
animals.put(i,animal);
}
return animals;
}
public static Map<Integer,String> readEnumCSV(String csvFileClasspath) {//这个⽅法读取的结果map
try{
List<String> lines = adLines(new ClassPathResource(csvFileClasspath).getInputStream());//按⾏读取
Map<Integer,String> enums = new HashMap<>();//⽣成hashmap
for(String line:lines){//填充hashmap并返回
String[] parts = line.split(",");
enums.put(Integer.parseInt(parts[0]),parts[1]);
}
return enums;
} catch (Exception e){
e.printStackTrace();
return null;
}
}
/**
* used for testing and training
*
* @param csvFileClasspath
* @param batchSize
* @param labelIndex
* @param numClasses
* @return
* @throws IOException
* @throws InterruptedException
*/
private static DataSet readCSVDataset(//csv读取器
String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
throws IOException, InterruptedException{
RecordReader rr = new CSVRecordReader();//⽣成csv读取器
rr.initialize(new FileSplit(new ClassPathResource(csvFileClasspath).getFile()));//⽤⽂件初始化读取器
DataSetIterator iterator = new RecordReaderDataSetIterator(rr,batchSize,labelIndex,numClasses);//获取DataSet迭代器,传⼊csv读取器,批⼤⼩,类别索引,类别数
();//返回DataSet迭代器
}
}