Merge remote-tracking branch 'konduit/master'
commit
0107fb10ab
|
@ -49,7 +49,7 @@ check_cuda_version "$VERSION"
|
|||
case $VERSION in
|
||||
10.1)
|
||||
VERSION2="7.6"
|
||||
VERSION3="1.5.1"
|
||||
VERSION3="1.5.2"
|
||||
;;
|
||||
10.0)
|
||||
VERSION2="7.4"
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
||||
<cuda.version>10.1</cuda.version>
|
||||
<cudnn.version>7.6</cudnn.version>
|
||||
<javacpp-presets.cuda.version>1.5.1</javacpp-presets.cuda.version>
|
||||
<javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -133,16 +134,6 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
}
|
||||
is.setMmgr(mmgr);
|
||||
|
||||
|
||||
|
||||
if(paramTable != null && paramTable.size() > 0) {
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
}
|
||||
INDArray result = sameDiff.outputSingle(phMap, outputKey);
|
||||
|
||||
//Edge case: "vertex" is just an identity activation, for example
|
||||
|
@ -212,17 +203,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
String epsName = fn.getGradPlaceholderName();
|
||||
phMap.put(epsName, epsilon);
|
||||
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
List<String> required = new ArrayList<>(config.getVertexParams().getInputs()); //Ensure that the input placeholder gradients are calculated
|
||||
required.addAll(paramTable.keySet());
|
||||
required.addAll(inputNames);
|
||||
|
||||
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
|
||||
for(String s : paramTable.keySet() ){
|
||||
|
@ -279,6 +261,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
protected void doInit(){
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
sameDiff = SameDiff.create();
|
||||
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
|
||||
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
|
||||
|
||||
inputVars = new LinkedHashMap<>();
|
||||
LinkedHashMap<String, SDVariable> maskVars = new LinkedHashMap<>();
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
|||
import org.deeplearning4j.nn.layers.AbstractLayer;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -100,13 +101,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
|
||||
}
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
//Configure memory management for SameDiff instance - use DL4J workspaces
|
||||
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
|
||||
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
||||
|
@ -179,13 +173,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||
bl.validateInput(input);
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
phMap.put(fn.getGradPlaceholderName(), epsilon);
|
||||
|
@ -300,6 +287,8 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||
sameDiff = SameDiff.create();
|
||||
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
|
||||
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
|
||||
Map<String, INDArray> p = paramTable();
|
||||
|
||||
long[] inputShape = input.shape().clone();
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -119,15 +120,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
}
|
||||
is.setMmgr(mmgr);
|
||||
|
||||
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
if(!activations && layerConf().labelsRequired() && labels != null) {
|
||||
|
@ -193,13 +185,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
sameDiff.createGradFunction(INPUT_KEY);
|
||||
}
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
List<String> gradVarNames = new ArrayList<>();
|
||||
gradVarNames.addAll(paramTable.keySet());
|
||||
gradVarNames.add(INPUT_KEY);
|
||||
|
@ -317,6 +302,8 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer bl = layerConf();
|
||||
sameDiff = SameDiff.create();
|
||||
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
|
||||
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
|
||||
Map<String, INDArray> p = paramTable();
|
||||
|
||||
long[] inputShape = input.shape().clone();
|
||||
|
@ -339,7 +326,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
|
||||
outputVar = layerOutput;
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
for (Map.Entry<String, INDArray> e : p.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
|
||||
|
|
|
@ -29,32 +29,32 @@ This example application uses a neural network trained on the standard MNIST dat
|
|||
|
||||
Deeplearning4J applications requires application specific dependencies in the build.gradle file. The Deeplearning library in turn depends on the libraries of ND4J and OpenBLAS, thus these must also be added to the dependencies declaration. Starting with Android Studio 3.0, annotationProcessors need to be defined as well, thus dependencies for either -x86 or -arm processors should be included, depending on your device, if you are working in Android Studio 3.0 or later. Note that both can be include without conflict as is done in the example app.
|
||||
```groovy
|
||||
compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
|
||||
implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
|
||||
exclude group: 'org.bytedeco', module: 'opencv-platform'
|
||||
exclude group: 'org.bytedeco', module: 'leptonica-platform'
|
||||
exclude group: 'org.bytedeco', module: 'hdf5-platform'
|
||||
exclude group: 'org.nd4j', module: 'nd4j-base64'
|
||||
}
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86_64"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
|
||||
|
||||
implementation 'com.google.code.gson:gson:2.8.2'
|
||||
annotationProcessor 'org.projectlombok:lombok:1.16.16'
|
||||
|
|
|
@ -25,32 +25,31 @@ Contents
|
|||
## <a name="head_link1">Setting the Dependencies</a>
|
||||
Deeplearning4J applications require several dependencies in the build.gradle file. The Deeplearning library in turn depends on the libraries of ND4J and OpenBLAS, thus these must also be added to the dependencies declaration. Starting with Android Studio 3.0, annotationProcessors need to be defined as well, requiring dependencies for -x86 or -arm processors.
|
||||
```groovy
|
||||
compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
|
||||
implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
|
||||
exclude group: 'org.bytedeco', module: 'opencv-platform'
|
||||
exclude group: 'org.bytedeco', module: 'leptonica-platform'
|
||||
exclude group: 'org.bytedeco', module: 'hdf5-platform'
|
||||
exclude group: 'org.nd4j', module: 'nd4j-base64'
|
||||
}
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86_64"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
|
||||
```
|
||||
|
||||
Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig.
|
||||
|
|
|
@ -33,33 +33,32 @@ It is also recommended that you download and install IntelliJ IDEA, Maven, and t
|
|||
In order to use Deeplearning4J in your Android projects, you will need to add the following dependencies to your app module’s build.gradle file. Depending on the type of neural network used in your application, you may need to add additional dependencies.
|
||||
|
||||
``` groovy
|
||||
compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
|
||||
implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
|
||||
exclude group: 'org.bytedeco', module: 'opencv-platform'
|
||||
exclude group: 'org.bytedeco', module: 'leptonica-platform'
|
||||
exclude group: 'org.bytedeco', module: 'hdf5-platform'
|
||||
exclude group: 'org.nd4j', module: 'nd4j-base64'
|
||||
}
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86_64"
|
||||
testCompile 'junit:junit:4.12'
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
|
||||
testimplementation 'junit:junit:4.12'
|
||||
```
|
||||
|
||||
DL4J depends on ND4J, which is a library that offers fast n-dimensional arrays. ND4J in turn depends on a platform-specific native code library called JavaCPP, therefore you must load a version of ND4J that matches the architecture of the Android device. Both -x86 and -arm types can be included to support multiple device processor types.
|
||||
|
|
|
@ -33,35 +33,34 @@ For best results, you’ll need the following:
|
|||
|
||||
## <a name="head_link2">Configuring Your Android Studio Project</a>
|
||||
|
||||
To be able to use Deeplearning4J in your project, add the following compile dependencies to your app module’s build.gradle file:
|
||||
To be able to use Deeplearning4J in your project, add the following implementation dependencies to your app module’s build.gradle file:
|
||||
|
||||
``` groovy
|
||||
compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
|
||||
implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
|
||||
exclude group: 'org.bytedeco', module: 'opencv-platform'
|
||||
exclude group: 'org.bytedeco', module: 'leptonica-platform'
|
||||
exclude group: 'org.bytedeco', module: 'hdf5-platform'
|
||||
exclude group: 'org.nd4j', module: 'nd4j-base64'
|
||||
}
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86_64"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1'
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm64"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86"
|
||||
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86_64"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
|
||||
implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
|
||||
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -61,25 +61,25 @@ Alternatively, in the case of CUDA 10.1, cuDNN comes bundled with the "redist" p
|
|||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>cuda</artifactId>
|
||||
<version>10.1-7.6-1.5.1</version>
|
||||
<version>10.1-7.6-1.5.2</version>
|
||||
<classifier>linux-x86_64-redist</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>cuda</artifactId>
|
||||
<version>10.1-7.6-1.5.1</version>
|
||||
<version>10.1-7.6-1.5.2</version>
|
||||
<classifier>linux-ppc64le-redist</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>cuda</artifactId>
|
||||
<version>10.1-7.6-1.5.1</version>
|
||||
<version>10.1-7.6-1.5.2</version>
|
||||
<classifier>macosx-x86_64-redist</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>cuda</artifactId>
|
||||
<version>10.1-7.6-1.5.1</version>
|
||||
<version>10.1-7.6-1.5.2</version>
|
||||
<classifier>windows-x86_64-redist</classifier>
|
||||
</dependency>
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
|
|
@ -59,6 +59,7 @@ namespace nd4j {
|
|||
|
||||
template <typename T>
|
||||
static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* create_(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
||||
template <typename T>
|
||||
static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
|
|
@ -422,6 +422,11 @@ NDArray NDArrayFactory::create(nd4j::DataType dtype, nd4j::LaunchContext * conte
|
|||
return res;
|
||||
}
|
||||
|
||||
NDArray* NDArrayFactory::create_(nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
||||
auto result = new NDArray();
|
||||
*result = NDArrayFactory::create(dtype, context);
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -29,8 +30,8 @@
|
|||
#include <dll.h>
|
||||
#include <Environment.h>
|
||||
#include <ArrayOptions.h>
|
||||
#include <templatemath.h>
|
||||
#include <shape.h>
|
||||
//#include <templatemath.h>
|
||||
//#include <shape.h>
|
||||
#include <helpers/logger.h>
|
||||
|
||||
namespace nd4j {
|
||||
|
@ -128,7 +129,9 @@ namespace nd4j {
|
|||
// if both dtypes are the same - just return it
|
||||
if (typeX == typeY)
|
||||
return typeX;
|
||||
|
||||
auto nd4j_max = [](nd4j::DataType typeX, nd4j::DataType typeY) {
|
||||
return typeX > typeY?typeX:typeY;
|
||||
};
|
||||
auto rX = isR(typeX);
|
||||
auto rY = isR(typeY);
|
||||
|
||||
|
@ -144,7 +147,7 @@ namespace nd4j {
|
|||
if (rX && rY) {
|
||||
// if we allow precision boost, then we pick bigger data type
|
||||
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
|
||||
return nd4j::math::nd4j_max<nd4j::DataType>(typeX, typeY);
|
||||
return nd4j_max(typeX, typeY);
|
||||
} else {
|
||||
// and we return first operand otherwise
|
||||
return typeX;
|
||||
|
@ -155,7 +158,7 @@ namespace nd4j {
|
|||
// if that's not real type, we apply same rules
|
||||
if (!rX && !rY) {
|
||||
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
|
||||
return nd4j::math::nd4j_max<nd4j::DataType>(typeX, typeY);
|
||||
return nd4j_max(typeX, typeY);
|
||||
} else {
|
||||
// and we return first operand otherwise
|
||||
return typeX;
|
||||
|
@ -367,8 +370,8 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
|||
|
||||
template <typename T>
|
||||
FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) {
|
||||
|
||||
for (int e = 0; e < shape::shapeInfoLength(originalShapeInfo); e++) {
|
||||
auto shapeInfoLength = *originalShapeInfo * 2 + 4;
|
||||
for (auto e = 0; e < shapeInfoLength; e++) {
|
||||
if (originalShapeInfo[e] < static_cast<Nd4jLong>(DataTypeUtils::max<T>())) {
|
||||
newShapeInfo[e] = static_cast<T>(originalShapeInfo[e]);
|
||||
} else
|
||||
|
|
|
@ -1,10 +1,26 @@
|
|||
/**
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver on 5/17/2019.
|
||||
//
|
||||
|
||||
#include <array/ConstantHolder.h>
|
||||
#include <DataTypeUtils.h>
|
||||
|
||||
#include <array/ConstantHolder.h>
|
||||
#include <shape.h>
|
||||
|
||||
namespace nd4j {
|
||||
ConstantHolder::ConstantHolder(const ConstantHolder& other) {
|
||||
|
@ -24,7 +40,7 @@ namespace nd4j {
|
|||
bool ConstantHolder::hasBuffer() {
|
||||
return hasBuffer(DataTypeUtils::fromT<T>());
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT bool ConstantHolder::hasBuffer, (), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT bool ConstantHolder::hasBuffer, (void), LIBND4J_TYPES);
|
||||
|
||||
void ConstantHolder::addBuffer(ConstantDataBuffer &pointer, nd4j::DataType dataType) {
|
||||
_buffers[dataType] = pointer;
|
||||
|
@ -34,7 +50,7 @@ namespace nd4j {
|
|||
void ConstantHolder::addBuffer(ConstantDataBuffer &pointer) {
|
||||
addBuffer(pointer, DataTypeUtils::fromT<T>());
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer&), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer& cb), LIBND4J_TYPES);
|
||||
|
||||
ConstantDataBuffer* ConstantHolder::getConstantDataBuffer(nd4j::DataType dataType) {
|
||||
if (!hasBuffer(dataType))
|
||||
|
|
|
@ -195,7 +195,21 @@ namespace nd4j {
|
|||
template <typename T>
|
||||
_CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, T to) {
|
||||
auto t = this->relativeT<T>(index);
|
||||
auto z = from + (t * (to - from));
|
||||
auto z = from + T(t * (to - from));
|
||||
return z;
|
||||
}
|
||||
|
||||
template <>
|
||||
_CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index, Nd4jLong from, Nd4jLong to) {
|
||||
auto t = this->relativeT<double>(index);
|
||||
auto z = from + Nd4jLong(t * (to - from));
|
||||
return z;
|
||||
}
|
||||
|
||||
template <>
|
||||
_CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index, int from, int to) {
|
||||
auto t = this->relativeT<float>(index);
|
||||
auto z = from + float(t * (to - from));
|
||||
return z;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -21,6 +22,7 @@
|
|||
#include <exceptions/cuda_exception.h>
|
||||
#include <ConstantHelper.h>
|
||||
#include <DataTypeUtils.h>
|
||||
#include <shape.h>
|
||||
#include <execution/LaunchContext.h>
|
||||
#include <specials.h>
|
||||
#include <logger.h>
|
||||
|
|
|
@ -620,14 +620,15 @@ namespace nd4j {
|
|||
// FIXME: remove this method once we get 1D vectors supported
|
||||
vectorize(input_shape);
|
||||
REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed");
|
||||
|
||||
//REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op.");
|
||||
//Zero output array, so unused elements have 0 gradient
|
||||
output->nullify();
|
||||
std::sort(indices.begin(), indices.end());
|
||||
if(indices.size() == 3 && (indices[1] - indices[0]) == 1) {
|
||||
//
|
||||
// the first case: only for scalar gradient step
|
||||
if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) {
|
||||
output->p(indices[0], *epsNext);
|
||||
}
|
||||
else {
|
||||
else { // else for other cases
|
||||
auto sub = (*output)(indices, true, true);
|
||||
sub.assign(epsNext);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -23,6 +24,7 @@
|
|||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <helpers/RandomLauncher.h>
|
||||
#include <ops/declarable/helpers/random.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -35,41 +37,59 @@ namespace nd4j {
|
|||
* TArgs[0] - min for rng
|
||||
* TArgs[1] - max for rng
|
||||
*/
|
||||
CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 2, 0) {
|
||||
CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 0, 0) {
|
||||
// uniform distribution
|
||||
auto rng = block.randomGenerator();
|
||||
auto dtype = DataType::FLOAT32;
|
||||
if (block.getIArguments()->size())
|
||||
dtype = (DataType)INT_ARG(0);
|
||||
|
||||
// FIXME: to be implemented
|
||||
/*
|
||||
if (rng == nullptr)
|
||||
return Status::THROW("RNG is null, aborting...");
|
||||
auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
|
||||
auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
|
||||
bool disposable = false;
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||
min = NDArrayFactory::create_(dtype);
|
||||
max = NDArrayFactory::create_(dtype);
|
||||
min->p(0, T_ARG(0));
|
||||
max->p(0, T_ARG(1));
|
||||
disposable = true;
|
||||
}
|
||||
|
||||
functions::random::RandomFunction<T>::template execTransform<randomOps::UniformDistribution<T>>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
|
||||
|
||||
STORE_RESULT(*z);
|
||||
*/
|
||||
REQUIRE_TRUE(block.numT() > 1, 0, "RandomUniform: to/from must be set");
|
||||
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
|
||||
|
||||
RandomLauncher::fillUniform(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
|
||||
if (disposable) {
|
||||
delete min;
|
||||
delete max;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(randomuniform) {
|
||||
auto in = INPUT_VARIABLE(0);
|
||||
//auto min = INPUT_VARIABLE(1);
|
||||
auto shape = in->template asVectorT<Nd4jLong>();
|
||||
auto dtype = DataType::FLOAT32; //ArrayOptions::dataType(inputShape->at(1)); // output type is by given min
|
||||
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', shape);
|
||||
if (block.getIArguments()->size())
|
||||
dtype = (DataType)INT_ARG(0);
|
||||
if (block.width() > 1)
|
||||
REQUIRE_TRUE(dtype == INPUT_VARIABLE(1)->dataType(), 0, "RandomUniform: data type of output and min/max args should be the same");
|
||||
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape);
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(randomuniform) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
->setAllowedInputTypes(0, {ALL_INTS})
|
||||
->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS})
|
||||
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -33,8 +34,20 @@ namespace nd4j {
|
|||
DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* random_uniform distribution for types int32,int64, float16, float and double
|
||||
* by default dtype is float32
|
||||
*
|
||||
* input:
|
||||
* 0 - shape of output (1D int tensor)
|
||||
* 1 - min val (0D of output type) - optional (0 as default)
|
||||
* 2 - max val (0D of output type) - optional (inf as default)
|
||||
*
|
||||
* output:
|
||||
* 0 - uniformly distributed values of given type (between min and max)
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_randomuniform)
|
||||
DECLARE_CUSTOM_OP(randomuniform, 1, 1, true, 2, 0);
|
||||
DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
#if NOT_EXCLUDED(OP_random_normal)
|
||||
|
@ -66,6 +79,7 @@ namespace nd4j {
|
|||
#if NOT_EXCLUDED(OP_random_poisson)
|
||||
DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -23,6 +23,7 @@
|
|||
#include <memory>
|
||||
//#include <graph/Context.h>
|
||||
#include <ShapeUtils.h>
|
||||
#include <helpers/RandomLauncher.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -127,6 +128,28 @@ namespace helpers {
|
|||
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context,
|
||||
graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES);
|
||||
|
||||
template <typename T>
|
||||
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||
T minVal = T(0);
|
||||
T maxVal = DataTypeUtils::max<T>();
|
||||
if (min)
|
||||
minVal = min->t<T>(0);
|
||||
if (max)
|
||||
maxVal = max->t<T>(0);
|
||||
|
||||
if (output->isR())
|
||||
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
|
||||
else {
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (auto i = 0; i < output->lengthOf(); i++) {
|
||||
output->t<T>(i) = rng.relativeT<T>(i, minVal, maxVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -91,7 +92,7 @@ namespace helpers {
|
|||
val = nd4j::math::nd4j_min<T>(val, input->t<T>(e));
|
||||
}
|
||||
else {
|
||||
idx = indices->e<int>(e);
|
||||
idx = indices->e<Nd4jLong>(e);
|
||||
val = input->t<T>(e);
|
||||
}
|
||||
output->t<T>(idx) = val;
|
||||
|
@ -111,14 +112,14 @@ namespace helpers {
|
|||
minT->assign(listOfTensors->at(0));
|
||||
|
||||
for (Nd4jLong i = 1; i < indices->lengthOf(); i++) {
|
||||
if (indices->e<T>(i) == idx) {
|
||||
if (indices->e<Nd4jLong>(i) == idx) {
|
||||
|
||||
for (int e = 0; e < minT->lengthOf(); e++) {
|
||||
minT->p(e, nd4j::math::nd4j_min(minT->e<T>(e), listOfTensors->at(i)->e<T>(e)));
|
||||
}
|
||||
}
|
||||
else {
|
||||
idx = indices->e<T>(i);
|
||||
idx = indices->e<Nd4jLong>(i);
|
||||
minT = listOfOutTensors->at(idx);
|
||||
minT->assign(listOfTensors->at(i));
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -26,6 +26,7 @@
|
|||
#include <helpers/RandomLauncher.h>
|
||||
#include <ShapeUtils.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <cuda_exception.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -181,6 +182,72 @@ namespace helpers {
|
|||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE);
|
||||
|
||||
template <typename T>
|
||||
static __global__ void fillUniformKernel(graph::RandomGenerator* devRng, T from, T to, T* output, Nd4jLong* outputShape) {
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
__shared__ Nd4jLong outputLen;
|
||||
|
||||
if (0 == threadIdx.x) {
|
||||
outputLen = shape::length(outputShape);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto i = start; i < outputLen; i += step) {
|
||||
auto zIndex = shape::getIndexOffset(i, outputShape);
|
||||
output[zIndex] = devRng->relativeT<T>(i, from, to);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||
T minVal = T(0);
|
||||
T maxVal = DataTypeUtils::infOrMax<T>();
|
||||
if (min)
|
||||
minVal = min->t<T>(0);
|
||||
if (max)
|
||||
maxVal = max->t<T>(0);
|
||||
|
||||
if (output->isR())
|
||||
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
|
||||
else {
|
||||
auto stream = context->getCudaStream();
|
||||
graph::RandomGenerator *devRng;
|
||||
auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator));
|
||||
if (err != 0) {
|
||||
cuda_exception::build("fillRandomUniform_: Cannot allocate device memory for random generator due error", err);
|
||||
}
|
||||
|
||||
err = cudaMemcpy(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||
if (err != 0) {
|
||||
cuda_exception::build("fillRandomUniform_: Cannot copy random generator to device", err);
|
||||
}
|
||||
auto outputBuf = output->dataBuffer()->specialAsT<T>();
|
||||
auto outputShape = output->specialShapeInfo();
|
||||
fillUniformKernel<T><<<128, 128, 128, *stream>>>(devRng, minVal, maxVal, outputBuf, outputShape);
|
||||
|
||||
err = cudaStreamSynchronize(*stream);
|
||||
if (err != 0) {
|
||||
cuda_exception::build("fillRandomUniform_: Cannot successfully finish kernel call", err);
|
||||
}
|
||||
|
||||
err = cudaFree(devRng);
|
||||
if (err != 0) {
|
||||
cuda_exception::build("fillRandomUniform_: Cannot deallocate device memory for random generator", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
||||
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -33,7 +33,7 @@ namespace helpers {
|
|||
|
||||
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
|
||||
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
|
||||
|
||||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -273,8 +274,8 @@ public:
|
|||
BlockInformation(Nd4jLong length, int threshold) {
|
||||
|
||||
threads = length / threshold;
|
||||
threads = nd4j::math::nd4j_max<int>(1, threads);
|
||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
||||
threads = (1 < threads)?threads:1;//nd4j::math::nd4j_max<int>(1, threads);
|
||||
threads = (threads < omp_get_max_threads())?threads:omp_get_max_threads();//nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
||||
|
||||
items = length / threads;
|
||||
remainder = length % threads;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -27,7 +28,7 @@
|
|||
#include <dll.h>
|
||||
#include <pointercast.h>
|
||||
#include <platformmath.h>
|
||||
|
||||
#include <DataTypeUtils.h>
|
||||
|
||||
#define BFLOAT16_MAX_VALUE 32737.
|
||||
#define HALF_MAX_VALUE 65504.
|
||||
|
@ -883,7 +884,7 @@ namespace nd4j {
|
|||
if (a > 171.624) {
|
||||
// Correct answer too large to display. Force +infinity.
|
||||
return Z(DOUBLE_MAX_VALUE);
|
||||
//DataTypeUtils::infOrMax<Z>();
|
||||
// return DataTypeUtils::infOrMax<Z>();
|
||||
}
|
||||
|
||||
return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a));
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
|
|
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
|||
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
|
||||
auto grad = NDArrayFactory::create<double>('c', {5,4});
|
||||
auto grad = NDArrayFactory::create<double>('c', {5});
|
||||
|
||||
matrix.linspace(1);
|
||||
grad.linspace(1);
|
||||
|
@ -264,6 +264,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
||||
int zero = 0;
|
||||
auto matrix = NDArrayFactory::create<double>('c', {1, 2});
|
||||
|
@ -287,6 +288,31 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
|
||||
int zero = 0;
|
||||
auto matrix = NDArrayFactory::create<float>('c', {4, 8192});
|
||||
// auto b = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
|
||||
auto grad = NDArrayFactory::create<double>('c', {4, 256});
|
||||
|
||||
matrix.linspace(1);
|
||||
grad.linspace(1);
|
||||
|
||||
nd4j::ops::strided_slice_bp op;
|
||||
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
z->printShapeInfo("Output shape");
|
||||
z->printIndexedBuffer("Output");
|
||||
//ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
|
||||
|
|
|
@ -58,6 +58,11 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(PlaygroundTests, test_s_1) {
|
||||
auto t = ::runLightBenchmarkSuit(true);
|
||||
delete[] t;
|
||||
}
|
||||
|
||||
/*
|
||||
TEST_F(PlaygroundTests, test_relubp_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {128, 64, 224, 224});
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -855,18 +856,39 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_UniformDistribution_04) {
|
||||
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
|
||||
auto al = NDArrayFactory::create<int>(1);
|
||||
auto be = NDArrayFactory::create<int>(20);
|
||||
auto exp0 = NDArrayFactory::create<float>('c', {10});
|
||||
|
||||
|
||||
nd4j::ops::randomuniform op;
|
||||
auto result = op.execute({&x, &al, &be}, {}, {DataType::INT32});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
z->printIndexedBuffer("Uniform int distribution");
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
namespace nd4j {
|
||||
namespace tests {
|
||||
static void fillList(Nd4jLong seed, int numberOfArrays, std::vector<Nd4jLong> &shape, std::vector<NDArray*> &list, nd4j::graph::RandomGenerator *rng) {
|
||||
rng->setSeed((int) seed);
|
||||
|
||||
for (int i = 0; i < numberOfArrays; i++) {
|
||||
auto array = NDArrayFactory::create_<double>('c', shape);
|
||||
|
||||
auto arrayI = NDArrayFactory::create<Nd4jLong>(shape);
|
||||
auto arrayR = NDArrayFactory::create_<double>('c', shape);
|
||||
auto min = NDArrayFactory::create(0.0);
|
||||
auto max = NDArrayFactory::create(1.0);
|
||||
nd4j::ops::randomuniform op;
|
||||
op.execute(*rng, {array}, {array}, {0.0, 1.0}, {}, {}, true);
|
||||
op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, false);
|
||||
|
||||
list.emplace_back(array);
|
||||
list.emplace_back(arrayR);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -979,3 +1001,13 @@ TEST_F(RNGTests, test_choice_1) {
|
|||
delete x;
|
||||
delete prob;
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, test_uniform_119) {
|
||||
auto x = NDArrayFactory::create<int>('c', {2}, {1, 5});
|
||||
auto z = NDArrayFactory::create<float>('c', {1, 5});
|
||||
|
||||
|
||||
nd4j::ops::randomuniform op;
|
||||
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
|
@ -401,8 +401,8 @@ public class DifferentialFunctionFactory {
|
|||
return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable randomUniform(double min, double max, SDVariable shape) {
|
||||
return new DistributionUniform(sameDiff(), shape, min, max).outputVariable();
|
||||
public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) {
|
||||
return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable randomUniform(double min, double max, long... shape) {
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Collection;
|
||||
|
||||
/**
|
||||
* Holds a set of arrays keyed by a String name, functioning essentially like a {@code Map<String,INDArray>}.<br>
|
||||
* Implementations may have different internal ways of storing arrays, however.<br>
|
||||
* For example for single threaded applications: {@link org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder}<br>
|
||||
* And for multi-threaded: {@link org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public interface ArrayHolder {
|
||||
|
||||
/**
|
||||
* @return True if an array by that name exists
|
||||
*/
|
||||
boolean hasArray(String name);
|
||||
|
||||
/**
|
||||
* @param name Name of the array to get
|
||||
* @return The array, or null if no array with that name exists
|
||||
*/
|
||||
INDArray getArray(String name);
|
||||
|
||||
/**
|
||||
* Set the array for the specified name (new array, or replace if it already exists)
|
||||
*
|
||||
* @param name Name of the array
|
||||
* @param array Array to set
|
||||
*/
|
||||
void setArray(String name, INDArray array);
|
||||
|
||||
/**
|
||||
* Remove the array from the ArrayHolder, returning it (if it exists)
|
||||
*
|
||||
* @param name Name of the array to return
|
||||
* @return The now-removed array
|
||||
*/
|
||||
INDArray removeArray(String name);
|
||||
|
||||
/**
|
||||
* @return Number of arrays in the ArrayHolder
|
||||
*/
|
||||
int size();
|
||||
|
||||
/**
|
||||
* Initialize from the specified array holder.
|
||||
* This clears all internal arrays, and adds all arrays from the specified array holder
|
||||
*
|
||||
* @param arrayHolder Array holder to initialize this based on
|
||||
*/
|
||||
void initFrom(ArrayHolder arrayHolder);
|
||||
|
||||
/**
|
||||
* @return Names of the arrays currently in the ArrayHolder
|
||||
*/
|
||||
Collection<String> arrayNames();
|
||||
|
||||
/**
|
||||
* Rename the entry with the specified name
|
||||
*
|
||||
* @param from Original name
|
||||
* @param to New name
|
||||
*/
|
||||
void rename(String from, String to);
|
||||
}
|
|
@ -30,6 +30,8 @@ import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
|||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.samediff.api.OutAndGrad;
|
||||
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
|
||||
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
||||
import org.nd4j.autodiff.samediff.config.FitConfig;
|
||||
|
@ -122,8 +124,8 @@ public class SameDiff extends SDBaseOps {
|
|||
@Getter
|
||||
private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<>(); //Key: thread ID
|
||||
|
||||
private final Map<String, DeviceLocalNDArray> constantArrays = new ConcurrentHashMap<>();
|
||||
private final Map<String, DeviceLocalNDArray> variablesArrays = new ConcurrentHashMap<>(); //TODO issues with DeviceLocal + mutable / changed during training?
|
||||
private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true);
|
||||
private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true);
|
||||
private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them
|
||||
|
||||
private final List<String> lossVariables = new ArrayList<>();
|
||||
|
@ -346,6 +348,23 @@ public class SameDiff extends SDBaseOps {
|
|||
return listeners;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the array holders for variable and constant arrays<br>
|
||||
* <b>NOTE:</b> this is usually reserved for developers and internal use, and should not be needed by almost all users<br>
|
||||
* See {@link ArrayHolder} for more details
|
||||
*
|
||||
* @param variableArrayHolder Array holder for variable arrays
|
||||
* @param constantArrayHolder Array holder for constant arrays
|
||||
* @param initialize If true: transfer any arrays from the current array holders to the new/specified ones
|
||||
*/
|
||||
public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize){
|
||||
if(initialize){
|
||||
variableArrayHolder.initFrom(this.variablesArrays);
|
||||
constantArrayHolder.initFrom(this.constantArrays);
|
||||
}
|
||||
this.variablesArrays = variableArrayHolder;
|
||||
this.constantArrays = constantArrayHolder;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details.
|
||||
|
@ -674,9 +693,9 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
SDVariable v = getVariable(varName);
|
||||
if (v.isConstant()) {
|
||||
constantArrays.put(varName, new DeviceLocalNDArray(arr, true));
|
||||
constantArrays.setArray(varName, arr);
|
||||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||
variablesArrays.put(varName, new DeviceLocalNDArray(arr, true));
|
||||
variablesArrays.setArray(varName, arr);
|
||||
} else if (v.isPlaceHolder()) {
|
||||
long tid = Thread.currentThread().getId();
|
||||
if (!placeholdersPerThread.containsKey(tid)) {
|
||||
|
@ -699,12 +718,12 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable var = getVariable(varName);
|
||||
switch (var.getVariableType()) {
|
||||
case VARIABLE:
|
||||
return variablesArrays.containsKey(varName);
|
||||
return variablesArrays.hasArray(varName);
|
||||
case ARRAY:
|
||||
long tid = Thread.currentThread().getId();
|
||||
return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null);
|
||||
case CONSTANT:
|
||||
return constantArrays.containsKey(varName);
|
||||
return constantArrays.hasArray(varName);
|
||||
case PLACEHOLDER:
|
||||
return placeholdersPerThread.containsKey(Thread.currentThread().getId()) &&
|
||||
placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
|
||||
|
@ -724,11 +743,11 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable v = variables.get(varName).getVariable();
|
||||
switch (v.getVariableType()) {
|
||||
case VARIABLE:
|
||||
return variablesArrays.get(varName).get();
|
||||
return variablesArrays.getArray(varName);
|
||||
case CONSTANT:
|
||||
if (!constantArrays.containsKey(varName))
|
||||
if (!constantArrays.hasArray(varName))
|
||||
return null;
|
||||
return constantArrays.get(varName).get();
|
||||
return constantArrays.getArray(varName);
|
||||
case ARRAY:
|
||||
//Only stored in inference session...
|
||||
InferenceSession s = sessions.get(Thread.currentThread().getId());
|
||||
|
@ -781,31 +800,16 @@ public class SameDiff extends SDBaseOps {
|
|||
sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
|
||||
}
|
||||
|
||||
boolean duped = false;
|
||||
if (arr.isAttached()) {
|
||||
arr = arr.detach();
|
||||
duped = true;
|
||||
}
|
||||
if (arr.isView()) {
|
||||
arr = arr.dup();
|
||||
duped = true;
|
||||
}
|
||||
|
||||
if (!duped && variable.getVariableType() == VariableType.VARIABLE) {
|
||||
for (DeviceLocalNDArray otherArr : variablesArrays.values()) {
|
||||
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
|
||||
arr = arr.dup();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch (variable.getVariableType()) {
|
||||
case VARIABLE:
|
||||
variablesArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.setArray(variable.name(), arr);
|
||||
break;
|
||||
case CONSTANT:
|
||||
constantArrays.put(variable.name(), new DeviceLocalNDArray(arr, true));
|
||||
constantArrays.setArray(variable.name(), arr);
|
||||
break;
|
||||
case ARRAY:
|
||||
throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" +
|
||||
|
@ -859,9 +863,9 @@ public class SameDiff extends SDBaseOps {
|
|||
arr = arr.dup();
|
||||
|
||||
if(variable.getVariableType() == VariableType.VARIABLE ){
|
||||
variablesArrays.get(variable.name()).update(arr);
|
||||
variablesArrays.setArray(variable.name(), arr);
|
||||
} else {
|
||||
constantArrays.get(variable.name()).update(arr);
|
||||
constantArrays.setArray(variable.name(), arr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2715,7 +2719,7 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType());
|
||||
name = v.name();
|
||||
variables.put(name, Variable.builder().name(name).variable(v).build());
|
||||
constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.setArray(name, constant);
|
||||
return v;
|
||||
}
|
||||
|
||||
|
@ -2792,7 +2796,7 @@ public class SameDiff extends SDBaseOps {
|
|||
if(variableType == VariableType.VARIABLE){
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
INDArray vArr = weightInitScheme.create(dataType, shape);
|
||||
variablesArrays.put(name, new DeviceLocalNDArray(vArr, true));
|
||||
variablesArrays.setArray(name, vArr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2924,7 +2928,7 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType());
|
||||
addVariable(r);
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
variablesArrays.put(v.name(), new DeviceLocalNDArray(v.getArr().dup(), true));
|
||||
variablesArrays.setArray(v.name(), v.getArr().dup());
|
||||
}
|
||||
return r;
|
||||
case ARRAY:
|
||||
|
@ -3014,20 +3018,17 @@ public class SameDiff extends SDBaseOps {
|
|||
arr = arr.detach();
|
||||
duped = true;
|
||||
}
|
||||
if (arr.isView()) {
|
||||
arr = arr.dup();
|
||||
duped = true;
|
||||
}
|
||||
|
||||
if (!duped) {
|
||||
for (DeviceLocalNDArray otherArr : variablesArrays.values()) {
|
||||
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
|
||||
for (String s : variablesArrays.arrayNames()) {
|
||||
if (variablesArrays.getArray(s) == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
|
||||
arr = arr.dup();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType());
|
||||
associateArrayWithVariable(arr, ret);
|
||||
|
||||
|
@ -3085,8 +3086,8 @@ public class SameDiff extends SDBaseOps {
|
|||
INDArray arr = variable.getArr();
|
||||
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
||||
|
||||
constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.remove(n);
|
||||
constantArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.removeArray(n);
|
||||
if (!placeholdersPerThread.isEmpty()) {
|
||||
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
||||
m.remove(n);
|
||||
|
@ -3183,8 +3184,8 @@ public class SameDiff extends SDBaseOps {
|
|||
INDArray arr = variable.getArr();
|
||||
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
||||
|
||||
variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.remove(n);
|
||||
variablesArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.removeArray(n);
|
||||
if (!placeholdersPerThread.isEmpty()) {
|
||||
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
||||
m.remove(n);
|
||||
|
@ -3260,16 +3261,14 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
switch (v.getVariableType()) {
|
||||
case VARIABLE:
|
||||
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey());
|
||||
INDArray arr = dl.get();
|
||||
INDArray arr = variablesArrays.removeArray(e.getKey());
|
||||
INDArray newArr = arr.castTo(d);
|
||||
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.setArray(e.getKey(), newArr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
break;
|
||||
case CONSTANT:
|
||||
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey());
|
||||
INDArray arr2 = dl2.get();
|
||||
INDArray arr2 = constantArrays.removeArray(e.getKey());
|
||||
INDArray newArr2 = arr2.castTo(d);
|
||||
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.setArray(e.getKey(), newArr2); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
break;
|
||||
case PLACEHOLDER:
|
||||
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());
|
||||
|
@ -3409,14 +3408,12 @@ public class SameDiff extends SDBaseOps {
|
|||
variables.remove(from);
|
||||
variables.put(to, v);
|
||||
|
||||
if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.containsKey(from)){
|
||||
DeviceLocalNDArray dl = constantArrays.remove(from);
|
||||
constantArrays.put(to, dl);
|
||||
if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.hasArray(from)){
|
||||
constantArrays.rename(from, to);
|
||||
}
|
||||
|
||||
if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.containsKey(from)){
|
||||
DeviceLocalNDArray dl = variablesArrays.remove(from);
|
||||
variablesArrays.put(to, dl);
|
||||
if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.hasArray(from)){
|
||||
variablesArrays.rename(from, to);
|
||||
}
|
||||
|
||||
if(v.getVariable().getVariableType() == VariableType.PLACEHOLDER ){
|
||||
|
@ -4187,6 +4184,8 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
@Override
|
||||
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
||||
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); //Training isn't thread safe, no need to use DeviceLocal, even with lazy init
|
||||
|
||||
//Propagate graph to this samediff instance which will also contain the backward
|
||||
if (SameDiff.this.debugMode) {
|
||||
sameDiff.enableDebugMode();
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
package org.nd4j.autodiff.samediff.array;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.ArrayHolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A simple {@link ArrayHolder} that uses a simple {@code Map<String, INDArray>} internally.
|
||||
* No thread safety guarantees
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class SingleThreadArrayHolder implements ArrayHolder {
|
||||
|
||||
private final Map<String, INDArray> map = new HashMap<>();
|
||||
|
||||
@Override
|
||||
public boolean hasArray(@NonNull String name) {
|
||||
return map.containsKey(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getArray(@NonNull String name) {
|
||||
return map.get(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setArray(@NonNull String name, @NonNull INDArray array) {
|
||||
map.put(name, array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray removeArray(@NonNull String name) {
|
||||
return map.remove(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return map.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFrom(ArrayHolder arrayHolder) {
|
||||
map.clear();
|
||||
Collection<String> names = arrayHolder.arrayNames();
|
||||
for (String n : names) {
|
||||
map.put(n, arrayHolder.getArray(n));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<String> arrayNames() {
|
||||
return Collections.unmodifiableCollection(map.keySet());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void rename(String from, String to) {
|
||||
INDArray arr = map.remove(from);
|
||||
map.put(to, arr);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
package org.nd4j.autodiff.samediff.array;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.ArrayHolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* An {@link ArrayHolder} that uses the thread safe {@link DeviceLocalNDArray} internally
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class ThreadSafeArrayHolder implements ArrayHolder {
|
||||
|
||||
private final Map<String, DeviceLocalNDArray> map = new ConcurrentHashMap<>();
|
||||
private final boolean lazyInit;
|
||||
|
||||
/**
|
||||
* @param lazyInit If true: use lazy initialization for {@link DeviceLocalNDArray}
|
||||
*/
|
||||
public ThreadSafeArrayHolder(boolean lazyInit) {
|
||||
this.lazyInit = lazyInit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasArray(@NonNull String name) {
|
||||
return map.containsKey(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getArray(@NonNull String name) {
|
||||
return map.get(name).get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setArray(@NonNull String name, @NonNull INDArray array) {
|
||||
if (array.isView())
|
||||
array = array.dup(); //Device local doesn't support views
|
||||
if (!map.containsKey(name)) {
|
||||
DeviceLocalNDArray dla = new DeviceLocalNDArray(array, lazyInit);
|
||||
map.put(name, dla);
|
||||
} else {
|
||||
DeviceLocalNDArray dla = map.get(name);
|
||||
dla.update(array);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray removeArray(@NonNull String name) {
|
||||
DeviceLocalNDArray arr = map.remove(name);
|
||||
if (arr == null)
|
||||
return null;
|
||||
return arr.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return map.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFrom(ArrayHolder arrayHolder) {
|
||||
map.clear();
|
||||
Collection<String> names = arrayHolder.arrayNames();
|
||||
for (String n : names) {
|
||||
setArray(n, arrayHolder.getArray(n));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<String> arrayNames() {
|
||||
return Collections.unmodifiableCollection(map.keySet());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void rename(@NonNull String from, @NonNull String to) {
|
||||
DeviceLocalNDArray dl = map.remove(from);
|
||||
map.put(to, dl);
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.ops;
|
|||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||
|
||||
|
@ -237,9 +238,23 @@ public class SDRandom extends SDOps {
|
|||
return uniform(null, min, max, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #uniform(String, double, double, SDVariable)
|
||||
*/
|
||||
public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) {
|
||||
return uniform(null, min, max, shape, dataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output
|
||||
*/
|
||||
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
|
||||
return uniform(name, min, max, shape, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
|
||||
* U(min,max)<br>
|
||||
* U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned<br>
|
||||
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
|
||||
* specified as a long[] instead
|
||||
*
|
||||
|
@ -247,11 +262,12 @@ public class SDRandom extends SDOps {
|
|||
* @param min Minimum value
|
||||
* @param max Maximum value. Must satisfy max >= min
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @return New SDVariable
|
||||
* @param dataType Data type of the output array (if null: Float32 output is returned)
|
||||
* @return New SDVariable, of the specified data type
|
||||
*/
|
||||
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
|
||||
public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) {
|
||||
validateInteger("uniform random", shape);
|
||||
SDVariable ret = f().randomUniform(min, max, shape);
|
||||
SDVariable ret = f().randomUniform(min, max, shape, dataType);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ public class LogSumExp extends DynamicCustomOp {
|
|||
super(sameDiff, i_v);
|
||||
if(dimensions != null) {
|
||||
addIArgument(dimensions);
|
||||
this.dimensions = dimensions;
|
||||
}
|
||||
addTArgument(keepDims ? 1.0 : 0.0);
|
||||
this.keepDims = keepDims;
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -41,30 +42,47 @@ import java.util.Map;
|
|||
public class DistributionUniform extends DynamicCustomOp {
|
||||
private double min = 0.0;
|
||||
private double max = 1.0;
|
||||
private DataType dataType;
|
||||
|
||||
public DistributionUniform() {
|
||||
//
|
||||
}
|
||||
|
||||
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max){
|
||||
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max) {
|
||||
this(sd, shape, min, max, null);
|
||||
}
|
||||
|
||||
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max, DataType dataType){
|
||||
super(null, sd, new SDVariable[]{shape});
|
||||
Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max);
|
||||
addTArgument(min, max);
|
||||
Preconditions.checkState(dataType == null || dataType.isNumerical(), "Only numerical datatypes can be used with DistributionUniform - rquested output datatype: %s", dataType);
|
||||
this.dataType = dataType;
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DistributionUniform(INDArray shape, INDArray out, double min, double max){
|
||||
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
||||
addArgs();
|
||||
AttrValue v = attributesForNode.get("dtype");
|
||||
dataType = TFGraphMapper.convertType(v.getType());
|
||||
addIArgument(dataType.toInt());
|
||||
}
|
||||
|
||||
protected void addArgs() {
|
||||
tArguments.clear();
|
||||
addTArgument(min, max);
|
||||
if(dataType != null){
|
||||
iArguments.clear();
|
||||
addIArgument(dataType.toInt());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -85,8 +103,10 @@ public class DistributionUniform extends DynamicCustomOp {
|
|||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||
//Input data type specifies the shape; output data type should be any float
|
||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
||||
//Input data type specifies the shape
|
||||
if(dataType != null){
|
||||
return Collections.singletonList(dataType);
|
||||
}
|
||||
return Collections.singletonList(DataType.FLOAT);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
||||
<cuda.version>10.1</cuda.version>
|
||||
<cudnn.version>7.6</cudnn.version>
|
||||
<javacpp-presets.cuda.version>1.5.1</javacpp-presets.cuda.version>
|
||||
<javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
|
||||
<nd4j.backend>nd4j-cuda-${cuda.version}</nd4j.backend>
|
||||
</properties>
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
||||
<cuda.version>10.1</cuda.version>
|
||||
<cudnn.version>7.6</cudnn.version>
|
||||
<javacpp-presets.cuda.version>1.5.1</javacpp-presets.cuda.version>
|
||||
<javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
|
||||
</properties>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -382,4 +382,21 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
INDArray out = Nd4j.exec(all);
|
||||
assertEquals(x, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUniformDtype(){
|
||||
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));
|
||||
SDVariable out = sd.random.uniform(0, 10, shape, t);
|
||||
INDArray arr = out.eval();
|
||||
assertEquals(t, arr.dataType());
|
||||
double min = arr.minNumber().doubleValue();
|
||||
double max = arr.maxNumber().doubleValue();
|
||||
double mean = arr.meanNumber().doubleValue();
|
||||
assertEquals(0, min, 0.5);
|
||||
assertEquals(10, max, 0.5);
|
||||
assertEquals(5.5, mean, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1970,4 +1970,24 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
INDArray log = Transforms.log(sum);
|
||||
assertEquals(log, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLogSumExp2(){
|
||||
|
||||
for( int dim=0; dim<=2; dim++ ) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var(inputArr);
|
||||
SDVariable lse = sd.math().logSumExp(in, dim);
|
||||
|
||||
INDArray exp = Transforms.exp(inputArr, true);
|
||||
INDArray sum = exp.sum(dim);
|
||||
INDArray log = Transforms.log(sum);
|
||||
|
||||
OpValidation.validate(new TestCase(sd)
|
||||
.expectedOutput(lse.name(), log)
|
||||
.gradientCheck(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2311,7 +2311,6 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
SDVariable loss = out.std("out", true);
|
||||
|
||||
INDArray outArr = loss.eval();
|
||||
// sd.execBackwards(Collections.emptyMap());
|
||||
Map<String,INDArray> grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
|
||||
|
||||
Map<String, INDArray> origGrad = new HashMap<>();
|
||||
|
@ -2321,7 +2320,6 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
|
||||
in.getArr().assign(Nd4j.rand(in.getArr().shape()));
|
||||
INDArray outArr2 = loss.eval();
|
||||
// sd.execBackwards(Collections.emptyMap());
|
||||
grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
|
||||
|
||||
assertNotEquals(outArr, outArr2);
|
||||
|
@ -2641,8 +2639,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
.expectedOutput("out", out)
|
||||
.gradientCheck(true));
|
||||
|
||||
assertNull(err, err);
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
18
pom.xml
18
pom.xml
|
@ -288,21 +288,21 @@
|
|||
<javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 -->
|
||||
<javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties>
|
||||
|
||||
<javacpp.version>1.5.1-1</javacpp.version>
|
||||
<javacpp-presets.version>1.5.1</javacpp-presets.version>
|
||||
<javacv.version>1.5.1</javacv.version>
|
||||
<javacpp.version>1.5.2</javacpp.version>
|
||||
<javacpp-presets.version>1.5.2</javacpp-presets.version>
|
||||
<javacv.version>1.5.2</javacv.version>
|
||||
|
||||
<python.version>3.7.3</python.version>
|
||||
<python.version>3.7.5</python.version>
|
||||
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
|
||||
|
||||
<openblas.version>0.3.6</openblas.version>
|
||||
<mkl.version>2019.4</mkl.version>
|
||||
<opencv.version>4.1.0</opencv.version>
|
||||
<ffmpeg.version>4.1.3</ffmpeg.version>
|
||||
<openblas.version>0.3.7</openblas.version>
|
||||
<mkl.version>2019.5</mkl.version>
|
||||
<opencv.version>4.1.2</opencv.version>
|
||||
<ffmpeg.version>4.2.1</ffmpeg.version>
|
||||
<leptonica.version>1.78.0</leptonica.version>
|
||||
<hdf5.version>1.10.5</hdf5.version>
|
||||
<ale.version>0.6.0</ale.version>
|
||||
<tensorflow.version>1.14.0</tensorflow.version>
|
||||
<tensorflow.version>1.15.0</tensorflow.version>
|
||||
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
|
||||
|
||||
<commons-compress.version>1.18</commons-compress.version>
|
||||
|
|
Loading…
Reference in New Issue