Index: tika-dl/pom.xml
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- tika-dl/pom.xml (date 1529070083000)
+++ tika-dl/pom.xml (date 1529352854000)
@@ -36,8 +36,8 @@
UTF-8
- 0.8.0
- 0.8.0-2
+ 1.0.0-beta
+ 0.9.1
@@ -96,11 +96,16 @@
org.apache.commons
commons-math3
3.4.1
+
+
+ org.deeplearning4j
+ deeplearning4j-zoo
+ ${dl4j.version}
org.deeplearning4j
deeplearning4j-modelimport
- ${dl4j.model.version}
+ ${dl4j.version}
org.deeplearning4j
@@ -153,11 +158,21 @@
javacpp
+
+
+ org.nd4j
+ nd4j-api
+ ${dl4j.version}
org.bytedeco
javacpp
- 1.3.2
+ 1.4.1
+
+
+ org.bytedeco
+ javacpp
+ 1.4.1
org.apache.commons
Index: tika-dl/src/main/java/org/apache/tika/dl/imagerec/DL4JInceptionV3Net.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- tika-dl/src/main/java/org/apache/tika/dl/imagerec/DL4JInceptionV3Net.java (date 1529070083000)
+++ tika-dl/src/main/java/org/apache/tika/dl/imagerec/DL4JInceptionV3Net.java (date 1529352854000)
@@ -46,9 +46,9 @@
import org.apache.tika.parser.recognition.RecognisedObject;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
-import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
-import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
+import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
+import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
@@ -267,8 +267,7 @@
modelWeightsPath, false);
long time = System.currentTimeMillis() - st;
LOG.info("Loaded the Inception model. Time taken={}ms", time);
- } catch (IOException | InvalidKerasConfigurationException
- | UnsupportedKerasConfigurationException e) {
+ } catch (IOException|UnsupportedKerasConfigurationException|InvalidKerasConfigurationException e) {
throw new TikaConfigException(e.getMessage(), e);
}
}
Index: tika-dl/src/main/java/org/apache/tika/dl/imagerec/DL4JVGG16Net.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- tika-dl/src/main/java/org/apache/tika/dl/imagerec/DL4JVGG16Net.java (date 1529070083000)
+++ tika-dl/src/main/java/org/apache/tika/dl/imagerec/DL4JVGG16Net.java (date 1529352854000)
@@ -29,9 +29,11 @@
import org.apache.tika.parser.recognition.RecognisedObject;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
-import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModelHelper;
-import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.deeplearning4j.util.ModelSerializer;
+import org.deeplearning4j.zoo.PretrainedType;
+import org.deeplearning4j.zoo.ZooModel;
+import org.deeplearning4j.zoo.model.VGG16;
+import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
@@ -40,7 +42,6 @@
import org.slf4j.LoggerFactory;
import org.xml.sax.ContentHandler;
import org.xml.sax.SAXException;
-import org.deeplearning4j.nn.modelimport.keras.trainedmodels.Utils.ImageNetLabels;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
@@ -58,7 +59,6 @@
private static final String BASE_DIR = ".dl4j" + File.separator + "trainedmodels";
private static String MODEL_DIR = HOME_DIR + File.separator + BASE_DIR;
private static String MODEL_DIR_PREPROCESSED = MODEL_DIR + File.separator + "tikaPreprocessed" + File.separator;
- private static TrainedModelHelper MODEL_HELPER = new TrainedModelHelper(TrainedModels.VGG16);
@Field
private File modelFile = new File(MODEL_DIR_PREPROCESSED + File.separator + "vgg16.zip");
@@ -71,6 +71,10 @@
private boolean serialize = true;
@Field
private int topN;
+
+ private ImageNetLabels imageNetLabels;
+
+ private VGG16 vgg16;
private NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3);
private DataNormalization preProcessor = new VGG16ImagePreProcessor();
private boolean available = false;
@@ -99,19 +103,22 @@
} else {
LOG.warn("Preprocessed Model doesn't exist at {}", locationToSave);
locationToSave.getParentFile().mkdirs();
- model = MODEL_HELPER.loadModel();
+ ZooModel zooModel = VGG16.builder().build();
+ model = (ComputationGraph)zooModel.initPretrained(PretrainedType.IMAGENET);
LOG.info("Saving the Loaded model for future use. Saved models are more optimised to consume less resources.");
ModelSerializer.writeModel(model, locationToSave, true);
}
} else {
LOG.info("Weight graph model loaded via dl4j Helper functions");
- model = MODEL_HELPER.loadModel();
+ ZooModel zooModel = VGG16.builder().build();
+ model = (ComputationGraph)zooModel.initPretrained(PretrainedType.IMAGENET);
}
+ imageNetLabels = new ImageNetLabels();
available = true;
} catch (Exception e) {
available = false;
LOG.warn(e.getMessage(), e);
- throw new TikaConfigException(e.getMessage(), e);
+ throw new TikaConfigException(e.getMessage(), e);
}
}
@@ -126,8 +133,6 @@
}
private List predict(INDArray predictions)
{
- ArrayList labels;
- labels=ImageNetLabels.getLabels();
List objects = new ArrayList<>();
int[] topNPredictions = new int[topN];
float[] topNProb = new float[topN];
@@ -140,7 +145,7 @@
topNPredictions[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
topNProb[i] = currentBatch.getFloat(batch, topNPredictions[i]);
currentBatch.putScalar(0, topNPredictions[i], 0);
- outLabels[i]= labels.get(topNPredictions[i]);
+ outLabels[i]= imageNetLabels.getLabel(topNPredictions[i]);
objects.add(new RecognisedObject(outLabels[i], "eng", outLabels[i], topNProb[i]));
i++;
}
Index: tika-dl/src/test/java/org/apache/tika/dl/imagerec/DL4JInceptionV3NetTest.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- tika-dl/src/test/java/org/apache/tika/dl/imagerec/DL4JInceptionV3NetTest.java (date 1529070083000)
+++ tika-dl/src/test/java/org/apache/tika/dl/imagerec/DL4JInceptionV3NetTest.java (date 1529352854000)
@@ -38,6 +38,7 @@
//skip test
return;
}
+ e.printStackTrace();
}
if (config != null) {
Tika tika = new Tika(config);
Index: tika-dl/src/test/java/org/apache/tika/dl/imagerec/DL4JVGG16NetTest.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- tika-dl/src/test/java/org/apache/tika/dl/imagerec/DL4JVGG16NetTest.java (date 1529070083000)
+++ tika-dl/src/test/java/org/apache/tika/dl/imagerec/DL4JVGG16NetTest.java (date 1529352854000)
@@ -36,8 +36,10 @@
&& (e.getMessage().contains("Connection refused")
|| e.getMessage().contains("connect timed out"))) {
//skip test
+ e.printStackTrace();
return;
}
+ e.printStackTrace();
}
if(config != null) {