Merge remote-tracking branch 'konduit/master'

master
AlexDBlack 2019-11-08 18:11:45 +11:00
commit 0107fb10ab
45 changed files with 762 additions and 281 deletions

View File

@ -49,7 +49,7 @@ check_cuda_version "$VERSION"
case $VERSION in case $VERSION in
10.1) 10.1)
VERSION2="7.6" VERSION2="7.6"
VERSION3="1.5.1" VERSION3="1.5.2"
;; ;;
10.0) 10.0)
VERSION2="7.4" VERSION2="7.4"

View File

@ -28,7 +28,7 @@
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml --> <!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
<cuda.version>10.1</cuda.version> <cuda.version>10.1</cuda.version>
<cudnn.version>7.6</cudnn.version> <cudnn.version>7.6</cudnn.version>
<javacpp-presets.cuda.version>1.5.1</javacpp-presets.cuda.version> <javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
</properties> </properties>
<dependencyManagement> <dependencyManagement>

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr; import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -133,16 +134,6 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
} }
is.setMmgr(mmgr); is.setMmgr(mmgr);
if(paramTable != null && paramTable.size() > 0) {
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
}
INDArray result = sameDiff.outputSingle(phMap, outputKey); INDArray result = sameDiff.outputSingle(phMap, outputKey);
//Edge case: "vertex" is just an identity activation, for example //Edge case: "vertex" is just an identity activation, for example
@ -212,17 +203,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
String epsName = fn.getGradPlaceholderName(); String epsName = fn.getGradPlaceholderName();
phMap.put(epsName, epsilon); phMap.put(epsName, epsilon);
List<String> required = new ArrayList<>(config.getVertexParams().getInputs()); //Ensure that the input placeholder gradients are calculated
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
required.addAll(paramTable.keySet()); required.addAll(paramTable.keySet());
required.addAll(inputNames);
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required); Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
@ -279,6 +261,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
protected void doInit(){ protected void doInit(){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
sameDiff = SameDiff.create(); sameDiff = SameDiff.create();
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
inputVars = new LinkedHashMap<>(); inputVars = new LinkedHashMap<>();
LinkedHashMap<String, SDVariable> maskVars = new LinkedHashMap<>(); LinkedHashMap<String, SDVariable> maskVars = new LinkedHashMap<>();

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.layers.AbstractLayer;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr; import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -100,13 +101,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
} }
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
//Configure memory management for SameDiff instance - use DL4J workspaces //Configure memory management for SameDiff instance - use DL4J workspaces
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM); String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS); String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
@ -179,13 +173,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input); bl.validateInput(input);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>(); Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input); phMap.put(INPUT_KEY, input);
phMap.put(fn.getGradPlaceholderName(), epsilon); phMap.put(fn.getGradPlaceholderName(), epsilon);
@ -300,6 +287,8 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
sameDiff = SameDiff.create(); sameDiff = SameDiff.create();
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
Map<String, INDArray> p = paramTable(); Map<String, INDArray> p = paramTable();
long[] inputShape = input.shape().clone(); long[] inputShape = input.shape().clone();

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr; import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -119,15 +120,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
} }
is.setMmgr(mmgr); is.setMmgr(mmgr);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>(); Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input); phMap.put(INPUT_KEY, input);
if(!activations && layerConf().labelsRequired() && labels != null) { if(!activations && layerConf().labelsRequired() && labels != null) {
@ -193,13 +185,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.createGradFunction(INPUT_KEY); sameDiff.createGradFunction(INPUT_KEY);
} }
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
List<String> gradVarNames = new ArrayList<>(); List<String> gradVarNames = new ArrayList<>();
gradVarNames.addAll(paramTable.keySet()); gradVarNames.addAll(paramTable.keySet());
gradVarNames.add(INPUT_KEY); gradVarNames.add(INPUT_KEY);
@ -317,6 +302,8 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer bl = layerConf(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer bl = layerConf();
sameDiff = SameDiff.create(); sameDiff = SameDiff.create();
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
Map<String, INDArray> p = paramTable(); Map<String, INDArray> p = paramTable();
long[] inputShape = input.shape().clone(); long[] inputShape = input.shape().clone();
@ -339,7 +326,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null"); Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
outputVar = layerOutput; outputVar = layerOutput;
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
for (Map.Entry<String, INDArray> e : p.entrySet()) { for (Map.Entry<String, INDArray> e : p.entrySet()) {
INDArray arr = e.getValue(); INDArray arr = e.getValue();
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey())); sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));

View File

@ -29,32 +29,32 @@ This example application uses a neural network trained on the standard MNIST dat
Deeplearning4J applications requires application specific dependencies in the build.gradle file. The Deeplearning library in turn depends on the libraries of ND4J and OpenBLAS, thus these must also be added to the dependencies declaration. Starting with Android Studio 3.0, annotationProcessors need to be defined as well, thus dependencies for either -x86 or -arm processors should be included, depending on your device, if you are working in Android Studio 3.0 or later. Note that both can be include without conflict as is done in the example app. Deeplearning4J applications requires application specific dependencies in the build.gradle file. The Deeplearning library in turn depends on the libraries of ND4J and OpenBLAS, thus these must also be added to the dependencies declaration. Starting with Android Studio 3.0, annotationProcessors need to be defined as well, thus dependencies for either -x86 or -arm processors should be included, depending on your device, if you are working in Android Studio 3.0 or later. Note that both can be include without conflict as is done in the example app.
```groovy ```groovy
compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') { implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
exclude group: 'org.bytedeco', module: 'opencv-platform' exclude group: 'org.bytedeco', module: 'opencv-platform'
exclude group: 'org.bytedeco', module: 'leptonica-platform' exclude group: 'org.bytedeco', module: 'leptonica-platform'
exclude group: 'org.bytedeco', module: 'hdf5-platform' exclude group: 'org.bytedeco', module: 'hdf5-platform'
exclude group: 'org.nd4j', module: 'nd4j-base64' exclude group: 'org.nd4j', module: 'nd4j-base64'
} }
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}' implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1' implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1' implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1' implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
implementation 'com.google.code.gson:gson:2.8.2' implementation 'com.google.code.gson:gson:2.8.2'
annotationProcessor 'org.projectlombok:lombok:1.16.16' annotationProcessor 'org.projectlombok:lombok:1.16.16'

View File

@ -25,32 +25,31 @@ Contents
## <a name="head_link1">Setting the Dependencies</a> ## <a name="head_link1">Setting the Dependencies</a>
Deeplearning4J applications require several dependencies in the build.gradle file. The Deeplearning library in turn depends on the libraries of ND4J and OpenBLAS, thus these must also be added to the dependencies declaration. Starting with Android Studio 3.0, annotationProcessors need to be defined as well, requiring dependencies for -x86 or -arm processors. Deeplearning4J applications require several dependencies in the build.gradle file. The Deeplearning library in turn depends on the libraries of ND4J and OpenBLAS, thus these must also be added to the dependencies declaration. Starting with Android Studio 3.0, annotationProcessors need to be defined as well, requiring dependencies for -x86 or -arm processors.
```groovy ```groovy
compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') { implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
exclude group: 'org.bytedeco', module: 'opencv-platform' exclude group: 'org.bytedeco', module: 'opencv-platform'
exclude group: 'org.bytedeco', module: 'leptonica-platform' exclude group: 'org.bytedeco', module: 'leptonica-platform'
exclude group: 'org.bytedeco', module: 'hdf5-platform' exclude group: 'org.bytedeco', module: 'hdf5-platform'
exclude group: 'org.nd4j', module: 'nd4j-base64'
} }
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}' implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1' implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1' implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1' implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
``` ```
Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig. Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig.

View File

@ -33,33 +33,32 @@ It is also recommended that you download and install IntelliJ IDEA, Maven, and t
In order to use Deeplearning4J in your Android projects, you will need to add the following dependencies to your app modules build.gradle file. Depending on the type of neural network used in your application, you may need to add additional dependencies. In order to use Deeplearning4J in your Android projects, you will need to add the following dependencies to your app modules build.gradle file. Depending on the type of neural network used in your application, you may need to add additional dependencies.
``` groovy ``` groovy
compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') { implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
exclude group: 'org.bytedeco', module: 'opencv-platform' exclude group: 'org.bytedeco', module: 'opencv-platform'
exclude group: 'org.bytedeco', module: 'leptonica-platform' exclude group: 'org.bytedeco', module: 'leptonica-platform'
exclude group: 'org.bytedeco', module: 'hdf5-platform' exclude group: 'org.bytedeco', module: 'hdf5-platform'
exclude group: 'org.nd4j', module: 'nd4j-base64'
} }
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}' implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1' implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1' implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1' implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
testCompile 'junit:junit:4.12' testimplementation 'junit:junit:4.12'
``` ```
DL4J depends on ND4J, which is a library that offers fast n-dimensional arrays. ND4J in turn depends on a platform-specific native code library called JavaCPP, therefore you must load a version of ND4J that matches the architecture of the Android device. Both -x86 and -arm types can be included to support multiple device processor types. DL4J depends on ND4J, which is a library that offers fast n-dimensional arrays. ND4J in turn depends on a platform-specific native code library called JavaCPP, therefore you must load a version of ND4J that matches the architecture of the Android device. Both -x86 and -arm types can be included to support multiple device processor types.

View File

@ -33,35 +33,34 @@ For best results, youll need the following:
## <a name="head_link2">Configuring Your Android Studio Project</a> ## <a name="head_link2">Configuring Your Android Studio Project</a>
To be able to use Deeplearning4J in your project, add the following compile dependencies to your app modules build.gradle file: To be able to use Deeplearning4J in your project, add the following implementation dependencies to your app modules build.gradle file:
``` groovy ``` groovy
compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') { implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '{{page.version}}') {
exclude group: 'org.bytedeco', module: 'opencv-platform' exclude group: 'org.bytedeco', module: 'opencv-platform'
exclude group: 'org.bytedeco', module: 'leptonica-platform' exclude group: 'org.bytedeco', module: 'leptonica-platform'
exclude group: 'org.bytedeco', module: 'hdf5-platform' exclude group: 'org.bytedeco', module: 'hdf5-platform'
exclude group: 'org.nd4j', module: 'nd4j-base64'
} }
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}' implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}'
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86"
compile group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1' implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'openblas', version: '0.3.6-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1' implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'opencv', version: '4.1.0-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1' implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-arm64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
compile group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.1', classifier: "android-x86_64" implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
``` ```

View File

@ -61,25 +61,25 @@ Alternatively, in the case of CUDA 10.1, cuDNN comes bundled with the "redist" p
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>cuda</artifactId> <artifactId>cuda</artifactId>
<version>10.1-7.6-1.5.1</version> <version>10.1-7.6-1.5.2</version>
<classifier>linux-x86_64-redist</classifier> <classifier>linux-x86_64-redist</classifier>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>cuda</artifactId> <artifactId>cuda</artifactId>
<version>10.1-7.6-1.5.1</version> <version>10.1-7.6-1.5.2</version>
<classifier>linux-ppc64le-redist</classifier> <classifier>linux-ppc64le-redist</classifier>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>cuda</artifactId> <artifactId>cuda</artifactId>
<version>10.1-7.6-1.5.1</version> <version>10.1-7.6-1.5.2</version>
<classifier>macosx-x86_64-redist</classifier> <classifier>macosx-x86_64-redist</classifier>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>cuda</artifactId> <artifactId>cuda</artifactId>
<version>10.1-7.6-1.5.1</version> <version>10.1-7.6-1.5.2</version>
<classifier>windows-x86_64-redist</classifier> <classifier>windows-x86_64-redist</classifier>
</dependency> </dependency>

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at

View File

@ -59,6 +59,7 @@ namespace nd4j {
template <typename T> template <typename T>
static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* create_(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
template <typename T> template <typename T>
static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());

View File

@ -422,6 +422,11 @@ NDArray NDArrayFactory::create(nd4j::DataType dtype, nd4j::LaunchContext * conte
return res; return res;
} }
NDArray* NDArrayFactory::create_(nd4j::DataType dtype, nd4j::LaunchContext * context) {
auto result = new NDArray();
*result = NDArrayFactory::create(dtype, context);
return result;
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -29,8 +30,8 @@
#include <dll.h> #include <dll.h>
#include <Environment.h> #include <Environment.h>
#include <ArrayOptions.h> #include <ArrayOptions.h>
#include <templatemath.h> //#include <templatemath.h>
#include <shape.h> //#include <shape.h>
#include <helpers/logger.h> #include <helpers/logger.h>
namespace nd4j { namespace nd4j {
@ -128,7 +129,9 @@ namespace nd4j {
// if both dtypes are the same - just return it // if both dtypes are the same - just return it
if (typeX == typeY) if (typeX == typeY)
return typeX; return typeX;
auto nd4j_max = [](nd4j::DataType typeX, nd4j::DataType typeY) {
return typeX > typeY?typeX:typeY;
};
auto rX = isR(typeX); auto rX = isR(typeX);
auto rY = isR(typeY); auto rY = isR(typeY);
@ -144,7 +147,7 @@ namespace nd4j {
if (rX && rY) { if (rX && rY) {
// if we allow precision boost, then we pick bigger data type // if we allow precision boost, then we pick bigger data type
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) { if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
return nd4j::math::nd4j_max<nd4j::DataType>(typeX, typeY); return nd4j_max(typeX, typeY);
} else { } else {
// and we return first operand otherwise // and we return first operand otherwise
return typeX; return typeX;
@ -155,7 +158,7 @@ namespace nd4j {
// if that's not real type, we apply same rules // if that's not real type, we apply same rules
if (!rX && !rY) { if (!rX && !rY) {
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) { if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
return nd4j::math::nd4j_max<nd4j::DataType>(typeX, typeY); return nd4j_max(typeX, typeY);
} else { } else {
// and we return first operand otherwise // and we return first operand otherwise
return typeX; return typeX;
@ -367,8 +370,8 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
template <typename T> template <typename T>
FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) { FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) {
auto shapeInfoLength = *originalShapeInfo * 2 + 4;
for (int e = 0; e < shape::shapeInfoLength(originalShapeInfo); e++) { for (auto e = 0; e < shapeInfoLength; e++) {
if (originalShapeInfo[e] < static_cast<Nd4jLong>(DataTypeUtils::max<T>())) { if (originalShapeInfo[e] < static_cast<Nd4jLong>(DataTypeUtils::max<T>())) {
newShapeInfo[e] = static_cast<T>(originalShapeInfo[e]); newShapeInfo[e] = static_cast<T>(originalShapeInfo[e]);
} else } else

View File

@ -1,10 +1,26 @@
/**
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
// //
// Created by raver on 5/17/2019. // Created by raver on 5/17/2019.
// //
#include <array/ConstantHolder.h>
#include <DataTypeUtils.h> #include <DataTypeUtils.h>
#include <array/ConstantHolder.h>
#include <shape.h>
namespace nd4j { namespace nd4j {
ConstantHolder::ConstantHolder(const ConstantHolder& other) { ConstantHolder::ConstantHolder(const ConstantHolder& other) {
@ -24,7 +40,7 @@ namespace nd4j {
bool ConstantHolder::hasBuffer() { bool ConstantHolder::hasBuffer() {
return hasBuffer(DataTypeUtils::fromT<T>()); return hasBuffer(DataTypeUtils::fromT<T>());
} }
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT bool ConstantHolder::hasBuffer, (), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT bool ConstantHolder::hasBuffer, (void), LIBND4J_TYPES);
void ConstantHolder::addBuffer(ConstantDataBuffer &pointer, nd4j::DataType dataType) { void ConstantHolder::addBuffer(ConstantDataBuffer &pointer, nd4j::DataType dataType) {
_buffers[dataType] = pointer; _buffers[dataType] = pointer;
@ -34,7 +50,7 @@ namespace nd4j {
void ConstantHolder::addBuffer(ConstantDataBuffer &pointer) { void ConstantHolder::addBuffer(ConstantDataBuffer &pointer) {
addBuffer(pointer, DataTypeUtils::fromT<T>()); addBuffer(pointer, DataTypeUtils::fromT<T>());
} }
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer&), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer& cb), LIBND4J_TYPES);
ConstantDataBuffer* ConstantHolder::getConstantDataBuffer(nd4j::DataType dataType) { ConstantDataBuffer* ConstantHolder::getConstantDataBuffer(nd4j::DataType dataType) {
if (!hasBuffer(dataType)) if (!hasBuffer(dataType))

View File

@ -195,7 +195,21 @@ namespace nd4j {
template <typename T> template <typename T>
_CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, T to) { _CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, T to) {
auto t = this->relativeT<T>(index); auto t = this->relativeT<T>(index);
auto z = from + (t * (to - from)); auto z = from + T(t * (to - from));
return z;
}
template <>
_CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index, Nd4jLong from, Nd4jLong to) {
auto t = this->relativeT<double>(index);
auto z = from + Nd4jLong(t * (to - from));
return z;
}
template <>
_CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index, int from, int to) {
auto t = this->relativeT<float>(index);
auto z = from + float(t * (to - from));
return z; return z;
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -21,6 +22,7 @@
#include <exceptions/cuda_exception.h> #include <exceptions/cuda_exception.h>
#include <ConstantHelper.h> #include <ConstantHelper.h>
#include <DataTypeUtils.h> #include <DataTypeUtils.h>
#include <shape.h>
#include <execution/LaunchContext.h> #include <execution/LaunchContext.h>
#include <specials.h> #include <specials.h>
#include <logger.h> #include <logger.h>

View File

@ -620,14 +620,15 @@ namespace nd4j {
// FIXME: remove this method once we get 1D vectors supported // FIXME: remove this method once we get 1D vectors supported
vectorize(input_shape); vectorize(input_shape);
REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed"); REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed");
//REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op.");
//Zero output array, so unused elements have 0 gradient //Zero output array, so unused elements have 0 gradient
output->nullify(); output->nullify();
std::sort(indices.begin(), indices.end()); //
if(indices.size() == 3 && (indices[1] - indices[0]) == 1) { // the first case: only for scalar gradient step
if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) {
output->p(indices[0], *epsNext); output->p(indices[0], *epsNext);
} }
else { else { // else for other cases
auto sub = (*output)(indices, true, true); auto sub = (*output)(indices, true, true);
sub.assign(epsNext); sub.assign(epsNext);
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -23,6 +24,7 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <helpers/RandomLauncher.h> #include <helpers/RandomLauncher.h>
#include <ops/declarable/helpers/random.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -35,41 +37,59 @@ namespace nd4j {
* TArgs[0] - min for rng * TArgs[0] - min for rng
* TArgs[1] - max for rng * TArgs[1] - max for rng
*/ */
CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 2, 0) { CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 0, 0) {
// uniform distribution // uniform distribution
auto rng = block.randomGenerator(); auto rng = block.randomGenerator();
auto dtype = DataType::FLOAT32;
if (block.getIArguments()->size())
dtype = (DataType)INT_ARG(0);
// FIXME: to be implemented auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
/* auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
if (rng == nullptr) bool disposable = false;
return Status::THROW("RNG is null, aborting...");
auto x = INPUT_VARIABLE(0); if (min == nullptr && max == nullptr && block.numT() >= 2) {
auto z = OUTPUT_VARIABLE(0); min = NDArrayFactory::create_(dtype);
max = NDArrayFactory::create_(dtype);
min->p(0, T_ARG(0));
max->p(0, T_ARG(1));
disposable = true;
}
functions::random::RandomFunction<T>::template execTransform<randomOps::UniformDistribution<T>>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data()); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
STORE_RESULT(*z); helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
*/
REQUIRE_TRUE(block.numT() > 1, 0, "RandomUniform: to/from must be set");
RandomLauncher::fillUniform(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1)); if (disposable) {
delete min;
delete max;
}
return Status::OK(); return Status::OK();
} }
DECLARE_SHAPE_FN(randomuniform) { DECLARE_SHAPE_FN(randomuniform) {
auto in = INPUT_VARIABLE(0); auto in = INPUT_VARIABLE(0);
//auto min = INPUT_VARIABLE(1);
auto shape = in->template asVectorT<Nd4jLong>(); auto shape = in->template asVectorT<Nd4jLong>();
auto dtype = DataType::FLOAT32; //ArrayOptions::dataType(inputShape->at(1)); // output type is by given min
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', shape); if (block.getIArguments()->size())
dtype = (DataType)INT_ARG(0);
if (block.width() > 1)
REQUIRE_TRUE(dtype == INPUT_VARIABLE(1)->dataType(), 0, "RandomUniform: data type of output and min/max args should be the same");
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape);
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }
DECLARE_TYPES(randomuniform) { DECLARE_TYPES(randomuniform) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(0, {ALL_INTS})
->setAllowedOutputTypes({ALL_FLOATS}); ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS});
} }
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -33,8 +34,20 @@ namespace nd4j {
DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0); DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0);
#endif #endif
/*
* random_uniform distribution for types int32,int64, float16, float and double
* by default dtype is float32
*
* input:
* 0 - shape of output (1D int tensor)
* 1 - min val (0D of output type) - optional (0 as default)
* 2 - max val (0D of output type) - optional (inf as default)
*
* output:
* 0 - uniformly distributed values of given type (between min and max)
*/
#if NOT_EXCLUDED(OP_randomuniform) #if NOT_EXCLUDED(OP_randomuniform)
DECLARE_CUSTOM_OP(randomuniform, 1, 1, true, 2, 0); DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0);
#endif #endif
#if NOT_EXCLUDED(OP_random_normal) #if NOT_EXCLUDED(OP_random_normal)
@ -66,6 +79,7 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_random_poisson) #if NOT_EXCLUDED(OP_random_poisson)
DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0); DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0);
#endif #endif
} }
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -23,6 +23,7 @@
#include <memory> #include <memory>
//#include <graph/Context.h> //#include <graph/Context.h>
#include <ShapeUtils.h> #include <ShapeUtils.h>
#include <helpers/RandomLauncher.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -127,6 +128,28 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context,
graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES); graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES);
template <typename T>
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
T minVal = T(0);
T maxVal = DataTypeUtils::max<T>();
if (min)
minVal = min->t<T>(0);
if (max)
maxVal = max->t<T>(0);
if (output->isR())
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
else {
PRAGMA_OMP_PARALLEL_FOR
for (auto i = 0; i < output->lengthOf(); i++) {
output->t<T>(i) = rng.relativeT<T>(i, minVal, maxVal);
}
}
}
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
}
} }
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -91,7 +92,7 @@ namespace helpers {
val = nd4j::math::nd4j_min<T>(val, input->t<T>(e)); val = nd4j::math::nd4j_min<T>(val, input->t<T>(e));
} }
else { else {
idx = indices->e<int>(e); idx = indices->e<Nd4jLong>(e);
val = input->t<T>(e); val = input->t<T>(e);
} }
output->t<T>(idx) = val; output->t<T>(idx) = val;
@ -111,14 +112,14 @@ namespace helpers {
minT->assign(listOfTensors->at(0)); minT->assign(listOfTensors->at(0));
for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { for (Nd4jLong i = 1; i < indices->lengthOf(); i++) {
if (indices->e<T>(i) == idx) { if (indices->e<Nd4jLong>(i) == idx) {
for (int e = 0; e < minT->lengthOf(); e++) { for (int e = 0; e < minT->lengthOf(); e++) {
minT->p(e, nd4j::math::nd4j_min(minT->e<T>(e), listOfTensors->at(i)->e<T>(e))); minT->p(e, nd4j::math::nd4j_min(minT->e<T>(e), listOfTensors->at(i)->e<T>(e)));
} }
} }
else { else {
idx = indices->e<T>(i); idx = indices->e<Nd4jLong>(i);
minT = listOfOutTensors->at(idx); minT = listOfOutTensors->at(idx);
minT->assign(listOfTensors->at(i)); minT->assign(listOfTensors->at(i));
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -26,6 +26,7 @@
#include <helpers/RandomLauncher.h> #include <helpers/RandomLauncher.h>
#include <ShapeUtils.h> #include <ShapeUtils.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <cuda_exception.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -181,6 +182,72 @@ namespace helpers {
} }
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE); BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE);
template <typename T>
static __global__ void fillUniformKernel(graph::RandomGenerator* devRng, T from, T to, T* output, Nd4jLong* outputShape) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
__shared__ Nd4jLong outputLen;
if (0 == threadIdx.x) {
outputLen = shape::length(outputShape);
}
__syncthreads();
for (auto i = start; i < outputLen; i += step) {
auto zIndex = shape::getIndexOffset(i, outputShape);
output[zIndex] = devRng->relativeT<T>(i, from, to);
}
}
template <typename T>
static void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
T minVal = T(0);
T maxVal = DataTypeUtils::infOrMax<T>();
if (min)
minVal = min->t<T>(0);
if (max)
maxVal = max->t<T>(0);
if (output->isR())
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
else {
auto stream = context->getCudaStream();
graph::RandomGenerator *devRng;
auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator));
if (err != 0) {
cuda_exception::build("fillRandomUniform_: Cannot allocate device memory for random generator due error", err);
}
err = cudaMemcpy(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice);
if (err != 0) {
cuda_exception::build("fillRandomUniform_: Cannot copy random generator to device", err);
}
auto outputBuf = output->dataBuffer()->specialAsT<T>();
auto outputShape = output->specialShapeInfo();
fillUniformKernel<T><<<128, 128, 128, *stream>>>(devRng, minVal, maxVal, outputBuf, outputShape);
err = cudaStreamSynchronize(*stream);
if (err != 0) {
cuda_exception::build("fillRandomUniform_: Cannot successfully finish kernel call", err);
}
err = cudaFree(devRng);
if (err != 0) {
cuda_exception::build("fillRandomUniform_: Cannot deallocate device memory for random generator", err);
}
}
}
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
} }
} }
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -33,7 +33,7 @@ namespace helpers {
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output); void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output); void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output);
} }
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -273,8 +274,8 @@ public:
BlockInformation(Nd4jLong length, int threshold) { BlockInformation(Nd4jLong length, int threshold) {
threads = length / threshold; threads = length / threshold;
threads = nd4j::math::nd4j_max<int>(1, threads); threads = (1 < threads)?threads:1;//nd4j::math::nd4j_max<int>(1, threads);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads()); threads = (threads < omp_get_max_threads())?threads:omp_get_max_threads();//nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
items = length / threads; items = length / threads;
remainder = length % threads; remainder = length % threads;

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -27,7 +28,7 @@
#include <dll.h> #include <dll.h>
#include <pointercast.h> #include <pointercast.h>
#include <platformmath.h> #include <platformmath.h>
#include <DataTypeUtils.h>
#define BFLOAT16_MAX_VALUE 32737. #define BFLOAT16_MAX_VALUE 32737.
#define HALF_MAX_VALUE 65504. #define HALF_MAX_VALUE 65504.
@ -883,7 +884,7 @@ namespace nd4j {
if (a > 171.624) { if (a > 171.624) {
// Correct answer too large to display. Force +infinity. // Correct answer too large to display. Force +infinity.
return Z(DOUBLE_MAX_VALUE); return Z(DOUBLE_MAX_VALUE);
//DataTypeUtils::infOrMax<Z>(); // return DataTypeUtils::infOrMax<Z>();
} }
return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a)); return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a));

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at

View File

@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
// auto e = NDArrayFactory::create<int>('c', {1}, {zero}); // auto e = NDArrayFactory::create<int>('c', {1}, {zero});
// auto s = NDArrayFactory::create<int>('c', {1}, {1}); // auto s = NDArrayFactory::create<int>('c', {1}, {1});
auto grad = NDArrayFactory::create<double>('c', {5,4}); auto grad = NDArrayFactory::create<double>('c', {5});
matrix.linspace(1); matrix.linspace(1);
grad.linspace(1); grad.linspace(1);
@ -264,6 +264,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
int zero = 0; int zero = 0;
auto matrix = NDArrayFactory::create<double>('c', {1, 2}); auto matrix = NDArrayFactory::create<double>('c', {1, 2});
@ -287,6 +288,31 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
int zero = 0;
auto matrix = NDArrayFactory::create<float>('c', {4, 8192});
// auto b = NDArrayFactory::create<int>('c', {1}, {zero});
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
auto grad = NDArrayFactory::create<double>('c', {4, 256});
matrix.linspace(1);
grad.linspace(1);
nd4j::ops::strided_slice_bp op;
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
//ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f}); auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f}); auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});

View File

@ -58,6 +58,11 @@ public:
} }
}; };
TEST_F(PlaygroundTests, test_s_1) {
auto t = ::runLightBenchmarkSuit(true);
delete[] t;
}
/* /*
TEST_F(PlaygroundTests, test_relubp_1) { TEST_F(PlaygroundTests, test_relubp_1) {
auto x = NDArrayFactory::create<float>('c', {128, 64, 224, 224}); auto x = NDArrayFactory::create<float>('c', {128, 64, 224, 224});

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -855,18 +856,39 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
delete result; delete result;
} }
TEST_F(RNGTests, Test_UniformDistribution_04) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
auto al = NDArrayFactory::create<int>(1);
auto be = NDArrayFactory::create<int>(20);
auto exp0 = NDArrayFactory::create<float>('c', {10});
nd4j::ops::randomuniform op;
auto result = op.execute({&x, &al, &be}, {}, {DataType::INT32});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printIndexedBuffer("Uniform int distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
namespace nd4j { namespace nd4j {
namespace tests { namespace tests {
static void fillList(Nd4jLong seed, int numberOfArrays, std::vector<Nd4jLong> &shape, std::vector<NDArray*> &list, nd4j::graph::RandomGenerator *rng) { static void fillList(Nd4jLong seed, int numberOfArrays, std::vector<Nd4jLong> &shape, std::vector<NDArray*> &list, nd4j::graph::RandomGenerator *rng) {
rng->setSeed((int) seed); rng->setSeed((int) seed);
for (int i = 0; i < numberOfArrays; i++) { for (int i = 0; i < numberOfArrays; i++) {
auto array = NDArrayFactory::create_<double>('c', shape); auto arrayI = NDArrayFactory::create<Nd4jLong>(shape);
auto arrayR = NDArrayFactory::create_<double>('c', shape);
auto min = NDArrayFactory::create(0.0);
auto max = NDArrayFactory::create(1.0);
nd4j::ops::randomuniform op; nd4j::ops::randomuniform op;
op.execute(*rng, {array}, {array}, {0.0, 1.0}, {}, {}, true); op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, false);
list.emplace_back(array); list.emplace_back(arrayR);
} }
}; };
} }
@ -979,3 +1001,13 @@ TEST_F(RNGTests, test_choice_1) {
delete x; delete x;
delete prob; delete prob;
} }
TEST_F(RNGTests, test_uniform_119) {
auto x = NDArrayFactory::create<int>('c', {2}, {1, 5});
auto z = NDArrayFactory::create<float>('c', {1, 5});
nd4j::ops::randomuniform op;
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
ASSERT_EQ(Status::OK(), status);
}

View File

@ -401,8 +401,8 @@ public class DifferentialFunctionFactory {
return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables(); return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables();
} }
public SDVariable randomUniform(double min, double max, SDVariable shape) { public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) {
return new DistributionUniform(sameDiff(), shape, min, max).outputVariable(); return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable();
} }
public SDVariable randomUniform(double min, double max, long... shape) { public SDVariable randomUniform(double min, double max, long... shape) {

View File

@ -0,0 +1,69 @@
package org.nd4j.autodiff.samediff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
/**
* Holds a set of arrays keyed by a String name, functioning essentially like a {@code Map<String,INDArray>}.<br>
* Implementations may have different internal ways of storing arrays, however.<br>
* For example for single threaded applications: {@link org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder}<br>
* And for multi-threaded: {@link org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder}
*
* @author Alex Black
*/
public interface ArrayHolder {
/**
* @return True if an array by that name exists
*/
boolean hasArray(String name);
/**
* @param name Name of the array to get
* @return The array, or null if no array with that name exists
*/
INDArray getArray(String name);
/**
* Set the array for the specified name (new array, or replace if it already exists)
*
* @param name Name of the array
* @param array Array to set
*/
void setArray(String name, INDArray array);
/**
* Remove the array from the ArrayHolder, returning it (if it exists)
*
* @param name Name of the array to return
* @return The now-removed array
*/
INDArray removeArray(String name);
/**
* @return Number of arrays in the ArrayHolder
*/
int size();
/**
* Initialize from the specified array holder.
* This clears all internal arrays, and adds all arrays from the specified array holder
*
* @param arrayHolder Array holder to initialize this based on
*/
void initFrom(ArrayHolder arrayHolder);
/**
* @return Names of the arrays currently in the ArrayHolder
*/
Collection<String> arrayNames();
/**
* Rename the entry with the specified name
*
* @param from Original name
* @param to New name
*/
void rename(String from, String to);
}

View File

@ -30,6 +30,8 @@ import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.api.OutAndGrad; import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder;
import org.nd4j.autodiff.samediff.config.BatchOutputConfig; import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig; import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig; import org.nd4j.autodiff.samediff.config.FitConfig;
@ -122,8 +124,8 @@ public class SameDiff extends SDBaseOps {
@Getter @Getter
private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<>(); //Key: thread ID private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<>(); //Key: thread ID
private final Map<String, DeviceLocalNDArray> constantArrays = new ConcurrentHashMap<>(); private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true);
private final Map<String, DeviceLocalNDArray> variablesArrays = new ConcurrentHashMap<>(); //TODO issues with DeviceLocal + mutable / changed during training? private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true);
private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them
private final List<String> lossVariables = new ArrayList<>(); private final List<String> lossVariables = new ArrayList<>();
@ -346,6 +348,23 @@ public class SameDiff extends SDBaseOps {
return listeners; return listeners;
} }
/**
* Set the array holders for variable and constant arrays<br>
* <b>NOTE:</b> this is usually reserved for developers and internal use, and should not be needed by almost all users<br>
* See {@link ArrayHolder} for more details
*
* @param variableArrayHolder Array holder for variable arrays
* @param constantArrayHolder Array holder for constant arrays
* @param initialize If true: transfer any arrays from the current array holders to the new/specified ones
*/
public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize){
if(initialize){
variableArrayHolder.initFrom(this.variablesArrays);
constantArrayHolder.initFrom(this.constantArrays);
}
this.variablesArrays = variableArrayHolder;
this.constantArrays = constantArrayHolder;
}
/** /**
* @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details. * @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details.
@ -674,9 +693,9 @@ public class SameDiff extends SDBaseOps {
SDVariable v = getVariable(varName); SDVariable v = getVariable(varName);
if (v.isConstant()) { if (v.isConstant()) {
constantArrays.put(varName, new DeviceLocalNDArray(arr, true)); constantArrays.setArray(varName, arr);
} else if (v.getVariableType() == VariableType.VARIABLE) { } else if (v.getVariableType() == VariableType.VARIABLE) {
variablesArrays.put(varName, new DeviceLocalNDArray(arr, true)); variablesArrays.setArray(varName, arr);
} else if (v.isPlaceHolder()) { } else if (v.isPlaceHolder()) {
long tid = Thread.currentThread().getId(); long tid = Thread.currentThread().getId();
if (!placeholdersPerThread.containsKey(tid)) { if (!placeholdersPerThread.containsKey(tid)) {
@ -699,12 +718,12 @@ public class SameDiff extends SDBaseOps {
SDVariable var = getVariable(varName); SDVariable var = getVariable(varName);
switch (var.getVariableType()) { switch (var.getVariableType()) {
case VARIABLE: case VARIABLE:
return variablesArrays.containsKey(varName); return variablesArrays.hasArray(varName);
case ARRAY: case ARRAY:
long tid = Thread.currentThread().getId(); long tid = Thread.currentThread().getId();
return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null); return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null);
case CONSTANT: case CONSTANT:
return constantArrays.containsKey(varName); return constantArrays.hasArray(varName);
case PLACEHOLDER: case PLACEHOLDER:
return placeholdersPerThread.containsKey(Thread.currentThread().getId()) && return placeholdersPerThread.containsKey(Thread.currentThread().getId()) &&
placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName); placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
@ -724,11 +743,11 @@ public class SameDiff extends SDBaseOps {
SDVariable v = variables.get(varName).getVariable(); SDVariable v = variables.get(varName).getVariable();
switch (v.getVariableType()) { switch (v.getVariableType()) {
case VARIABLE: case VARIABLE:
return variablesArrays.get(varName).get(); return variablesArrays.getArray(varName);
case CONSTANT: case CONSTANT:
if (!constantArrays.containsKey(varName)) if (!constantArrays.hasArray(varName))
return null; return null;
return constantArrays.get(varName).get(); return constantArrays.getArray(varName);
case ARRAY: case ARRAY:
//Only stored in inference session... //Only stored in inference session...
InferenceSession s = sessions.get(Thread.currentThread().getId()); InferenceSession s = sessions.get(Thread.currentThread().getId());
@ -781,31 +800,16 @@ public class SameDiff extends SDBaseOps {
sessions.put(Thread.currentThread().getId(), new InferenceSession(this)); sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
} }
boolean duped = false;
if (arr.isAttached()) { if (arr.isAttached()) {
arr = arr.detach(); arr = arr.detach();
duped = true;
}
if (arr.isView()) {
arr = arr.dup();
duped = true;
}
if (!duped && variable.getVariableType() == VariableType.VARIABLE) {
for (DeviceLocalNDArray otherArr : variablesArrays.values()) {
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
arr = arr.dup();
break;
}
}
} }
switch (variable.getVariableType()) { switch (variable.getVariableType()) {
case VARIABLE: case VARIABLE:
variablesArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads variablesArrays.setArray(variable.name(), arr);
break; break;
case CONSTANT: case CONSTANT:
constantArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); constantArrays.setArray(variable.name(), arr);
break; break;
case ARRAY: case ARRAY:
throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" + throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" +
@ -859,9 +863,9 @@ public class SameDiff extends SDBaseOps {
arr = arr.dup(); arr = arr.dup();
if(variable.getVariableType() == VariableType.VARIABLE ){ if(variable.getVariableType() == VariableType.VARIABLE ){
variablesArrays.get(variable.name()).update(arr); variablesArrays.setArray(variable.name(), arr);
} else { } else {
constantArrays.get(variable.name()).update(arr); constantArrays.setArray(variable.name(), arr);
} }
} }
@ -2715,7 +2719,7 @@ public class SameDiff extends SDBaseOps {
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType()); SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType());
name = v.name(); name = v.name();
variables.put(name, Variable.builder().name(name).variable(v).build()); variables.put(name, Variable.builder().name(name).variable(v).build());
constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads constantArrays.setArray(name, constant);
return v; return v;
} }
@ -2792,7 +2796,7 @@ public class SameDiff extends SDBaseOps {
if(variableType == VariableType.VARIABLE){ if(variableType == VariableType.VARIABLE){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
INDArray vArr = weightInitScheme.create(dataType, shape); INDArray vArr = weightInitScheme.create(dataType, shape);
variablesArrays.put(name, new DeviceLocalNDArray(vArr, true)); variablesArrays.setArray(name, vArr);
} }
} }
@ -2924,7 +2928,7 @@ public class SameDiff extends SDBaseOps {
SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType()); SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType());
addVariable(r); addVariable(r);
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
variablesArrays.put(v.name(), new DeviceLocalNDArray(v.getArr().dup(), true)); variablesArrays.setArray(v.name(), v.getArr().dup());
} }
return r; return r;
case ARRAY: case ARRAY:
@ -3014,20 +3018,17 @@ public class SameDiff extends SDBaseOps {
arr = arr.detach(); arr = arr.detach();
duped = true; duped = true;
} }
if (arr.isView()) {
arr = arr.dup();
duped = true;
}
if (!duped) { if (!duped) {
for (DeviceLocalNDArray otherArr : variablesArrays.values()) { for (String s : variablesArrays.arrayNames()) {
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) if (variablesArrays.getArray(s) == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
arr = arr.dup(); arr = arr.dup();
break; break;
} }
} }
} }
SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType()); SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType());
associateArrayWithVariable(arr, ret); associateArrayWithVariable(arr, ret);
@ -3085,8 +3086,8 @@ public class SameDiff extends SDBaseOps {
INDArray arr = variable.getArr(); INDArray arr = variable.getArr();
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads constantArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
variablesArrays.remove(n); variablesArrays.removeArray(n);
if (!placeholdersPerThread.isEmpty()) { if (!placeholdersPerThread.isEmpty()) {
for (Map<String, INDArray> m : placeholdersPerThread.values()) { for (Map<String, INDArray> m : placeholdersPerThread.values()) {
m.remove(n); m.remove(n);
@ -3183,8 +3184,8 @@ public class SameDiff extends SDBaseOps {
INDArray arr = variable.getArr(); INDArray arr = variable.getArr();
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads variablesArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
constantArrays.remove(n); constantArrays.removeArray(n);
if (!placeholdersPerThread.isEmpty()) { if (!placeholdersPerThread.isEmpty()) {
for (Map<String, INDArray> m : placeholdersPerThread.values()) { for (Map<String, INDArray> m : placeholdersPerThread.values()) {
m.remove(n); m.remove(n);
@ -3260,16 +3261,14 @@ public class SameDiff extends SDBaseOps {
switch (v.getVariableType()) { switch (v.getVariableType()) {
case VARIABLE: case VARIABLE:
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey()); INDArray arr = variablesArrays.removeArray(e.getKey());
INDArray arr = dl.get();
INDArray newArr = arr.castTo(d); INDArray newArr = arr.castTo(d);
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads variablesArrays.setArray(e.getKey(), newArr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break; break;
case CONSTANT: case CONSTANT:
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey()); INDArray arr2 = constantArrays.removeArray(e.getKey());
INDArray arr2 = dl2.get();
INDArray newArr2 = arr2.castTo(d); INDArray newArr2 = arr2.castTo(d);
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads constantArrays.setArray(e.getKey(), newArr2); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break; break;
case PLACEHOLDER: case PLACEHOLDER:
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId()); Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());
@ -3409,14 +3408,12 @@ public class SameDiff extends SDBaseOps {
variables.remove(from); variables.remove(from);
variables.put(to, v); variables.put(to, v);
if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.containsKey(from)){ if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.hasArray(from)){
DeviceLocalNDArray dl = constantArrays.remove(from); constantArrays.rename(from, to);
constantArrays.put(to, dl);
} }
if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.containsKey(from)){ if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.hasArray(from)){
DeviceLocalNDArray dl = variablesArrays.remove(from); variablesArrays.rename(from, to);
variablesArrays.put(to, dl);
} }
if(v.getVariable().getVariableType() == VariableType.PLACEHOLDER ){ if(v.getVariable().getVariableType() == VariableType.PLACEHOLDER ){
@ -4187,6 +4184,8 @@ public class SameDiff extends SDBaseOps {
@Override @Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) { public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); //Training isn't thread safe, no need to use DeviceLocal, even with lazy init
//Propagate graph to this samediff instance which will also contain the backward //Propagate graph to this samediff instance which will also contain the backward
if (SameDiff.this.debugMode) { if (SameDiff.this.debugMode) {
sameDiff.enableDebugMode(); sameDiff.enableDebugMode();

View File

@ -0,0 +1,66 @@
package org.nd4j.autodiff.samediff.array;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
/**
* A simple {@link ArrayHolder} that uses a simple {@code Map<String, INDArray>} internally.
* No thread safety guarantees
*
* @author Alex Black
*/
public class SingleThreadArrayHolder implements ArrayHolder {
private final Map<String, INDArray> map = new HashMap<>();
@Override
public boolean hasArray(@NonNull String name) {
return map.containsKey(name);
}
@Override
public INDArray getArray(@NonNull String name) {
return map.get(name);
}
@Override
public void setArray(@NonNull String name, @NonNull INDArray array) {
map.put(name, array);
}
@Override
public INDArray removeArray(@NonNull String name) {
return map.remove(name);
}
@Override
public int size() {
return map.size();
}
@Override
public void initFrom(ArrayHolder arrayHolder) {
map.clear();
Collection<String> names = arrayHolder.arrayNames();
for (String n : names) {
map.put(n, arrayHolder.getArray(n));
}
}
@Override
public Collection<String> arrayNames() {
return Collections.unmodifiableCollection(map.keySet());
}
@Override
public void rename(String from, String to) {
INDArray arr = map.remove(from);
map.put(to, arr);
}
}

View File

@ -0,0 +1,85 @@
package org.nd4j.autodiff.samediff.array;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* An {@link ArrayHolder} that uses the thread safe {@link DeviceLocalNDArray} internally
*
* @author Alex Black
*/
public class ThreadSafeArrayHolder implements ArrayHolder {
private final Map<String, DeviceLocalNDArray> map = new ConcurrentHashMap<>();
private final boolean lazyInit;
/**
* @param lazyInit If true: use lazy initialization for {@link DeviceLocalNDArray}
*/
public ThreadSafeArrayHolder(boolean lazyInit) {
this.lazyInit = lazyInit;
}
@Override
public boolean hasArray(@NonNull String name) {
return map.containsKey(name);
}
@Override
public INDArray getArray(@NonNull String name) {
return map.get(name).get();
}
@Override
public void setArray(@NonNull String name, @NonNull INDArray array) {
if (array.isView())
array = array.dup(); //Device local doesn't support views
if (!map.containsKey(name)) {
DeviceLocalNDArray dla = new DeviceLocalNDArray(array, lazyInit);
map.put(name, dla);
} else {
DeviceLocalNDArray dla = map.get(name);
dla.update(array);
}
}
@Override
public INDArray removeArray(@NonNull String name) {
DeviceLocalNDArray arr = map.remove(name);
if (arr == null)
return null;
return arr.get();
}
@Override
public int size() {
return map.size();
}
@Override
public void initFrom(ArrayHolder arrayHolder) {
map.clear();
Collection<String> names = arrayHolder.arrayNames();
for (String n : names) {
setArray(n, arrayHolder.getArray(n));
}
}
@Override
public Collection<String> arrayNames() {
return Collections.unmodifiableCollection(map.keySet());
}
@Override
public void rename(@NonNull String from, @NonNull String to) {
DeviceLocalNDArray dl = map.remove(from);
map.put(to, dl);
}
}

View File

@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.ops;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
@ -237,21 +238,36 @@ public class SDRandom extends SDOps {
return uniform(null, min, max, shape); return uniform(null, min, max, shape);
} }
/**
* @see #uniform(String, double, double, SDVariable)
*/
public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) {
return uniform(null, min, max, shape, dataType);
}
/**
* As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output
*/
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
return uniform(name, min, max, shape, null);
}
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution, * Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
* U(min,max)<br> * U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned<br>
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is * See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
* specified as a long[] instead * specified as a long[] instead
* *
* @param name Name of the new SDVariable * @param name Name of the new SDVariable
* @param min Minimum value * @param min Minimum value
* @param max Maximum value. Must satisfy max >= min * @param max Maximum value. Must satisfy max >= min
* @param shape Shape of the new random SDVariable, as a 1D array * @param shape Shape of the new random SDVariable, as a 1D array
* @return New SDVariable * @param dataType Data type of the output array (if null: Float32 output is returned)
* @return New SDVariable, of the specified data type
*/ */
public SDVariable uniform(String name, double min, double max, SDVariable shape) { public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) {
validateInteger("uniform random", shape); validateInteger("uniform random", shape);
SDVariable ret = f().randomUniform(min, max, shape); SDVariable ret = f().randomUniform(min, max, shape, dataType);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }

View File

@ -39,6 +39,7 @@ public class LogSumExp extends DynamicCustomOp {
super(sameDiff, i_v); super(sameDiff, i_v);
if(dimensions != null) { if(dimensions != null) {
addIArgument(dimensions); addIArgument(dimensions);
this.dimensions = dimensions;
} }
addTArgument(keepDims ? 1.0 : 0.0); addTArgument(keepDims ? 1.0 : 0.0);
this.keepDims = keepDims; this.keepDims = keepDims;

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -41,30 +42,47 @@ import java.util.Map;
public class DistributionUniform extends DynamicCustomOp { public class DistributionUniform extends DynamicCustomOp {
private double min = 0.0; private double min = 0.0;
private double max = 1.0; private double max = 1.0;
private DataType dataType;
public DistributionUniform() { public DistributionUniform() {
// //
} }
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max){ public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max) {
this(sd, shape, min, max, null);
}
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max, DataType dataType){
super(null, sd, new SDVariable[]{shape}); super(null, sd, new SDVariable[]{shape});
Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max); Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max);
addTArgument(min, max); Preconditions.checkState(dataType == null || dataType.isNumerical(), "Only numerical datatypes can be used with DistributionUniform - rquested output datatype: %s", dataType);
this.dataType = dataType;
this.min = min;
this.max = max;
addArgs();
} }
public DistributionUniform(INDArray shape, INDArray out, double min, double max){ public DistributionUniform(INDArray shape, INDArray out, double min, double max){
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null); super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
this.min = min;
this.max = max;
} }
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); AttrValue v = attributesForNode.get("dtype");
addArgs(); dataType = TFGraphMapper.convertType(v.getType());
addIArgument(dataType.toInt());
} }
protected void addArgs() { protected void addArgs() {
tArguments.clear();
addTArgument(min, max); addTArgument(min, max);
if(dataType != null){
iArguments.clear();
addIArgument(dataType.toInt());
}
} }
@Override @Override
@ -85,8 +103,10 @@ public class DistributionUniform extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
//Input data type specifies the shape; output data type should be any float //Input data type specifies the shape
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 if(dataType != null){
return Collections.singletonList(dataType);
}
return Collections.singletonList(DataType.FLOAT); return Collections.singletonList(DataType.FLOAT);
} }
} }

View File

@ -29,7 +29,7 @@
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml --> <!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
<cuda.version>10.1</cuda.version> <cuda.version>10.1</cuda.version>
<cudnn.version>7.6</cudnn.version> <cudnn.version>7.6</cudnn.version>
<javacpp-presets.cuda.version>1.5.1</javacpp-presets.cuda.version> <javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
<nd4j.backend>nd4j-cuda-${cuda.version}</nd4j.backend> <nd4j.backend>nd4j-cuda-${cuda.version}</nd4j.backend>
</properties> </properties>

View File

@ -29,7 +29,7 @@
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml --> <!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
<cuda.version>10.1</cuda.version> <cuda.version>10.1</cuda.version>
<cudnn.version>7.6</cudnn.version> <cudnn.version>7.6</cudnn.version>
<javacpp-presets.cuda.version>1.5.1</javacpp-presets.cuda.version> <javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
</properties> </properties>
<build> <build>

View File

@ -382,4 +382,21 @@ public class RandomOpValidation extends BaseOpValidation {
INDArray out = Nd4j.exec(all); INDArray out = Nd4j.exec(all);
assertEquals(x, out); assertEquals(x, out);
} }
@Test
public void testUniformDtype(){
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
SameDiff sd = SameDiff.create();
SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));
SDVariable out = sd.random.uniform(0, 10, shape, t);
INDArray arr = out.eval();
assertEquals(t, arr.dataType());
double min = arr.minNumber().doubleValue();
double max = arr.maxNumber().doubleValue();
double mean = arr.meanNumber().doubleValue();
assertEquals(0, min, 0.5);
assertEquals(10, max, 0.5);
assertEquals(5.5, mean, 1);
}
}
} }

View File

@ -1970,4 +1970,24 @@ public class TransformOpValidation extends BaseOpValidation {
INDArray log = Transforms.log(sum); INDArray log = Transforms.log(sum);
assertEquals(log, out); assertEquals(log, out);
} }
@Test
public void testLogSumExp2(){
for( int dim=0; dim<=2; dim++ ) {
Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var(inputArr);
SDVariable lse = sd.math().logSumExp(in, dim);
INDArray exp = Transforms.exp(inputArr, true);
INDArray sum = exp.sum(dim);
INDArray log = Transforms.log(sum);
OpValidation.validate(new TestCase(sd)
.expectedOutput(lse.name(), log)
.gradientCheck(true));
}
}
} }

View File

@ -2311,7 +2311,6 @@ public class SameDiffTests extends BaseNd4jTest {
SDVariable loss = out.std("out", true); SDVariable loss = out.std("out", true);
INDArray outArr = loss.eval(); INDArray outArr = loss.eval();
// sd.execBackwards(Collections.emptyMap());
Map<String,INDArray> grads = sd.calculateGradients(null, in.name(), w.name(), out.name()); Map<String,INDArray> grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
Map<String, INDArray> origGrad = new HashMap<>(); Map<String, INDArray> origGrad = new HashMap<>();
@ -2321,7 +2320,6 @@ public class SameDiffTests extends BaseNd4jTest {
in.getArr().assign(Nd4j.rand(in.getArr().shape())); in.getArr().assign(Nd4j.rand(in.getArr().shape()));
INDArray outArr2 = loss.eval(); INDArray outArr2 = loss.eval();
// sd.execBackwards(Collections.emptyMap());
grads = sd.calculateGradients(null, in.name(), w.name(), out.name()); grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
assertNotEquals(outArr, outArr2); assertNotEquals(outArr, outArr2);
@ -2641,8 +2639,7 @@ public class SameDiffTests extends BaseNd4jTest {
.expectedOutput("out", out) .expectedOutput("out", out)
.gradientCheck(true)); .gradientCheck(true));
assertNull(err, err); assertNull(err);
} }
@Test @Test

18
pom.xml
View File

@ -288,21 +288,21 @@
<javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 --> <javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 -->
<javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties> <javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties>
<javacpp.version>1.5.1-1</javacpp.version> <javacpp.version>1.5.2</javacpp.version>
<javacpp-presets.version>1.5.1</javacpp-presets.version> <javacpp-presets.version>1.5.2</javacpp-presets.version>
<javacv.version>1.5.1</javacv.version> <javacv.version>1.5.2</javacv.version>
<python.version>3.7.3</python.version> <python.version>3.7.5</python.version>
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version> <cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
<openblas.version>0.3.6</openblas.version> <openblas.version>0.3.7</openblas.version>
<mkl.version>2019.4</mkl.version> <mkl.version>2019.5</mkl.version>
<opencv.version>4.1.0</opencv.version> <opencv.version>4.1.2</opencv.version>
<ffmpeg.version>4.1.3</ffmpeg.version> <ffmpeg.version>4.2.1</ffmpeg.version>
<leptonica.version>1.78.0</leptonica.version> <leptonica.version>1.78.0</leptonica.version>
<hdf5.version>1.10.5</hdf5.version> <hdf5.version>1.10.5</hdf5.version>
<ale.version>0.6.0</ale.version> <ale.version>0.6.0</ale.version>
<tensorflow.version>1.14.0</tensorflow.version> <tensorflow.version>1.15.0</tensorflow.version>
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version> <tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
<commons-compress.version>1.18</commons-compress.version> <commons-compress.version>1.18</commons-compress.version>