SameDiff multi-threaded inference (#263)
* #8682 Don't log openmp BLAS threads for CUDA Signed-off-by: Alex Black <blacka101@gmail.com> * #8654 Add SameDiff multi-threaded tests Signed-off-by: Alex Black <blacka101@gmail.com> * Switching to op context for SameDiff exec Signed-off-by: Alex Black <blacka101@gmail.com> * Next steps Signed-off-by: Alex Black <blacka101@gmail.com> * Most back to passing Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Better tests, test refactoring Signed-off-by: Alex Black <blacka101@gmail.com> * Small tweak Signed-off-by: Alex Black <blacka101@gmail.com> * Code duplication reduction Signed-off-by: Alex Black <blacka101@gmail.com> * More code deduplication Signed-off-by: Alex Black <blacka101@gmail.com> * CUDA fixes Signed-off-by: Alex Black <blacka101@gmail.com> * More CUDA fixes Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Small fix Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S small fixes Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
b23ebee432
commit
f79207033b
|
@ -31,6 +31,7 @@ import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnore;
|
import org.nd4j.shade.jackson.annotation.JsonIgnore;
|
||||||
|
@ -708,6 +709,10 @@ public abstract class DifferentialFunction {
|
||||||
throw new ND4JIllegalStateException("calculateOutputShape() method leaked out for [" + this.opName() + "]");
|
throw new ND4JIllegalStateException("calculateOutputShape() method leaked out for [" + this.opName() + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
|
||||||
|
throw new ND4JIllegalStateException("calculateOutputShape(OpContext) method leaked out for [" + this.opName() + "]");
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the data types for the output arrays.
|
* Calculate the data types for the output arrays.
|
||||||
* Though datatypes can also be inferred from {@link #calculateOutputShape()}, this method differs in that it does not
|
* Though datatypes can also be inferred from {@link #calculateOutputShape()}, this method differs in that it does not
|
||||||
|
|
|
@ -5,6 +5,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -60,12 +61,12 @@ public abstract class BaseListener implements Listener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
|
||||||
//No op
|
//No op
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
//No op
|
//No op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -104,7 +105,7 @@ public interface Listener {
|
||||||
* @param at Current iteration/epoch etc
|
* @param at Current iteration/epoch etc
|
||||||
* @param op Operation that has just been executed
|
* @param op Operation that has just been executed
|
||||||
*/
|
*/
|
||||||
void preOpExecution(SameDiff sd, At at, SameDiffOp op);
|
void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called at the end of each operation execution<br>
|
* Called at the end of each operation execution<br>
|
||||||
|
@ -117,7 +118,7 @@ public interface Listener {
|
||||||
* @param op Operation that has just been executed
|
* @param op Operation that has just been executed
|
||||||
* @param outputs The output arrays for the just-executed operation
|
* @param outputs The output arrays for the just-executed operation
|
||||||
*/
|
*/
|
||||||
void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs);
|
void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called when any activation becomes available.
|
* Called when any activation becomes available.
|
||||||
|
@ -127,7 +128,7 @@ public interface Listener {
|
||||||
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
|
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
|
||||||
* It is guaranteed to be called for variables from requiredVariables().<br>
|
* It is guaranteed to be called for variables from requiredVariables().<br>
|
||||||
* <br>
|
* <br>
|
||||||
* Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, INDArray[])} -
|
* Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, OpContext, INDArray[])} -
|
||||||
* both contain the same information/arrays
|
* both contain the same information/arrays
|
||||||
*
|
*
|
||||||
* @param sd The SameDiff instance
|
* @param sd The SameDiff instance
|
||||||
|
|
|
@ -9,6 +9,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -44,7 +45,7 @@ public class ArraySavingListener extends BaseListener {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
List<String> outNames = op.getOutputsOfOp();
|
List<String> outNames = op.getOutputsOfOp();
|
||||||
for(int i=0; i<outputs.length; i++ ){
|
for(int i=0; i<outputs.length; i++ ){
|
||||||
String filename = (count++) + "_" + outNames.get(i).replaceAll("/", "__") + ".bin";
|
String filename = (count++) + "_" + outNames.get(i).replaceAll("/", "__") + ".bin";
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.ops.ScalarOp;
|
import org.nd4j.linalg.api.ops.ScalarOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -77,7 +78,7 @@ public class ExecDebuggingListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
|
||||||
if(lastIter != at.iteration()){
|
if(lastIter != at.iteration()){
|
||||||
lastIter = at.iteration();
|
lastIter = at.iteration();
|
||||||
stepThisIter = 0;
|
stepThisIter = 0;
|
||||||
|
|
|
@ -9,6 +9,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
@ -79,12 +80,12 @@ public class OpBenchmarkListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
|
||||||
start = System.currentTimeMillis();
|
start = System.currentTimeMillis();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
long now = System.currentTimeMillis();
|
long now = System.currentTimeMillis();
|
||||||
|
|
||||||
if (mode == Mode.SINGLE_ITER_PRINT && printActive && (now-start) > this.minRuntime) {
|
if (mode == Mode.SINGLE_ITER_PRINT && printActive && (now-start) > this.minRuntime) {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.nd4j.graph.UIInfoType;
|
||||||
import org.nd4j.graph.UIStaticInfoRecord;
|
import org.nd4j.graph.UIStaticInfoRecord;
|
||||||
import org.nd4j.graph.ui.LogFileWriter;
|
import org.nd4j.graph.ui.LogFileWriter;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
@ -410,7 +411,7 @@ public class UIListener extends BaseListener {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
|
|
||||||
|
|
||||||
//Do training set evaluation, if required
|
//Do training set evaluation, if required
|
||||||
|
|
|
@ -30,6 +30,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.AtomicBoolean;
|
import org.nd4j.linalg.primitives.AtomicBoolean;
|
||||||
|
@ -192,7 +193,7 @@ public class ProfilingListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
|
||||||
if (logActive) {
|
if (logActive) {
|
||||||
opStartNano = System.nanoTime();
|
opStartNano = System.nanoTime();
|
||||||
|
|
||||||
|
@ -202,7 +203,7 @@ public class ProfilingListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
if (logActive) {
|
if (logActive) {
|
||||||
long now = System.nanoTime();
|
long now = System.nanoTime();
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,6 @@ import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||||
* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
|
* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SameDiff extends SDBaseOps {
|
public class SameDiff extends SDBaseOps {
|
||||||
protected static final String GRAD_FN_KEY = "grad";
|
protected static final String GRAD_FN_KEY = "grad";
|
||||||
|
@ -1232,25 +1231,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a new SameDiff instance from an existing instance.
|
|
||||||
* Note that state (variables and functions) is shared between the two SameDiff instance
|
|
||||||
*
|
|
||||||
* @param originalSameDiff Original SameDiff instance
|
|
||||||
* @return Copy
|
|
||||||
*/
|
|
||||||
public static SameDiff create(SameDiff originalSameDiff) {
|
|
||||||
SameDiff ret = SameDiff.builder()
|
|
||||||
.sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances)
|
|
||||||
.build();
|
|
||||||
ret.variables.putAll(originalSameDiff.variables);
|
|
||||||
//ensuring proper sameDiff reference
|
|
||||||
DifferentialFunctionFactory differentialFunctionFactory = new DifferentialFunctionFactory(ret);
|
|
||||||
ret.functionFactory = differentialFunctionFactory;
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.listeners.At;
|
import org.nd4j.autodiff.listeners.At;
|
||||||
import org.nd4j.autodiff.listeners.Listener;
|
import org.nd4j.autodiff.listeners.Listener;
|
||||||
|
@ -46,6 +47,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
@ -65,7 +67,7 @@ import java.util.*;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
public class InferenceSession extends AbstractSession<INDArray, Pair<SameDiffOp,OpContext>> {
|
||||||
private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" +
|
private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" +
|
||||||
"Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";
|
"Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";
|
||||||
|
|
||||||
|
@ -83,6 +85,8 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
private IdentityDependencyTracker<INDArray, Dep> arrayUseTracker = new IdentityDependencyTracker<>();
|
private IdentityDependencyTracker<INDArray, Dep> arrayUseTracker = new IdentityDependencyTracker<>();
|
||||||
|
|
||||||
|
|
||||||
|
private Map<String,OpContext> opContexts = new HashMap<>();
|
||||||
|
|
||||||
public InferenceSession(@NonNull SameDiff sameDiff) {
|
public InferenceSession(@NonNull SameDiff sameDiff) {
|
||||||
super(sameDiff);
|
super(sameDiff);
|
||||||
mmgr = new ArrayCacheMemoryMgr();
|
mmgr = new ArrayCacheMemoryMgr();
|
||||||
|
@ -204,18 +208,19 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
public INDArray[] getOutputs(Pair<SameDiffOp,OpContext> opPair, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
|
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
|
||||||
|
SameDiffOp op = opPair.getFirst();
|
||||||
at.setFrameIter(outputFrameIter);
|
at.setFrameIter(outputFrameIter);
|
||||||
if (listeners != null && listeners.size() > 0) {
|
if (listeners != null && listeners.size() > 0) {
|
||||||
SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName());
|
SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName());
|
||||||
for (Listener l : listeners) {
|
for (Listener l : listeners) {
|
||||||
if (l.isActive(at.operation()))
|
if (l.isActive(at.operation()))
|
||||||
l.preOpExecution(sameDiff, at, sdOp);
|
l.preOpExecution(sameDiff, at, sdOp, opPair.getSecond());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
|
INDArray[] out = doExec(op.getOp(), opPair.getRight(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
|
||||||
|
|
||||||
if (log.isTraceEnabled()) {
|
if (log.isTraceEnabled()) {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
|
@ -246,7 +251,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
l.opExecution(sameDiff, at, batch, op, out);
|
l.opExecution(sameDiff, at, batch, op, opPair.getSecond(), out);
|
||||||
|
|
||||||
for (String varName : namedOuts.keySet()) {
|
for (String varName : namedOuts.keySet()) {
|
||||||
l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName));
|
l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName));
|
||||||
|
@ -255,6 +260,8 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
op.getOp().clearArrays();
|
op.getOp().clearArrays();
|
||||||
|
if(opPair.getSecond() != null)
|
||||||
|
opPair.getSecond().purge();
|
||||||
|
|
||||||
|
|
||||||
//Record array uses for memory management/deallocation
|
//Record array uses for memory management/deallocation
|
||||||
|
@ -343,7 +350,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
public INDArray[] doExec(DifferentialFunction op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
public INDArray[] doExec(DifferentialFunction op, OpContext opContext, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||||
Set<String> constAndPhInputs) {
|
Set<String> constAndPhInputs) {
|
||||||
|
|
||||||
int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size())
|
int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size())
|
||||||
|
@ -467,31 +474,31 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
return new INDArray[]{out};
|
return new INDArray[]{out};
|
||||||
} else if (op instanceof Assert) {
|
} else if (op instanceof Assert) {
|
||||||
Assert a = (Assert)op;
|
Assert a = (Assert)op;
|
||||||
boolean condition = a.getInputArgument(0).getDouble(0) != 0.0;
|
boolean condition = opContext.getInputArray(0).getDouble(0) != 0.0;
|
||||||
if(!condition){
|
if(!condition){
|
||||||
//Assertion failed
|
//Assertion failed
|
||||||
String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution";
|
String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution";
|
||||||
if(a.numInputArguments() >= 3) {
|
if(a.numInputArguments() >= 3) {
|
||||||
INDArray msg = a.getInputArgument(2);
|
INDArray msg = opContext.getInputArray(2);
|
||||||
if (msg != null && msg.dataType() == DataType.UTF8) {
|
if (msg != null && msg.dataType() == DataType.UTF8) {
|
||||||
s += ": " + msg.getString(0);
|
s += ": " + msg.getString(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(a.numInputArguments() >= 5){
|
if(a.numInputArguments() >= 5){
|
||||||
INDArray arr = a.getInputArgument(4);
|
INDArray arr = opContext.getInputArray(4);
|
||||||
s += "\n" + arr;
|
s += "\n" + arr;
|
||||||
}
|
}
|
||||||
throw new IllegalStateException(s);
|
throw new IllegalStateException(s);
|
||||||
}
|
}
|
||||||
return ((Assert) op).outputArguments().toArray(new INDArray[0]);
|
return opContext.getOutputArrays().toArray(new INDArray[0]);
|
||||||
} else if (op instanceof CustomOp) {
|
} else if (op instanceof CustomOp) {
|
||||||
CustomOp c = (CustomOp) op;
|
CustomOp c = (CustomOp) op;
|
||||||
Nd4j.exec(c);
|
Nd4j.exec(c, opContext);
|
||||||
return c.outputArguments().toArray(new INDArray[0]);
|
return opContext.getOutputArrays().toArray(new INDArray[0]);
|
||||||
} else if (op instanceof Op) {
|
} else if (op instanceof Op) {
|
||||||
Op o = (Op) op;
|
Op o = (Op) op;
|
||||||
Nd4j.exec(o);
|
Nd4j.exec(o, opContext);
|
||||||
return new INDArray[]{o.z()};
|
return new INDArray[]{opContext.getOutputArray(0)};
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
|
throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
|
||||||
}
|
}
|
||||||
|
@ -774,7 +781,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SameDiffOp getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
public Pair<SameDiffOp,OpContext> getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||||
Set<String> constAndPhInputs, Map<String, INDArray> placeholderValues, Set<String> allReqVariables) {
|
Set<String> constAndPhInputs, Map<String, INDArray> placeholderValues, Set<String> allReqVariables) {
|
||||||
SameDiffOp sdo = sameDiff.getOps().get(opName);
|
SameDiffOp sdo = sameDiff.getOps().get(opName);
|
||||||
DifferentialFunction df = sdo.getOp();
|
DifferentialFunction df = sdo.getOp();
|
||||||
|
@ -786,7 +793,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration ||
|
if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration ||
|
||||||
df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) {
|
df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) {
|
||||||
//Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case
|
//Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case
|
||||||
return sdo;
|
return new Pair<>(sdo, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Infer the args based on the inputs (variable + frame + iteration)
|
//Infer the args based on the inputs (variable + frame + iteration)
|
||||||
|
@ -839,24 +846,39 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
//TODO let's find a way to use in-place modification for loops where possible to reduce memory requirements
|
//TODO let's find a way to use in-place modification for loops where possible to reduce memory requirements
|
||||||
boolean isLoop = !frameIter.getFrame().equals(OUTER_FRAME) && frameIter.getIteration() > 0;
|
boolean isLoop = !frameIter.getFrame().equals(OUTER_FRAME) && frameIter.getIteration() > 0;
|
||||||
|
|
||||||
|
OpContext oc = opContexts.get(opName);
|
||||||
|
if(oc == null){
|
||||||
|
oc = Nd4j.getExecutioner().buildContext();
|
||||||
|
opContexts.put(opName, oc);
|
||||||
|
}
|
||||||
|
|
||||||
if (df instanceof CustomOp) {
|
if (df instanceof CustomOp) {
|
||||||
DynamicCustomOp customOp = (DynamicCustomOp) df;
|
DynamicCustomOp customOp = (DynamicCustomOp) df;
|
||||||
if (args != null) {
|
if (args != null) {
|
||||||
customOp.setInputArguments(args);
|
oc.setInputArrays(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (df instanceof Identity) {
|
if (df instanceof Identity) {
|
||||||
//We don't need to allocate an output array for Identity, we pass through the input array without copying
|
//We don't need to allocate an output array for Identity, we pass through the input array without copying
|
||||||
return sdo;
|
return new Pair<>(sdo, oc);
|
||||||
}
|
}
|
||||||
|
|
||||||
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
|
if(customOp.numIArguments() > 0)
|
||||||
|
oc.setIArguments(customOp.iArgs());
|
||||||
|
if(customOp.numDArguments() > 0)
|
||||||
|
oc.setDArguments(customOp.dArgs());
|
||||||
|
if(customOp.numTArguments() > 0)
|
||||||
|
oc.setTArguments(customOp.tArgs());
|
||||||
|
if(customOp.numBArguments() > 0)
|
||||||
|
oc.setBArguments(customOp.bArgs());
|
||||||
|
|
||||||
|
|
||||||
|
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape(oc);
|
||||||
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
|
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
|
||||||
String[] outNames = df.outputVariablesNames();
|
String[] outNames = df.outputVariablesNames();
|
||||||
Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" +
|
Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" +
|
||||||
" with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length);
|
" with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length);
|
||||||
for (int i = 0; i < outShape.size(); i++) {
|
for (int i = 0; i < outShape.size(); i++) {
|
||||||
INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i));
|
|
||||||
LongShapeDescriptor reqShape = outShape.get(i);
|
LongShapeDescriptor reqShape = outShape.get(i);
|
||||||
|
|
||||||
//Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
//Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
||||||
|
@ -870,7 +892,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
//Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
|
//Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
|
||||||
boolean isOutput = allReqVariables.contains(outNames[i]);
|
boolean isOutput = allReqVariables.contains(outNames[i]);
|
||||||
INDArray out = mmgr.allocate(isOutput, reqShape);
|
INDArray out = mmgr.allocate(isOutput, reqShape);
|
||||||
customOp.setOutputArgument(i, out);
|
oc.setOutputArray(i, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (df instanceof Op) {
|
} else if (df instanceof Op) {
|
||||||
|
@ -909,9 +931,9 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (args != null && args.length > 0) {
|
if (args != null && args.length > 0) {
|
||||||
op.setX(args[0]);
|
oc.setInputArray(0, args[0]);
|
||||||
if (args.length == 2 && !axisArg)
|
if (args.length == 2 && !axisArg)
|
||||||
op.setY(args[1]);
|
oc.setInputArray(1, args[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -920,18 +942,18 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]);
|
boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]);
|
||||||
if (emptyReduce) {
|
if (emptyReduce) {
|
||||||
//Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
|
//Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
|
||||||
INDArray z = mmgr.allocate(false, op.x().dataType(), op.x().shape());
|
INDArray z = mmgr.allocate(false, oc.getInputArray(0).dataType(), oc.getInputArray(0).shape());
|
||||||
op.setZ(z);
|
oc.setOutputArray(0, z);
|
||||||
} else {
|
} else {
|
||||||
List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape();
|
List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape(oc);
|
||||||
Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass());
|
Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass());
|
||||||
LongShapeDescriptor lsd = outputShape.get(0);
|
LongShapeDescriptor lsd = outputShape.get(0);
|
||||||
INDArray z = mmgr.allocate(isOutput, lsd);
|
INDArray z = mmgr.allocate(isOutput, lsd);
|
||||||
op.setZ(z);
|
oc.setOutputArray(0, z);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sdo;
|
return new Pair<>(sdo, oc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,12 @@ import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
import org.nd4j.autodiff.samediff.VariableType;
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.learning.GradientUpdater;
|
import org.nd4j.linalg.learning.GradientUpdater;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.primitives.AtomicDouble;
|
import org.nd4j.linalg.primitives.AtomicDouble;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -135,10 +137,11 @@ public class TrainingSession extends InferenceSession {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
public INDArray[] getOutputs(Pair<SameDiffOp, OpContext> opPair, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
|
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
|
||||||
//Get outputs from InferenceSession
|
//Get outputs from InferenceSession
|
||||||
INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables);
|
INDArray[] out = super.getOutputs(opPair, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables);
|
||||||
|
SameDiffOp op = opPair.getFirst();
|
||||||
|
|
||||||
List<String> outputs = op.getOutputsOfOp();
|
List<String> outputs = op.getOutputsOfOp();
|
||||||
int outIdx = 0;
|
int outIdx = 0;
|
||||||
|
|
|
@ -12,6 +12,8 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -36,7 +38,7 @@ public class ActivationGradientCheckListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener");
|
Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener");
|
||||||
Preconditions.checkState(eps != 0.0, "Epsilon has not been set");
|
Preconditions.checkState(eps != 0.0, "Epsilon has not been set");
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,10 @@ import org.nd4j.linalg.api.ops.Op;
|
||||||
|
|
||||||
import java.security.MessageDigest;
|
import java.security.MessageDigest;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
||||||
public class NonInplaceValidationListener extends BaseListener {
|
public class NonInplaceValidationListener extends BaseListener {
|
||||||
|
@ -33,25 +36,25 @@ public class NonInplaceValidationListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext oc) {
|
||||||
if(op.getOp().isInPlace()){
|
if(op.getOp().isInPlace()){
|
||||||
//Don't check inplace op
|
//Don't check inplace op
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if(op.getOp() instanceof Op){
|
if(op.getOp() instanceof Op){
|
||||||
Op o = (Op)op.getOp();
|
Op o = (Op)op.getOp();
|
||||||
if(o.x() == null){
|
if(oc.getInputArray(0) == null){
|
||||||
//No input op
|
//No input op
|
||||||
return;
|
return;
|
||||||
} else if(o.y() == null){
|
} else if(oc.getInputArray(1) == null){
|
||||||
opInputsOrig = new INDArray[]{o.x()};
|
opInputsOrig = new INDArray[]{oc.getInputArray(0)};
|
||||||
opInputs = new INDArray[]{o.x().dup()};
|
opInputs = new INDArray[]{oc.getInputArray(0).dup()};
|
||||||
} else {
|
} else {
|
||||||
opInputsOrig = new INDArray[]{o.x(), o.y()};
|
opInputsOrig = new INDArray[]{oc.getInputArray(0), oc.getInputArray(1)};
|
||||||
opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
|
opInputs = new INDArray[]{oc.getInputArray(0).dup(), oc.getInputArray(1).dup()};
|
||||||
}
|
}
|
||||||
} else if(op.getOp() instanceof DynamicCustomOp){
|
} else if(op.getOp() instanceof DynamicCustomOp){
|
||||||
val arr = ((DynamicCustomOp) op.getOp()).inputArguments();
|
List<INDArray> arr = oc.getInputArrays(); // ((DynamicCustomOp) op.getOp()).inputArguments();
|
||||||
opInputs = new INDArray[arr.size()];
|
opInputs = new INDArray[arr.size()];
|
||||||
opInputsOrig = new INDArray[arr.size()];
|
opInputsOrig = new INDArray[arr.size()];
|
||||||
for( int i=0; i<arr.size(); i++ ){
|
for( int i=0; i<arr.size(); i++ ){
|
||||||
|
@ -64,7 +67,7 @@ public class NonInplaceValidationListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
if(op.getOp().isInPlace()){
|
if(op.getOp().isInPlace()){
|
||||||
//Don't check inplace op
|
//Don't check inplace op
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -93,6 +93,12 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,11 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
return fastpath_i;
|
return fastpath_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numIArguments() {
|
||||||
|
return fastpath_i.size();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setTArguments(double... arguments) {
|
public void setTArguments(double... arguments) {
|
||||||
fastpath_t.clear();
|
fastpath_t.clear();
|
||||||
|
@ -67,6 +72,11 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
return fastpath_t;
|
return fastpath_t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numTArguments() {
|
||||||
|
return fastpath_t.size();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setBArguments(boolean... arguments) {
|
public void setBArguments(boolean... arguments) {
|
||||||
fastpath_b.clear();
|
fastpath_b.clear();
|
||||||
|
@ -79,6 +89,11 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
return fastpath_b;
|
return fastpath_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numBArguments() {
|
||||||
|
return fastpath_b.size();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setDArguments(DataType... arguments) {
|
public void setDArguments(DataType... arguments) {
|
||||||
fastpath_d.clear();
|
fastpath_d.clear();
|
||||||
|
@ -91,6 +106,11 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
return fastpath_d;
|
return fastpath_d;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numDArguments() {
|
||||||
|
return fastpath_d.size();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputArray(int index, @NonNull INDArray array) {
|
public void setInputArray(int index, @NonNull INDArray array) {
|
||||||
fastpath_in.put(index, array);
|
fastpath_in.put(index, array);
|
||||||
|
@ -110,6 +130,16 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numInputArguments() {
|
||||||
|
return fastpath_in.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getInputArray(int idx) {
|
||||||
|
return fastpath_in.get(idx);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<INDArray> getOutputArrays() {
|
public List<INDArray> getOutputArrays() {
|
||||||
val result = new ArrayList<INDArray>();
|
val result = new ArrayList<INDArray>();
|
||||||
|
@ -129,6 +159,15 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
fastpath_out.put(index, array);
|
fastpath_out.put(index, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getOutputArray(int i) {
|
||||||
|
return fastpath_out.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numOutputArguments() {
|
||||||
|
return fastpath_out.size();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputArrays(@NonNull List<INDArray> arrays) {
|
public void setInputArrays(@NonNull List<INDArray> arrays) {
|
||||||
|
|
|
@ -72,19 +72,33 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes() {
|
public DataType resultType(OpContext oc) {
|
||||||
if (y() != null)
|
return DataType.BOOL;
|
||||||
Preconditions.checkArgument(x().dataType() == y().dataType(),"Op.X type must be the same as Op.Y:" +
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean validateDataTypes(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
if (y != null)
|
||||||
|
Preconditions.checkArgument(x.dataType() == y.dataType(),"Op.X type must be the same as Op.Y:" +
|
||||||
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
|
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
|
||||||
|
|
||||||
if (z() != null)
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
Preconditions.checkArgument(z().isB(), "Op.X type must be bool: got type %s for op %s", x.dataType(), getClass());
|
if (z != null)
|
||||||
|
Preconditions.checkArgument(z.isB(), "Op.Z type must be bool: got type %s for op %s", z.dataType(), getClass());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
|
|
||||||
|
|
|
@ -90,27 +90,43 @@ public abstract class BaseReduceFloatOp extends BaseReduceOp implements ReduceFl
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType resultType() {
|
public DataType resultType() {
|
||||||
if (this.x() != null && this.x().isR())
|
return resultType(null);
|
||||||
return this.x().dataType();
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType resultType(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
if (x != null && x.isR())
|
||||||
|
return x.dataType();
|
||||||
|
|
||||||
return Nd4j.defaultFloatingPointType();
|
return Nd4j.defaultFloatingPointType();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes() {
|
public boolean validateDataTypes(OpContext oc) {
|
||||||
if (y() != null)
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
Preconditions.checkArgument(x().dataType() == y().dataType(),
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
"Op.X [%s] type must be the same as Op.Y [%s] for op %s: x.shape=%ndShape, y.shape=%ndShape", x().dataType(),
|
if (y != null)
|
||||||
y().dataType(), getClass().getName(), x(), y() );
|
Preconditions.checkArgument(x.dataType() == y.dataType(),
|
||||||
|
"Op.X [%s] type must be the same as Op.Y [%s] for op %s: x.shape=%ndShape, y.shape=%ndShape", x.dataType(),
|
||||||
|
y.dataType(), getClass().getName(), x, y );
|
||||||
|
|
||||||
if (z() != null)
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
Preconditions.checkArgument(z().isR(),"Op.Z (result array) must be one of floating types: z datatype = %s", z().dataType());
|
if (z != null)
|
||||||
|
Preconditions.checkArgument(z.isR(),"Op.Z (result array) must be one of floating types: z datatype = %s", z.dataType());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
|
|
||||||
|
|
|
@ -69,19 +69,33 @@ public abstract class BaseReduceLongOp extends BaseReduceOp implements ReduceLon
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes() {
|
public DataType resultType(OpContext oc) {
|
||||||
if (y() != null)
|
return DataType.LONG;
|
||||||
Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X type must be the same as Op.Y:" +
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean validateDataTypes(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
if (y != null)
|
||||||
|
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X type must be the same as Op.Y:" +
|
||||||
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
|
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
|
||||||
|
|
||||||
if (z() != null)
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
Preconditions.checkArgument( z().dataType() == DataType.LONG,"Op.Z must be long: has type %s for op %s", z().dataType(), getClass());
|
if (z != null)
|
||||||
|
Preconditions.checkArgument( z.dataType() == DataType.LONG,"Op.Z must be long: has type %s for op %s", z.dataType(), getClass());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
|
|
||||||
|
|
|
@ -77,26 +77,42 @@ public abstract class BaseReduceSameOp extends BaseReduceOp implements ReduceSam
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes() {
|
public DataType resultType(OpContext oc){
|
||||||
if (y() != null)
|
return oc.getInputArray(0).dataType();
|
||||||
Preconditions.checkArgument(x().dataType() == y().dataType(),"Op.X type must be the same as Op.Y type:" +
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean validateDataTypes(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
if (y != null)
|
||||||
|
Preconditions.checkArgument(x.dataType() == y.dataType(),"Op.X type must be the same as Op.Y type:" +
|
||||||
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
|
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
|
||||||
|
|
||||||
if (z() != null)
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must be the same as Op.X type. Op.X.datatype=%s, " +
|
if (z != null)
|
||||||
"Op.Z.datatype=%s", x().dataType(), z.dataType());
|
Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must be the same as Op.X type. Op.X.datatype=%s, " +
|
||||||
|
"Op.Z.datatype=%s", x.dataType(), z.dataType());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
|
|
||||||
//Calculate reduction shape. Note that reduction on scalar - returns a scalar
|
//Calculate reduction shape. Note that reduction on scalar - returns a scalar
|
||||||
long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims());
|
long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims());
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, this.resultType()));
|
DataType rt = oc != null ? resultType(oc) : resultType();
|
||||||
|
return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, rt));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -98,6 +98,12 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
|
|
||||||
|
|
|
@ -115,6 +115,13 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
|
||||||
val ret = new ArrayList<LongShapeDescriptor>(1);
|
val ret = new ArrayList<LongShapeDescriptor>(1);
|
||||||
|
|
||||||
long[] s;
|
long[] s;
|
||||||
|
|
|
@ -89,7 +89,12 @@ public abstract class BaseTransformAnyOp extends BaseTransformOp implements Tran
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes(boolean experimentalMode) {
|
public DataType resultType(OpContext oc) {
|
||||||
|
return oc.getInputArray(0).dataType();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -88,20 +88,34 @@ public abstract class BaseTransformBoolOp extends BaseTransformOp implements Tra
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes(boolean experimentalMode) {
|
public DataType resultType(OpContext oc) {
|
||||||
|
return DataType.BOOL;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
if (y() != null)
|
if (y() != null)
|
||||||
Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X must be the same type as Op.Y: " +
|
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must be the same type as Op.Y: " +
|
||||||
"x.datatype=%s, y.datatype=%s", x().dataType(), y.dataType());
|
"x.datatype=%s, y.datatype=%s", x.dataType(), y.dataType());
|
||||||
|
|
||||||
|
|
||||||
if (z() != null)
|
if (z != null)
|
||||||
Preconditions.checkArgument(z().isB(),"Op.Z type must be bool: z.datatype=%s for op %s", z().dataType(), getClass());
|
Preconditions.checkArgument(z.isB(),"Op.Z type must be bool: z.datatype=%s for op %s", z.dataType(), getClass());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), DataType.BOOL));
|
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), DataType.BOOL));
|
||||||
|
|
|
@ -72,19 +72,37 @@ public abstract class BaseTransformFloatOp extends BaseTransformOp implements Tr
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes(boolean experimentalMode) {
|
public DataType resultType(OpContext oc) {
|
||||||
if (y() != null && !experimentalMode) {
|
if (oc.getInputArray(0) != null && oc.getInputArray(0).isR())
|
||||||
|
return oc.getInputArray(0).dataType();
|
||||||
|
|
||||||
|
return Nd4j.defaultFloatingPointType();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
|
|
||||||
|
if (y != null && !experimentalMode) {
|
||||||
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y");
|
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (z() != null)
|
if (z != null)
|
||||||
Preconditions.checkArgument(z().isR(),"Op.Z must be one of floating types: z.datatype=%s for op %s", z().dataType(), getClass());
|
Preconditions.checkArgument(z.isR(),"Op.Z must be one of floating types: z.datatype=%s for op %s", z.dataType(), getClass());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.isR() ? x.dataType() : Nd4j.defaultFloatingPointType()));
|
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.isR() ? x.dataType() : Nd4j.defaultFloatingPointType()));
|
||||||
|
|
|
@ -89,22 +89,36 @@ public abstract class BaseTransformSameOp extends BaseTransformOp implements Tra
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes(boolean experimentalMode) {
|
public DataType resultType(OpContext oc) {
|
||||||
if (y() != null) {
|
return oc.getInputArray(0).dataType();
|
||||||
Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X type must be the same as Op.Y type: x.datatype=%s, y.datatype=%s for op %s",
|
}
|
||||||
x().dataType(), y().dataType(), getClass());
|
|
||||||
|
@Override
|
||||||
|
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
|
if (y != null) {
|
||||||
|
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X type must be the same as Op.Y type: x.datatype=%s, y.datatype=%s for op %s",
|
||||||
|
x.dataType(), y.dataType(), getClass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (z() != null)
|
if (z != null)
|
||||||
Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must be the same as Op.X type: x.datatype=%s, z.datatype=%s for op %s",
|
Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must be the same as Op.X type: x.datatype=%s, z.datatype=%s for op %s",
|
||||||
x().dataType(), z.dataType(), getClass());
|
x.dataType(), z.dataType(), getClass());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
if(x == null)
|
if(x == null)
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
|
|
||||||
|
|
|
@ -76,20 +76,28 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T
|
||||||
return this.x().dataType();
|
return this.x().dataType();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType resultType(OpContext opContext) {
|
||||||
|
return opContext.getInputArray(0).dataType();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes(boolean experimentalMode) {
|
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
|
||||||
Preconditions.checkArgument(x().isR(), "Op.X must be one of floating types: x.datatype=%s for op %s", x().dataType(), getClass());
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
|
Preconditions.checkArgument(x.isR(), "Op.X must be one of floating types: x.datatype=%s for op %s", x.dataType(), getClass());
|
||||||
|
|
||||||
if (y() != null) {
|
if (y != null) {
|
||||||
Preconditions.checkArgument(y().isR(), "Op.Y must be one of floating types: y.datatype=%s for op %s", y().dataType(), getClass());
|
Preconditions.checkArgument(y.isR(), "Op.Y must be one of floating types: y.datatype=%s for op %s", y.dataType(), getClass());
|
||||||
|
|
||||||
if (!experimentalMode)
|
if (!experimentalMode)
|
||||||
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y");
|
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (z() != null)
|
if (z() != null)
|
||||||
Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must have the same type as Op.X: x.datatype=%s, z.datatype=%s for op %s",
|
Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must have the same type as Op.X: x.datatype=%s, z.datatype=%s for op %s",
|
||||||
x.dataType(), z.dataType(), getClass());
|
x.dataType(), z.dataType(), getClass());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -102,6 +110,13 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.dataType()));
|
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.dataType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
if(oc.getInputArray(0) == null)
|
||||||
|
return Collections.emptyList();
|
||||||
|
return Collections.singletonList(LongShapeDescriptor.fromShape(oc.getInputArray(0).shape(), oc.getInputArray(0).dataType()));
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(List<org.nd4j.linalg.api.buffer.DataType> dataTypes){
|
public List<org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(List<org.nd4j.linalg.api.buffer.DataType> dataTypes){
|
||||||
//All strict tranform ops: FP in, FP out
|
//All strict tranform ops: FP in, FP out
|
||||||
|
|
|
@ -108,10 +108,16 @@ public interface CustomOp {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the output shape for this op
|
* Calculate the output shape for this op
|
||||||
* @return
|
* @return Output array shapes
|
||||||
*/
|
*/
|
||||||
List<LongShapeDescriptor> calculateOutputShape();
|
List<LongShapeDescriptor> calculateOutputShape();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the output shape for this op
|
||||||
|
* @return Output array shapes
|
||||||
|
*/
|
||||||
|
List<LongShapeDescriptor> calculateOutputShape(OpContext opContext);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the custom op descriptor if one is available.
|
* Get the custom op descriptor if one is available.
|
||||||
* @return
|
* @return
|
||||||
|
|
|
@ -493,6 +493,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
val descriptor = getDescriptor();
|
val descriptor = getDescriptor();
|
||||||
if (outputShapes != null && !outputShapes.isEmpty())
|
if (outputShapes != null && !outputShapes.isEmpty())
|
||||||
return outputShapes;
|
return outputShapes;
|
||||||
|
@ -504,34 +509,41 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
|
|
||||||
|
|
||||||
//not fully initialized: missing integer args
|
//not fully initialized: missing integer args
|
||||||
if (descriptor.getNumIArgs() >= 0 && numIArguments() < descriptor.getNumIArgs()) {
|
int nI = oc != null ? oc.numIArguments() : numIArguments();
|
||||||
|
if (descriptor.getNumIArgs() >= 0 && nI < descriptor.getNumIArgs()) {
|
||||||
if(log.isTraceEnabled()){
|
if(log.isTraceEnabled()){
|
||||||
log.trace("Could not calculate output shape for op {}: not fully initialized ({} IArgs specified, " +
|
log.trace("Could not calculate output shape for op {}: not fully initialized ({} IArgs specified, " +
|
||||||
"{} required)", getClass().getName(),numIArguments(), descriptor.getNumIArgs());
|
"{} required)", getClass().getName(), nI, descriptor.getNumIArgs());
|
||||||
}
|
}
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//not fully initialized: missing floating point args
|
//not fully initialized: missing floating point args
|
||||||
if (descriptor.getNumTArgs() >= 0 && numTArguments() < descriptor.getNumTArgs()) {
|
int nT = oc != null ? oc.numTArguments() : numTArguments();
|
||||||
|
if (descriptor.getNumTArgs() >= 0 && nT < descriptor.getNumTArgs()) {
|
||||||
if(log.isTraceEnabled()){
|
if(log.isTraceEnabled()){
|
||||||
log.trace("Could not calculate output shape for op {}: not fully initialized ({} TArgs specified, " +
|
log.trace("Could not calculate output shape for op {}: not fully initialized ({} TArgs specified, " +
|
||||||
"{} required)", getClass().getName(),numTArguments(), descriptor.getNumTArgs());
|
"{} required)", getClass().getName(), nT, descriptor.getNumTArgs());
|
||||||
}
|
}
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
//not fully initialized: missing INDArray input args
|
//not fully initialized: missing INDArray input args
|
||||||
if(descriptor.getNumInputs() >= 0 && numInputArguments() < descriptor.getNumInputs()){
|
int nIn = oc != null ? oc.numInputArguments() : numInputArguments();
|
||||||
|
if(descriptor.getNumInputs() >= 0 && nIn < descriptor.getNumInputs()){
|
||||||
if(log.isTraceEnabled()){
|
if(log.isTraceEnabled()){
|
||||||
log.trace("Could not calculate output shape for op {}: not fully initialized ({} input (INDArray) args specified, " +
|
log.trace("Could not calculate output shape for op {}: not fully initialized ({} input (INDArray) args specified, " +
|
||||||
"{} required)", getClass().getName(),numInputArguments(), descriptor.getNumInputs());
|
"{} required)", getClass().getName(), nIn, descriptor.getNumInputs());
|
||||||
}
|
}
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
List<LongShapeDescriptor> ret = Nd4j.getExecutioner().calculateOutputShape(this);
|
List<LongShapeDescriptor> ret;
|
||||||
|
if(oc == null)
|
||||||
|
ret = Nd4j.getExecutioner().calculateOutputShape(this);
|
||||||
|
else
|
||||||
|
ret = Nd4j.getExecutioner().calculateOutputShape(this, oc);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -89,6 +89,14 @@ public class NoOp extends DynamicCustomOp {
|
||||||
return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor());
|
return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
|
||||||
|
if(oc.getInputArrays() != null && !oc.getInputArrays().isEmpty()){
|
||||||
|
return Collections.singletonList(oc.getInputArray(0).shapeDescriptor());
|
||||||
|
}
|
||||||
|
return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
return Collections.singletonList(DataType.BOOL);
|
return Collections.singletonList(DataType.BOOL);
|
||||||
|
|
|
@ -39,12 +39,15 @@ public interface OpContext extends AutoCloseable {
|
||||||
|
|
||||||
List<Long> getIArguments();
|
List<Long> getIArguments();
|
||||||
|
|
||||||
|
int numIArguments();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets floating point arguments required for operation
|
* This method sets floating point arguments required for operation
|
||||||
* @param arguments
|
* @param arguments
|
||||||
*/
|
*/
|
||||||
void setTArguments(double... arguments);
|
void setTArguments(double... arguments);
|
||||||
List<Double> getTArguments();
|
List<Double> getTArguments();
|
||||||
|
int numTArguments();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets data type arguments required for operation
|
* This method sets data type arguments required for operation
|
||||||
|
@ -52,14 +55,15 @@ public interface OpContext extends AutoCloseable {
|
||||||
*/
|
*/
|
||||||
void setDArguments(DataType... arguments);
|
void setDArguments(DataType... arguments);
|
||||||
List<DataType> getDArguments();
|
List<DataType> getDArguments();
|
||||||
|
int numDArguments();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets boolean arguments required for operation
|
* This method sets boolean arguments required for operation
|
||||||
* @param arguments
|
* @param arguments
|
||||||
*/
|
*/
|
||||||
void setBArguments(boolean... arguments);
|
void setBArguments(boolean... arguments);
|
||||||
|
|
||||||
List<Boolean> getBArguments();
|
List<Boolean> getBArguments();
|
||||||
|
int numBArguments();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets root-level seed for rng
|
* This method sets root-level seed for rng
|
||||||
|
@ -99,6 +103,10 @@ public interface OpContext extends AutoCloseable {
|
||||||
*/
|
*/
|
||||||
List<INDArray> getInputArrays();
|
List<INDArray> getInputArrays();
|
||||||
|
|
||||||
|
int numInputArguments();
|
||||||
|
|
||||||
|
INDArray getInputArray(int idx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method adds INDArray as output for future op call
|
* This method adds INDArray as output for future op call
|
||||||
* @param index
|
* @param index
|
||||||
|
@ -124,6 +132,10 @@ public interface OpContext extends AutoCloseable {
|
||||||
*/
|
*/
|
||||||
List<INDArray> getOutputArrays();
|
List<INDArray> getOutputArrays();
|
||||||
|
|
||||||
|
INDArray getOutputArray(int i);
|
||||||
|
|
||||||
|
int numOutputArguments();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns pointer to context, to be used during native op execution
|
* This method returns pointer to context, to be used during native op execution
|
||||||
* @return
|
* @return
|
||||||
|
|
|
@ -86,7 +86,9 @@ public interface ReduceOp extends Op {
|
||||||
*/
|
*/
|
||||||
DataType resultType();
|
DataType resultType();
|
||||||
|
|
||||||
boolean validateDataTypes();
|
DataType resultType(OpContext oc);
|
||||||
|
|
||||||
|
boolean validateDataTypes(OpContext oc);
|
||||||
|
|
||||||
Number getFinalResult();
|
Number getFinalResult();
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,9 @@ public interface TransformOp extends Op {
|
||||||
*/
|
*/
|
||||||
DataType resultType();
|
DataType resultType();
|
||||||
|
|
||||||
|
DataType resultType(OpContext opContext);
|
||||||
|
|
||||||
Type getOpType();
|
Type getOpType();
|
||||||
|
|
||||||
boolean validateDataTypes(boolean experimentalMode);
|
boolean validateDataTypes(OpContext opContext, boolean experimentalMode);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -237,6 +238,11 @@ public class ScatterUpdate implements CustomOp {
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
return Nd4j.getExecutioner().calculateOutputShape(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
|
||||||
|
return Nd4j.getExecutioner().calculateOutputShape(this, opContext);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CustomOpDescriptor getDescriptor() {
|
public CustomOpDescriptor getDescriptor() {
|
||||||
return op.getDescriptor();
|
return op.getDescriptor();
|
||||||
|
|
|
@ -55,7 +55,7 @@ import java.util.*;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DefaultOpExecutioner implements OpExecutioner {
|
public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
|
|
||||||
private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: deeplearning4j.org/docs/latest/nd4j-overview#workspaces-panic";
|
private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: deeplearning4j.org/docs/latest/nd4j-overview#workspaces-panic";
|
||||||
|
|
||||||
|
@ -108,9 +108,10 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(Op op) {
|
public abstract INDArray exec(Op op);
|
||||||
throw new IllegalStateException("Java computation no longer supported");
|
|
||||||
}
|
@Override
|
||||||
|
public abstract INDArray exec(Op op, OpContext opContext);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Op execAndReturn(Op op) {
|
public Op execAndReturn(Op op) {
|
||||||
|
@ -175,24 +176,16 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(ReduceOp op) {
|
public abstract INDArray exec(ReduceOp op);
|
||||||
throw new UnsupportedOperationException("Java computation no longer supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(Variance accumulation) {
|
public abstract INDArray exec(Variance accumulation);
|
||||||
throw new UnsupportedOperationException("Operation should use exec special");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(IndexAccumulation op) {
|
public abstract INDArray exec(IndexAccumulation op);
|
||||||
throw new UnsupportedOperationException("Operation should use exec special");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(BroadcastOp broadcast) {
|
public abstract INDArray exec(BroadcastOp broadcast);
|
||||||
throw new IllegalStateException("Java computation no longer supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void exec(MetaOp op) {
|
public void exec(MetaOp op) {
|
||||||
|
@ -215,9 +208,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(ScalarOp op) {
|
public abstract INDArray exec(ScalarOp op);
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void exec(List<Aggregate> batch) {
|
public void exec(List<Aggregate> batch) {
|
||||||
|
@ -241,9 +232,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
* @param rng
|
* @param rng
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(RandomOp op, Random rng) {
|
public abstract INDArray exec(RandomOp op, Random rng);
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
@ -741,6 +730,11 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(CustomOp op, OpContext opContext) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] allocateOutputArrays(CustomOp op){
|
public INDArray[] allocateOutputArrays(CustomOp op){
|
||||||
List<LongShapeDescriptor> shapes = calculateOutputShape(op);
|
List<LongShapeDescriptor> shapes = calculateOutputShape(op);
|
||||||
|
@ -946,4 +940,44 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
public String runFullBenchmarkSuit(boolean printOut) {
|
public String runFullBenchmarkSuit(boolean printOut) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void setX(INDArray x, Op op, OpContext oc){
|
||||||
|
if(oc != null)
|
||||||
|
oc.setInputArray(0, x);
|
||||||
|
else
|
||||||
|
op.setX(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
public INDArray getX(Op op, OpContext oc){
|
||||||
|
if( oc != null )
|
||||||
|
return oc.getInputArray(0);
|
||||||
|
return op.x();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setY(INDArray y, Op op, OpContext oc){
|
||||||
|
if(oc != null)
|
||||||
|
oc.setInputArray(1, y);
|
||||||
|
else
|
||||||
|
op.setY(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
public INDArray getY(Op op, OpContext oc){
|
||||||
|
if( oc != null )
|
||||||
|
return oc.getInputArray(1);
|
||||||
|
return op.y();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setZ(INDArray z, Op op, OpContext oc){
|
||||||
|
if(oc != null)
|
||||||
|
oc.setOutputArray(0, z);
|
||||||
|
else
|
||||||
|
op.setZ(z);
|
||||||
|
}
|
||||||
|
|
||||||
|
public INDArray getZ(Op op, OpContext oc){
|
||||||
|
if( oc != null )
|
||||||
|
return oc.getOutputArray(0);
|
||||||
|
return op.z();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,6 +98,13 @@ public interface OpExecutioner {
|
||||||
*/
|
*/
|
||||||
INDArray exec(Op op);
|
INDArray exec(Op op);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute the operation
|
||||||
|
*
|
||||||
|
* @param op the operation to execute
|
||||||
|
*/
|
||||||
|
INDArray exec(Op op, OpContext opContext);
|
||||||
|
|
||||||
/**Execute a TransformOp and return the result
|
/**Execute a TransformOp and return the result
|
||||||
* @param op the operation to execute
|
* @param op the operation to execute
|
||||||
*/
|
*/
|
||||||
|
@ -364,6 +371,8 @@ public interface OpExecutioner {
|
||||||
|
|
||||||
List<LongShapeDescriptor> calculateOutputShape(CustomOp op);
|
List<LongShapeDescriptor> calculateOutputShape(CustomOp op);
|
||||||
|
|
||||||
|
List<LongShapeDescriptor> calculateOutputShape(CustomOp op, OpContext opContext);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Equivalent to calli
|
* Equivalent to calli
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -150,6 +151,11 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
|
||||||
return OUT_SHAPE;
|
return OUT_SHAPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
|
||||||
|
return OUT_SHAPE;
|
||||||
|
}
|
||||||
|
|
||||||
public Op.Type opType() {
|
public Op.Type opType() {
|
||||||
return Op.Type.LOGIC;
|
return Op.Type.LOGIC;
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseReduceOp;
|
import org.nd4j.linalg.api.ops.BaseReduceOp;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
@ -131,8 +132,14 @@ public class Variance extends BaseReduceOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType resultType() {
|
public DataType resultType() {
|
||||||
if (this.x() != null && this.x().isR())
|
return resultType(null);
|
||||||
return this.x().dataType();
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType resultType(OpContext oc){
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
if (x != null && x.isR())
|
||||||
|
return x.dataType();
|
||||||
|
|
||||||
if(this.arg() != null){
|
if(this.arg() != null){
|
||||||
return this.arg().dataType();
|
return this.arg().dataType();
|
||||||
|
@ -142,14 +149,18 @@ public class Variance extends BaseReduceOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes() {
|
public boolean validateDataTypes(OpContext oc) {
|
||||||
if (!x().isR())
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
if (x != null && !x.isR()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
if (y != null && !y.isR())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (y() != null && !y().isR())
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
return false;
|
if (z != null && !z.isR())
|
||||||
|
|
||||||
if (z() != null && !z().isR())
|
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -157,15 +168,22 @@ public class Variance extends BaseReduceOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
if(args().length < 1) {
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
|
||||||
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
|
||||||
|
if(oc == null && args().length < 1) {
|
||||||
throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found.");
|
throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found.");
|
||||||
}
|
}
|
||||||
|
|
||||||
long[] argShape = arg().getShape();
|
long[] argShape = arg().getShape();
|
||||||
if (argShape == null && x() == null) {
|
if (argShape == null && x == null) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x().shape() : argShape);
|
long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x.shape() : argShape);
|
||||||
|
|
||||||
val ret = new ArrayList<LongShapeDescriptor>(1);
|
val ret = new ArrayList<LongShapeDescriptor>(1);
|
||||||
val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims());
|
val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims());
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -94,20 +95,29 @@ public class MaxOut extends BaseTransformOp {
|
||||||
return Nd4j.defaultFloatingPointType();
|
return Nd4j.defaultFloatingPointType();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType resultType(OpContext oc) {
|
||||||
|
return Nd4j.defaultFloatingPointType();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Type getOpType() {
|
public Type getOpType() {
|
||||||
return Type.TRANSFORM_STRICT;
|
return Type.TRANSFORM_STRICT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean validateDataTypes(boolean experimentalMode) {
|
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
|
||||||
if (!x().isR())
|
INDArray x = oc != null ? oc.getInputArray(0) : x();
|
||||||
|
INDArray y = oc != null ? oc.getInputArray(1) : y();
|
||||||
|
INDArray z = oc != null ? oc.getOutputArray(0) : z();
|
||||||
|
|
||||||
|
if (!x.isR())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (y() != null && !y().isR())
|
if (y != null && !y().isR())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (z() != null && z().dataType() != x().dataType())
|
if (z != null && z().dataType() != x().dataType())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseOp;
|
import org.nd4j.linalg.api.ops.BaseOp;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.ops.RandomOp;
|
import org.nd4j.linalg.api.ops.RandomOp;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
@ -65,6 +66,11 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
return calculateOutputShape(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
|
||||||
if(shape != null){
|
if(shape != null){
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType()));
|
return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType()));
|
||||||
} else {
|
} else {
|
||||||
|
@ -83,4 +89,8 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp {
|
||||||
public boolean isInPlace(){
|
public boolean isInPlace(){
|
||||||
return x == null || x == z || x.data().pointer().address() == z.data().pointer().address();
|
return x == null || x == z || x.data().pointer().address() == z.data().pointer().address();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isTripleArgRngOp(){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,4 +139,9 @@ public class BinomialDistribution extends BaseRandomOp {
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.DOUBLE);
|
return Collections.singletonList(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTripleArgRngOp() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,4 +138,9 @@ public class GaussianDistribution extends BaseRandomOp {
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.DOUBLE);
|
return Collections.singletonList(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTripleArgRngOp() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,4 +135,9 @@ public class LogNormalDistribution extends BaseRandomOp {
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.DOUBLE);
|
return Collections.singletonList(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTripleArgRngOp() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,4 +136,9 @@ public class TruncatedNormalDistribution extends BaseRandomOp {
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.DOUBLE);
|
return Collections.singletonList(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTripleArgRngOp() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6556,6 +6556,10 @@ public class Nd4j {
|
||||||
return getExecutioner().exec(op);
|
return getExecutioner().exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static INDArray exec(Op op, OpContext context){
|
||||||
|
return getExecutioner().exec(op, context);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Execute the operation and return the result
|
* Execute the operation and return the result
|
||||||
*
|
*
|
||||||
|
|
|
@ -54,7 +54,7 @@ public abstract class Nd4jBlas implements Blas {
|
||||||
}
|
}
|
||||||
|
|
||||||
String logInit = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION);
|
String logInit = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION);
|
||||||
if(logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit)) {
|
if(logOpenMPBlasThreads() && (logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit))) {
|
||||||
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
|
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -74,4 +74,8 @@ public abstract class Nd4jBlas implements Blas {
|
||||||
}
|
}
|
||||||
return Vendor.values()[vendor];
|
return Vendor.values()[vendor];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean logOpenMPBlasThreads(){
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -134,4 +134,9 @@ public class CudaBlas extends Nd4jBlas {
|
||||||
public int getBlasVendorId() {
|
public int getBlasVendorId() {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean logOpenMPBlasThreads() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
|
||||||
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
||||||
|
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
@ -229,7 +230,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
INDArray ret = op.z();
|
INDArray ret = op.z();
|
||||||
|
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
op.validateDataTypes();
|
op.validateDataTypes(null);
|
||||||
//validateDataType(Nd4j.dataType(), op);
|
//validateDataType(Nd4j.dataType(), op);
|
||||||
|
|
||||||
for (int i = 0; i < dimension.length; i++)
|
for (int i = 0; i < dimension.length; i++)
|
||||||
|
@ -614,8 +615,15 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(Op op) {
|
public INDArray exec(Op op) {
|
||||||
|
return exec(op, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray exec(Op op, OpContext oc) {
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
|
|
||||||
|
/*
|
||||||
|
//TODO this never would have worked
|
||||||
//linear views and oblong offsets can't be handled by the gpu (due to the way the buffers are interpreted as vectors)
|
//linear views and oblong offsets can't be handled by the gpu (due to the way the buffers are interpreted as vectors)
|
||||||
if ( op instanceof CopyOp) {
|
if ( op instanceof CopyOp) {
|
||||||
// we dont' care about op.Z sync state, since it'll be overwritten
|
// we dont' care about op.Z sync state, since it'll be overwritten
|
||||||
|
@ -631,27 +639,27 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
//AtomicAllocator.getInstance().tickHostWrite(op.z());
|
//AtomicAllocator.getInstance().tickHostWrite(op.z());
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}*/
|
||||||
|
|
||||||
if (op instanceof TransformOp) {
|
if (op instanceof TransformOp) {
|
||||||
TransformOp t = (TransformOp) op;
|
TransformOp t = (TransformOp) op;
|
||||||
invoke(t);
|
invoke(t, oc);
|
||||||
} else if (op instanceof ReduceOp) {
|
} else if (op instanceof ReduceOp) {
|
||||||
ReduceOp acc = (ReduceOp) op;
|
ReduceOp acc = (ReduceOp) op;
|
||||||
invoke(acc, acc.dimensions().toIntVector());
|
invoke(acc, oc, acc.dimensions().toIntVector());
|
||||||
} else if (op instanceof ScalarOp) {
|
} else if (op instanceof ScalarOp) {
|
||||||
ScalarOp sc = (ScalarOp) op;
|
ScalarOp sc = (ScalarOp) op;
|
||||||
invoke(sc);
|
invoke(sc, oc);
|
||||||
} else if (op instanceof BroadcastOp) {
|
} else if (op instanceof BroadcastOp) {
|
||||||
BroadcastOp broadcastOp = (BroadcastOp) op;
|
BroadcastOp broadcastOp = (BroadcastOp) op;
|
||||||
invoke(broadcastOp);
|
invoke(broadcastOp, oc);
|
||||||
} else if (op instanceof IndexAccumulation) {
|
} else if (op instanceof IndexAccumulation) {
|
||||||
IndexAccumulation indexAccumulation = (IndexAccumulation) op;
|
IndexAccumulation indexAccumulation = (IndexAccumulation) op;
|
||||||
invoke(indexAccumulation, indexAccumulation.dimensions().toIntVector());
|
invoke(indexAccumulation, oc, indexAccumulation.dimensions().toIntVector());
|
||||||
} else if (op instanceof RandomOp) {
|
} else if (op instanceof RandomOp) {
|
||||||
exec((RandomOp) op);
|
exec((RandomOp) op, oc, Nd4j.getRandom());
|
||||||
} else if (op instanceof CustomOp) {
|
} else if (op instanceof CustomOp) {
|
||||||
exec((CustomOp) op);
|
exec((CustomOp) op, oc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -659,19 +667,22 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TransformOp execAndReturn(TransformOp op) {
|
public TransformOp execAndReturn(TransformOp op) {
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
invoke(op);
|
invoke(op, null);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
protected CudaContext invoke(BroadcastOp op) {
|
protected CudaContext invoke(BroadcastOp op, OpContext oc) {
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
|
|
||||||
|
INDArray x = getX(op, oc);
|
||||||
|
INDArray y = getY(op, oc);
|
||||||
|
INDArray z = getZ(op, oc);
|
||||||
|
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
|
|
||||||
//validateDataType(Nd4j.dataType(), op);
|
//validateDataType(Nd4j.dataType(), op);
|
||||||
|
@ -684,17 +695,17 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
lastOp.set(op.opName());
|
lastOp.set(op.opName());
|
||||||
|
|
||||||
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
|
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
|
|
||||||
val hostXShapeInfo =
|
val hostXShapeInfo =
|
||||||
op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
|
||||||
val hostYShapeInfo =
|
val hostYShapeInfo =
|
||||||
op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
|
y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
|
||||||
val hostZShapeInfo =
|
val hostZShapeInfo =
|
||||||
op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
|
z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
|
||||||
|
|
||||||
val tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), op.getDimension());
|
val tadBuffers = tadManager.getTADOnlyShapeInfo(x, op.getDimension());
|
||||||
|
|
||||||
val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
|
val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
|
||||||
val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
|
val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
|
||||||
|
@ -706,13 +717,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
Pointer devTadOffsetsZ = null;
|
Pointer devTadOffsetsZ = null;
|
||||||
|
|
||||||
// that's the place where we're going to have second TAD in place
|
// that's the place where we're going to have second TAD in place
|
||||||
val tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), op.getDimension());
|
val tadBuffersZ = tadManager.getTADOnlyShapeInfo(z, op.getDimension());
|
||||||
|
|
||||||
devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getFirst(), context);
|
devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getFirst(), context);
|
||||||
devTadOffsetsZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getSecond(), context);
|
devTadOffsetsZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getSecond(), context);
|
||||||
|
|
||||||
PointerPointer xShapeInfoHostPointer = extraz.get().put(
|
PointerPointer xShapeInfoHostPointer = extraz.get().put(
|
||||||
AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), // 0
|
AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), // 0
|
||||||
context.getOldStream(), // 1
|
context.getOldStream(), // 1
|
||||||
AtomicAllocator.getInstance().getDeviceIdPointer(), // 2
|
AtomicAllocator.getInstance().getDeviceIdPointer(), // 2
|
||||||
context.getBufferAllocation(), // 3
|
context.getBufferAllocation(), // 3
|
||||||
|
@ -727,30 +738,30 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
devTadShapeInfoZ, // 12
|
devTadShapeInfoZ, // 12
|
||||||
devTadOffsetsZ); // 13
|
devTadOffsetsZ); // 13
|
||||||
|
|
||||||
Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
|
Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
|
Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
|
||||||
Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context);
|
Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context);
|
||||||
|
|
||||||
val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer();
|
val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
|
||||||
val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer();
|
val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
|
||||||
val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer();
|
val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
|
||||||
|
|
||||||
//log.info("X: {}; Y: {}; Z: {}; dTS: {}, dTO: {}; dTSz: {}; dTOz: {};", x.address(), y.address(), z.address(), devTadShapeInfo.address(), devTadOffsets.address(), devTadShapeInfoZ.address(), devTadOffsetsZ.address());
|
//log.info("X: {}; Y: {}; Z: {}; dTS: {}, dTO: {}; dTSz: {}; dTOz: {};", x.address(), y.address(), z.address(), devTadShapeInfo.address(), devTadOffsets.address(), devTadShapeInfoZ.address(), devTadOffsetsZ.address());
|
||||||
|
|
||||||
switch (op.getOpType()) {
|
switch (op.getOpType()) {
|
||||||
case BROADCAST:
|
case BROADCAST:
|
||||||
nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
||||||
break;
|
break;
|
||||||
case BROADCAST_BOOL:
|
case BROADCAST_BOOL:
|
||||||
nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
null,
|
null,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
||||||
break;
|
break;
|
||||||
|
@ -768,11 +779,16 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
protected CudaContext invoke(IndexAccumulation op, int[] dimension) {
|
protected CudaContext invoke(IndexAccumulation op, OpContext oc, int[] dimension) {
|
||||||
dimension = Shape.normalizeAxis(op.x().rank(), dimension);
|
INDArray x = getX(op, oc);
|
||||||
|
INDArray y = getY(op, oc);
|
||||||
|
INDArray z = getZ(op, oc);
|
||||||
|
|
||||||
|
dimension = Shape.normalizeAxis(x.rank(), dimension);
|
||||||
if (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)) {
|
if (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)) {
|
||||||
if(op.z() == op.x() || op.z() == null) {
|
if(z == x || z == null) {
|
||||||
op.setZ(Nd4j.createUninitialized(DataType.LONG, new long[0], 'c'));
|
z = Nd4j.createUninitialized(DataType.LONG, new long[0], 'c');
|
||||||
|
setZ(z, op, oc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -790,46 +806,45 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
|
CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
|
||||||
if (dimension != null)
|
if (dimension != null)
|
||||||
for (int i = 0; i < dimension.length; i++)
|
for (int i = 0; i < dimension.length; i++)
|
||||||
if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE)
|
if (dimension[i] >= x.rank() && dimension[i] != Integer.MAX_VALUE)
|
||||||
throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
|
throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + x.rank() + "]");
|
||||||
|
|
||||||
val context = AtomicAllocator.getInstance().getDeviceContext();
|
val context = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
|
|
||||||
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
|
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
|
||||||
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null;
|
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(x.dataType()), context) : null;
|
||||||
|
|
||||||
val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
|
||||||
val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
|
val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
|
||||||
val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
|
val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
|
||||||
|
|
||||||
int fdimension[] = dimension;
|
int fdimension[] = dimension;
|
||||||
if (fdimension == null)
|
if (fdimension == null)
|
||||||
fdimension = new int[] {0};
|
fdimension = new int[] {0};
|
||||||
|
|
||||||
Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), fdimension);
|
Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(x, fdimension);
|
||||||
|
|
||||||
Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
|
Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
|
||||||
Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
|
Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
|
||||||
|
|
||||||
DataBuffer offsets = tadBuffers.getSecond();
|
DataBuffer offsets = tadBuffers.getSecond();
|
||||||
Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
|
Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
|
||||||
val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
|
val zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer();
|
val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
|
||||||
val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer();
|
val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
|
||||||
val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer();
|
|
||||||
|
|
||||||
PointerPointer xShapeInfoHostPointer = extraz.get().put(
|
PointerPointer xShapeInfoHostPointer = extraz.get().put(
|
||||||
AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(),
|
AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(),
|
||||||
AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
|
AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
|
||||||
context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
|
context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
|
||||||
hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets);
|
hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets);
|
||||||
|
|
||||||
if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) {
|
if (z.isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) {
|
||||||
nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
||||||
} else {
|
} else {
|
||||||
if (dimension != null && dimension.length > 1)
|
if (dimension != null && dimension.length > 1)
|
||||||
Arrays.sort(dimension);
|
Arrays.sort(dimension);
|
||||||
|
@ -839,9 +854,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
.getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension));
|
.getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension));
|
||||||
|
|
||||||
nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -855,30 +870,34 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected CudaContext invoke(ReduceOp op, int[] dimension) {
|
protected CudaContext invoke(ReduceOp op, OpContext oc, int[] dimension) {
|
||||||
val context = AtomicAllocator.getInstance().getDeviceContext();
|
val context = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
|
|
||||||
|
INDArray x = getX(op, oc);
|
||||||
|
INDArray y = getY(op, oc);
|
||||||
|
INDArray z = getZ(op, oc);
|
||||||
|
|
||||||
if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){
|
if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){
|
||||||
//Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y]
|
//Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y]
|
||||||
//Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions"
|
//Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions"
|
||||||
if(op.z() != null){
|
if(z != null){
|
||||||
Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." +
|
Preconditions.checkState(x.equalShapes(z), "For empty reductions, result (z) array must have same shape as x shape." +
|
||||||
" Got: x=%ndShape, z=%ndShape", op.x(), op.z());
|
" Got: x=%ndShape, z=%ndShape", x, z);
|
||||||
op.z().assign(op.x());
|
z.assign(x);
|
||||||
return context;
|
return context;
|
||||||
} else {
|
} else {
|
||||||
op.setZ(op.x().dup());
|
op.setZ(x.dup());
|
||||||
return context;
|
return context;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: this should be moved down to C++ on per-op basis
|
// FIXME: this should be moved down to C++ on per-op basis
|
||||||
// reduce to scalar case, ReduceBool ops require special treatment
|
// reduce to scalar case, ReduceBool ops require special treatment
|
||||||
if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
|
if (op instanceof BaseReduceBoolOp && x.isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
|
||||||
if (op.z() == null) {
|
if (z == null) {
|
||||||
op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
|
op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
|
||||||
} else {
|
} else {
|
||||||
op.z().assign(((BaseReduceBoolOp) op).emptyValue());
|
z.assign(((BaseReduceBoolOp) op).emptyValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
return context;
|
return context;
|
||||||
|
@ -888,7 +907,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
|
|
||||||
dimension = Shape.normalizeAxis(op.x().rank(), dimension);
|
dimension = Shape.normalizeAxis(x.rank(), dimension);
|
||||||
|
|
||||||
//validateDataType(Nd4j.dataType(), op);
|
//validateDataType(Nd4j.dataType(), op);
|
||||||
|
|
||||||
|
@ -903,130 +922,131 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
Arrays.sort(dimension);
|
Arrays.sort(dimension);
|
||||||
|
|
||||||
for (int i = 0; i < dimension.length; i++)
|
for (int i = 0; i < dimension.length; i++)
|
||||||
if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE)
|
if (dimension[i] >= x.rank() && dimension[i] != Integer.MAX_VALUE)
|
||||||
throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension)
|
throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension)
|
||||||
+ " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
|
+ " contains element that higher then rank of op.X: [" + x.rank() + "]");
|
||||||
|
|
||||||
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
lastOp.set(op.opName());
|
lastOp.set(op.opName());
|
||||||
|
|
||||||
val tadBuffers = op.x().isEmpty() ? Pair.<DataBuffer, DataBuffer>makePair(op.x().data(), null) : tadManager.getTADOnlyShapeInfo(op.x(), dimension);
|
val tadBuffers = x.isEmpty() ? Pair.<DataBuffer, DataBuffer>makePair(x.data(), null) : tadManager.getTADOnlyShapeInfo(x, dimension);
|
||||||
|
|
||||||
val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
|
val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
|
||||||
val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
|
val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
|
||||||
|
|
||||||
val offsets = op.x().isEmpty() ? null : tadBuffers.getSecond();
|
val offsets = x.isEmpty() ? null : tadBuffers.getSecond();
|
||||||
val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
|
val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
|
||||||
|
|
||||||
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
|
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
long[] retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims());
|
long[] retShape = Shape.reductionShape(x, dimension, true, op.isKeepDims());
|
||||||
|
|
||||||
if (op.y() != null) {
|
if (y != null) {
|
||||||
//2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y
|
//2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y
|
||||||
if (op.x().length() == op.y().length()) {
|
if (x.length() == y.length()) {
|
||||||
//Pairwise
|
//Pairwise
|
||||||
if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
|
if (x.tensorsAlongDimension(dimension) != y.tensorsAlongDimension(dimension)) {
|
||||||
throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " +
|
throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " +
|
||||||
Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) +
|
Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) +
|
||||||
", dimension = " + Arrays.toString(dimension) + ")");
|
", dimension = " + Arrays.toString(dimension) + ")");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//Every X TAD vs. entirety of Y
|
//Every X TAD vs. entirety of Y
|
||||||
val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension);
|
val xTADSize = x.length() / x.tensorsAlongDimension(dimension);
|
||||||
|
|
||||||
if (xTADSize != op.y().length()) {
|
if (xTADSize != y.length()) {
|
||||||
throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
|
throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
|
||||||
" (x TAD size = " + xTADSize + ", y size = " + op.y().length());
|
" (x TAD size = " + xTADSize + ", y size = " + y.length());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) {
|
//if (x.isVector() && x.length() == ArrayUtil.prod(retShape)) {
|
||||||
// return null;
|
// return null;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
val dataType = op.resultType();
|
val dataType = oc != null ? op.resultType(oc) : op.resultType();
|
||||||
|
|
||||||
if( op.z() == null ){
|
if( z == null ){
|
||||||
val ret = Nd4j.createUninitialized(dataType, retShape);
|
val ret = Nd4j.createUninitialized(dataType, retShape);
|
||||||
op.setZ(ret);
|
setZ(ret, op, oc);
|
||||||
} else if(op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())){
|
z = ret;
|
||||||
|
} else if(z.dataType() != dataType || !Arrays.equals(retShape, z.shape())){
|
||||||
throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape)
|
throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape)
|
||||||
+ " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape()));
|
+ " but has datatype " + z.dataType() + " and shape " + Arrays.toString(z.shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType());
|
val eb = op.extraArgsDataBuff(z.dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? x.dataType() : z.dataType());
|
||||||
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;
|
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;
|
||||||
|
|
||||||
val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
|
||||||
val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
|
val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
|
||||||
val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
|
val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
|
||||||
|
|
||||||
val xShapeInfoHostPointer = extraz.get().put(
|
val xShapeInfoHostPointer = extraz.get().put(
|
||||||
AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(),
|
AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(),
|
||||||
AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
|
AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
|
||||||
context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
|
context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
|
||||||
hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets);
|
hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets);
|
||||||
|
|
||||||
val yTadBuffers = op.y() == null ? null : tadManager.getTADOnlyShapeInfo(op.y(), dimension);
|
val yTadBuffers = y == null ? null : tadManager.getTADOnlyShapeInfo(y, dimension);
|
||||||
|
|
||||||
val yDevTadShapeInfo = op.y() == null ? null : AtomicAllocator.getInstance().getPointer(yTadBuffers.getFirst(), context);
|
val yDevTadShapeInfo = y == null ? null : AtomicAllocator.getInstance().getPointer(yTadBuffers.getFirst(), context);
|
||||||
val yOffsets = op.y() == null ? null : yTadBuffers.getSecond();
|
val yOffsets = y == null ? null : yTadBuffers.getSecond();
|
||||||
val yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
|
val yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
|
||||||
|
|
||||||
if (op.y() != null) {
|
if (y != null) {
|
||||||
xShapeInfoHostPointer.put(12, yDevTadShapeInfo);
|
xShapeInfoHostPointer.put(12, yDevTadShapeInfo);
|
||||||
xShapeInfoHostPointer.put(13, yDevTadOffsets);
|
xShapeInfoHostPointer.put(13, yDevTadOffsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
|
val zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer();
|
val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
|
||||||
val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer();
|
val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
|
||||||
val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer();
|
val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
|
||||||
|
|
||||||
op.validateDataTypes();
|
op.validateDataTypes(null);
|
||||||
|
|
||||||
if (op.z().isScalar()) {
|
if (z.isScalar()) {
|
||||||
if (op instanceof Variance) {
|
if (op instanceof Variance) {
|
||||||
nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((Variance) op).isBiasCorrected());
|
((Variance) op).isBiasCorrected());
|
||||||
} else if (op.y() != null) {
|
} else if (y != null) {
|
||||||
Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
|
Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);
|
||||||
nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
||||||
} else {
|
} else {
|
||||||
switch (op.getOpType()) {
|
switch (op.getOpType()) {
|
||||||
case REDUCE_FLOAT:
|
case REDUCE_FLOAT:
|
||||||
nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
||||||
break;
|
break;
|
||||||
case REDUCE_BOOL:
|
case REDUCE_BOOL:
|
||||||
nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
||||||
break;
|
break;
|
||||||
case REDUCE_SAME:
|
case REDUCE_SAME:
|
||||||
nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
||||||
break;
|
break;
|
||||||
case REDUCE_LONG:
|
case REDUCE_LONG:
|
||||||
nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
@ -1035,21 +1055,21 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
} else {
|
} else {
|
||||||
val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context);
|
val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context);
|
||||||
|
|
||||||
if (op.y() != null) {
|
if (y != null) {
|
||||||
val yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
|
val yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);
|
||||||
nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
|
||||||
(LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets);
|
(LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets);
|
||||||
} else {
|
} else {
|
||||||
if (op instanceof Variance) {
|
if (op instanceof Variance) {
|
||||||
nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
|
||||||
((Variance) op).isBiasCorrected(),
|
((Variance) op).isBiasCorrected(),
|
||||||
(LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets);
|
(LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets);
|
||||||
|
@ -1057,30 +1077,30 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
switch (op.getOpType()) {
|
switch (op.getOpType()) {
|
||||||
case REDUCE_FLOAT:
|
case REDUCE_FLOAT:
|
||||||
nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
||||||
break;
|
break;
|
||||||
case REDUCE_SAME:
|
case REDUCE_SAME:
|
||||||
nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
||||||
break;
|
break;
|
||||||
case REDUCE_BOOL:
|
case REDUCE_BOOL:
|
||||||
nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
||||||
break;
|
break;
|
||||||
case REDUCE_LONG:
|
case REDUCE_LONG:
|
||||||
nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
extraArgs,
|
extraArgs,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -1187,34 +1207,40 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(ScalarOp op) {
|
public INDArray exec(ScalarOp op) {
|
||||||
invoke(op);
|
invoke(op, null);
|
||||||
return op.z();
|
return op.z();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected CudaContext invoke(ScalarOp op) {
|
protected CudaContext invoke(ScalarOp op, OpContext oc) {
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
|
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
|
|
||||||
|
INDArray x = getX(op, oc);
|
||||||
|
INDArray y = getY(op, oc);
|
||||||
|
INDArray z = getZ(op, oc);
|
||||||
|
|
||||||
// validateDataType(Nd4j.dataType(), op);
|
// validateDataType(Nd4j.dataType(), op);
|
||||||
|
|
||||||
if(op.z() == null){
|
if(z == null){
|
||||||
switch (op.getOpType()) {
|
switch (op.getOpType()) {
|
||||||
case SCALAR:
|
case SCALAR:
|
||||||
op.setZ(op.x().ulike());
|
z = x.ulike();
|
||||||
|
setZ(x.ulike(), op, oc);
|
||||||
break;
|
break;
|
||||||
case SCALAR_BOOL:
|
case SCALAR_BOOL:
|
||||||
op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape()));
|
z = Nd4j.createUninitialized(DataType.BOOL, x.shape());
|
||||||
|
setZ(z, op, oc);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
|
throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op.x().length() != op.z().length())
|
if (x.length() != z.length())
|
||||||
throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: ["
|
throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: ["
|
||||||
+ Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != ["
|
+ Arrays.toString(x.shapeInfoDataBuffer().asInt()) + "] != ["
|
||||||
+ Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]");
|
+ Arrays.toString(z.shapeInfoDataBuffer().asInt()) + "]");
|
||||||
|
|
||||||
if (extraz.get() == null)
|
if (extraz.get() == null)
|
||||||
extraz.set(new PointerPointer(32));
|
extraz.set(new PointerPointer(32));
|
||||||
|
@ -1229,38 +1255,38 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val context = AtomicAllocator.getInstance().getDeviceContext();
|
val context = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
|
|
||||||
val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
|
||||||
val hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer());
|
val hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer());
|
||||||
val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
|
val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
|
||||||
|
|
||||||
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
|
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
|
||||||
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? op.x().dataType() : op.z().dataType()), context) : null;
|
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? x.dataType() : z.dataType()), context) : null;
|
||||||
|
|
||||||
Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
|
Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
PointerPointer xShapeInfoHostPointer = extraz.get().put(
|
PointerPointer xShapeInfoHostPointer = extraz.get().put(
|
||||||
AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(),
|
AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(),
|
||||||
AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
|
AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
|
||||||
context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
|
context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
|
||||||
hostYShapeInfo, hostZShapeInfo, null, null);
|
hostYShapeInfo, hostZShapeInfo, null, null);
|
||||||
|
|
||||||
val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer();
|
val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
|
||||||
val y = op.scalar() == null ? null : ((BaseCudaDataBuffer) op.scalar().data()).getOpaqueDataBuffer();
|
val yb = op.scalar() == null ? null : ((BaseCudaDataBuffer) op.scalar().data()).getOpaqueDataBuffer();
|
||||||
val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer();
|
val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
|
||||||
|
|
||||||
switch (op.getOpType()) {
|
switch (op.getOpType()) {
|
||||||
case SCALAR_BOOL:
|
case SCALAR_BOOL:
|
||||||
nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context),
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context),
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
case SCALAR:
|
case SCALAR:
|
||||||
nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context),
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context),
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -1275,9 +1301,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected CudaContext invoke(TransformOp op) {
|
protected CudaContext invoke(TransformOp op, OpContext oc) {
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
|
|
||||||
|
INDArray x = getX(op, oc);
|
||||||
|
INDArray y = getY(op, oc);
|
||||||
|
INDArray z = getZ(op, oc);
|
||||||
|
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
|
|
||||||
//validateDataType(Nd4j.dataType(), op);
|
//validateDataType(Nd4j.dataType(), op);
|
||||||
|
@ -1295,7 +1325,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
// special temp array for IsMax along dimension
|
// special temp array for IsMax along dimension
|
||||||
INDArray ret = null;
|
INDArray ret = null;
|
||||||
|
|
||||||
Pointer xShapeInfo = allocator.getPointer(op.x().shapeInfoDataBuffer(), context);
|
Pointer xShapeInfo = allocator.getPointer(x.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
|
|
||||||
Pointer dimensionDevPointer = null;
|
Pointer dimensionDevPointer = null;
|
||||||
|
@ -1304,17 +1334,18 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
Pointer retHostShape = null;
|
Pointer retHostShape = null;
|
||||||
int dimension[] = null;
|
int dimension[] = null;
|
||||||
|
|
||||||
val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
|
||||||
var hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
|
var hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
|
||||||
|
|
||||||
|
|
||||||
if (op.z() == null) {
|
if (z == null) {
|
||||||
ret = Nd4j.createUninitialized(op.resultType(), op.x().shape(), op.x().ordering());
|
ret = Nd4j.createUninitialized(op.resultType(), x.shape(), x.ordering());
|
||||||
op.setZ(ret);
|
setZ(ret, op, oc);
|
||||||
|
z = ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
var extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? op.x().dataType() : op.z().dataType()), context) : null;
|
var extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? x.dataType() : z.dataType()), context) : null;
|
||||||
val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
|
val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
|
||||||
|
|
||||||
Pointer hostTadShapeInfo = null;
|
Pointer hostTadShapeInfo = null;
|
||||||
Pointer devTadShapeInfo = null;
|
Pointer devTadShapeInfo = null;
|
||||||
|
@ -1328,13 +1359,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
Pointer devTadOffsets = null;
|
Pointer devTadOffsets = null;
|
||||||
Pointer devMaxTadOffsets = null;
|
Pointer devMaxTadOffsets = null;
|
||||||
|
|
||||||
op.validateDataTypes(experimentalMode.get());
|
op.validateDataTypes(oc, experimentalMode.get());
|
||||||
|
|
||||||
Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context);
|
Pointer zShapeInfo = allocator.getPointer(z.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
|
|
||||||
PointerPointer xShapeInfoHostPointer =
|
PointerPointer xShapeInfoHostPointer =
|
||||||
extraz.get().put(AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), // 0
|
extraz.get().put(AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), // 0
|
||||||
context.getOldStream(), // 1
|
context.getOldStream(), // 1
|
||||||
allocator.getDeviceIdPointer(), // 2
|
allocator.getDeviceIdPointer(), // 2
|
||||||
context.getBufferAllocation(), // 3
|
context.getBufferAllocation(), // 3
|
||||||
|
@ -1356,30 +1387,30 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
retHostShape);
|
retHostShape);
|
||||||
|
|
||||||
|
|
||||||
val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer();
|
val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
|
||||||
val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer();
|
val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
|
||||||
val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer();
|
val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
|
||||||
|
|
||||||
if (op.y() != null) {
|
if (y != null) {
|
||||||
Pointer yShapeInfo = allocator.getPointer(op.y().shapeInfoDataBuffer(), context);
|
Pointer yShapeInfo = allocator.getPointer(y.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
if (op.x().length() != op.y().length() || op.x().length() != op.z().length())
|
if (x.length() != y.length() || x.length() != z.length())
|
||||||
throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform");
|
throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform");
|
||||||
|
|
||||||
switch (op.getOpType()) {
|
switch (op.getOpType()) {
|
||||||
case TRANSFORM_BOOL:
|
case TRANSFORM_BOOL:
|
||||||
case PAIRWISE_BOOL:
|
case PAIRWISE_BOOL:
|
||||||
nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1387,32 +1418,32 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
switch (op.getOpType()) {
|
switch (op.getOpType()) {
|
||||||
case TRANSFORM_ANY:
|
case TRANSFORM_ANY:
|
||||||
nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
case TRANSFORM_FLOAT:
|
case TRANSFORM_FLOAT:
|
||||||
nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
case TRANSFORM_BOOL:
|
case TRANSFORM_BOOL:
|
||||||
nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
case TRANSFORM_SAME:
|
case TRANSFORM_SAME:
|
||||||
nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
case TRANSFORM_STRICT:
|
case TRANSFORM_STRICT:
|
||||||
nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(),
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -1478,6 +1509,21 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(RandomOp op, Random rng) {
|
public INDArray exec(RandomOp op, Random rng) {
|
||||||
|
return exec(op, null, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
public INDArray exec(RandomOp op, OpContext oc, Random rng){
|
||||||
|
INDArray x = getX(op, oc);
|
||||||
|
INDArray y = getY(op, oc);
|
||||||
|
INDArray z = getZ(op, oc);
|
||||||
|
|
||||||
|
if(op instanceof BaseRandomOp && ((BaseRandomOp)op).isTripleArgRngOp() && z != null && x == null && y == null){
|
||||||
|
//Ugly hack to ensure the triple arg call occurs
|
||||||
|
//See GaussianDistribution.setZ etc
|
||||||
|
x = z;
|
||||||
|
y = z;
|
||||||
|
}
|
||||||
|
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
|
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
|
@ -1496,38 +1542,38 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val context = AtomicAllocator.getInstance().getDeviceContext();
|
val context = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
|
|
||||||
PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()),
|
PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()),
|
||||||
context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());
|
context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());
|
||||||
|
|
||||||
val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
|
||||||
val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
|
val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
|
||||||
val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
|
val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
|
||||||
|
|
||||||
val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer();
|
val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
|
||||||
val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer();
|
val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
|
||||||
val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer();
|
val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
|
||||||
|
|
||||||
if (op.x() != null && op.y() != null && op.z() != null) {
|
if (x != null && y != null && z != null) {
|
||||||
// triple arg call
|
// triple arg call
|
||||||
nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
|
nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context),
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context),
|
||||||
y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context),
|
yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context),
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
|
||||||
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context));
|
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
|
||||||
|
|
||||||
} else if (op.x() != null && op.z() != null) {
|
} else if (x != null && z != null) {
|
||||||
//double arg call
|
//double arg call
|
||||||
nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
|
nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
|
||||||
x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context),
|
xb, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context),
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
|
||||||
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()),context));
|
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()),context));
|
||||||
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// single arg call
|
// single arg call
|
||||||
nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
|
nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
|
||||||
z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
|
||||||
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context));
|
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
@ -1535,7 +1581,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, st);
|
||||||
|
|
||||||
return op.z();
|
return z;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1888,6 +1934,11 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp op) {
|
public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp op) {
|
||||||
|
return calculateOutputShape(op, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp op, OpContext opContext){
|
||||||
|
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
|
@ -1895,7 +1946,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
val hash = op.opHash();
|
val hash = op.opHash();
|
||||||
|
|
||||||
val result = new ArrayList<LongShapeDescriptor>();
|
val result = new ArrayList<LongShapeDescriptor>();
|
||||||
if(op.numInputArguments() < 1 && op.getDescriptor().getNumInputs() != -2) {
|
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
|
||||||
|
if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) {
|
||||||
if(log.isTraceEnabled()){
|
if(log.isTraceEnabled()){
|
||||||
log.trace("Could not calculate output shape for op {}: number of input args was 0",
|
log.trace("Could not calculate output shape for op {}: number of input args was 0",
|
||||||
op.getClass().getName());
|
op.getClass().getName());
|
||||||
|
@ -1903,47 +1955,75 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
val inputBuffers = new PointerPointer<>(op.inputArguments().size() * 2);
|
val inputBuffers = new PointerPointer<>(nIn * 2);
|
||||||
val inputShapes = new PointerPointer<>(op.inputArguments().size());
|
val inputShapes = new PointerPointer<>(nIn);
|
||||||
|
|
||||||
|
val inputArgs = opContext != null ? opContext.getInputArrays() : op.inputArguments();
|
||||||
int cnt= 0;
|
int cnt= 0;
|
||||||
for (val in: op.inputArguments()) {
|
for (val in: inputArgs) {
|
||||||
// NOT A TYPO: shape functions work on host side only
|
// NOT A TYPO: shape functions work on host side only
|
||||||
if (!in.isEmpty()) {
|
if (!in.isEmpty()) {
|
||||||
inputBuffers.put(cnt, in.data().addressPointer());
|
inputBuffers.put(cnt, in.data().addressPointer());
|
||||||
inputBuffers.put(cnt + op.inputArguments().size(), AtomicAllocator.getInstance().getPointer(in.data()));
|
inputBuffers.put(cnt + nIn, AtomicAllocator.getInstance().getPointer(in.data()));
|
||||||
}
|
}
|
||||||
|
|
||||||
inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
|
inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
val iArgs = op.iArgs().length > 0 ? new LongPointer(op.iArgs().length) : null;
|
int nIArgs = opContext != null ? opContext.numIArguments() : op.numIArguments();
|
||||||
|
val iArgs = nIArgs > 0 ? new LongPointer(nIArgs) : null;
|
||||||
cnt = 0;
|
cnt = 0;
|
||||||
|
if(opContext != null){
|
||||||
|
for (val i: opContext.getIArguments())
|
||||||
|
iArgs.put(cnt++, i);
|
||||||
|
} else {
|
||||||
for (val i: op.iArgs())
|
for (val i: op.iArgs())
|
||||||
iArgs.put(cnt++, i);
|
iArgs.put(cnt++, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
val tArgs = op.tArgs().length > 0 ? new DoublePointer(op.tArgs().length) : null;
|
int nTArgs = opContext != null ? opContext.numTArguments() : op.numTArguments();
|
||||||
|
val tArgs = nTArgs > 0 ? new DoublePointer(nTArgs) : null;
|
||||||
|
|
||||||
val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null;
|
int nBArgs = opContext != null ? opContext.numBArguments() : op.numBArguments();
|
||||||
|
val bArgs = nBArgs > 0 ? new BooleanPointer(nBArgs) : null;
|
||||||
|
|
||||||
val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null;
|
int nDArgs = opContext != null ? opContext.numDArguments() : op.numDArguments();
|
||||||
|
val dArgs = nDArgs > 0 ? new IntPointer(nDArgs) : null;
|
||||||
|
|
||||||
cnt = 0;
|
cnt = 0;
|
||||||
|
if(opContext != null){
|
||||||
|
for (val b: opContext.getBArguments())
|
||||||
|
bArgs.put(cnt++, b);
|
||||||
|
} else {
|
||||||
for (val b: op.bArgs())
|
for (val b: op.bArgs())
|
||||||
bArgs.put(cnt++, b);
|
bArgs.put(cnt++, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
cnt = 0;
|
cnt = 0;
|
||||||
for (val t: op.tArgs())
|
if(opContext != null){
|
||||||
tArgs.put(cnt++, t);
|
for (val b: opContext.getTArguments())
|
||||||
|
tArgs.put(cnt++, b);
|
||||||
|
} else {
|
||||||
|
for (val b: op.tArgs())
|
||||||
|
tArgs.put(cnt++, b);
|
||||||
|
}
|
||||||
|
|
||||||
cnt = 0;
|
cnt = 0;
|
||||||
val dArgs1 = op.dArgs();
|
if(opContext != null){
|
||||||
for (val d: dArgs1)
|
for (val b: opContext.getDArguments())
|
||||||
dArgs.put(cnt++, d.toInt());
|
dArgs.put(cnt++, b.toInt());
|
||||||
|
} else {
|
||||||
|
for (val b: op.dArgs())
|
||||||
|
dArgs.put(cnt++, b.toInt());
|
||||||
|
}
|
||||||
|
|
||||||
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments());
|
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null,
|
||||||
|
hash, inputBuffers, inputShapes, nIn, tArgs, nTArgs,
|
||||||
|
iArgs, nIArgs, bArgs, nBArgs, dArgs, nDArgs);
|
||||||
|
// OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments());
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
@ -127,7 +128,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
// the only entry place for TADless ops
|
// the only entry place for TADless ops
|
||||||
processAsGridOp(op);
|
processAsGridOp(op);
|
||||||
} else if (op instanceof BroadcastOp) {
|
} else if (op instanceof BroadcastOp) {
|
||||||
invoke((BroadcastOp) op);
|
invoke((BroadcastOp) op, null);
|
||||||
} else {
|
} else {
|
||||||
//logger.info("Random op: {}", op.getClass().getSimpleName());
|
//logger.info("Random op: {}", op.getClass().getSimpleName());
|
||||||
pushToGrid(new OpDescriptor(op));
|
pushToGrid(new OpDescriptor(op));
|
||||||
|
@ -238,7 +239,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
flushQueue();
|
flushQueue();
|
||||||
|
|
||||||
//logger.info("Sending TransformOp to CudaExecutioner");
|
//logger.info("Sending TransformOp to CudaExecutioner");
|
||||||
super.invoke(t);
|
super.invoke(t, null);
|
||||||
} else if (op instanceof Variance) {
|
} else if (op instanceof Variance) {
|
||||||
Variance acc = (Variance) op;
|
Variance acc = (Variance) op;
|
||||||
if (flush)
|
if (flush)
|
||||||
|
@ -258,7 +259,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
flushQueue();
|
flushQueue();
|
||||||
|
|
||||||
//logger.info("Sending ScalarOp to CudaExecutioner");
|
//logger.info("Sending ScalarOp to CudaExecutioner");
|
||||||
super.invoke(sc);
|
super.invoke(sc, null);
|
||||||
} else if (op instanceof BroadcastOp) {
|
} else if (op instanceof BroadcastOp) {
|
||||||
BroadcastOp broadcastOp = (BroadcastOp) op;
|
BroadcastOp broadcastOp = (BroadcastOp) op;
|
||||||
if (flush)
|
if (flush)
|
||||||
|
@ -268,7 +269,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
if (dimensions != null) {
|
if (dimensions != null) {
|
||||||
super.exec(broadcastOp);
|
super.exec(broadcastOp);
|
||||||
} else {
|
} else {
|
||||||
super.invoke(broadcastOp);
|
super.invoke(broadcastOp, null);
|
||||||
}
|
}
|
||||||
} else if (op instanceof IndexAccumulation) {
|
} else if (op instanceof IndexAccumulation) {
|
||||||
IndexAccumulation indexAccumulation = (IndexAccumulation) op;
|
IndexAccumulation indexAccumulation = (IndexAccumulation) op;
|
||||||
|
@ -690,7 +691,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
flushQueue();
|
flushQueue();
|
||||||
|
|
||||||
buildZ(op, new int[] {Integer.MAX_VALUE});
|
buildZ(op, new int[] {Integer.MAX_VALUE});
|
||||||
super.invoke(op, new int[] {Integer.MAX_VALUE});
|
super.invoke(op, null, new int[] {Integer.MAX_VALUE});
|
||||||
} else {
|
} else {
|
||||||
buildZ(op, dimension);
|
buildZ(op, dimension);
|
||||||
processAsGridOp(op, dimension);
|
processAsGridOp(op, dimension);
|
||||||
|
@ -708,7 +709,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
|
|
||||||
// FIXME: remove CudaContext return opType. We just don't need it
|
// FIXME: remove CudaContext return opType. We just don't need it
|
||||||
@Override
|
@Override
|
||||||
protected CudaContext invoke(BroadcastOp op) {
|
protected CudaContext invoke(BroadcastOp op, OpContext oc) {
|
||||||
|
Preconditions.checkState(oc == null);
|
||||||
processAsGridOp(op, op.getDimension());
|
processAsGridOp(op, op.getDimension());
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -716,7 +718,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
|
|
||||||
// FIXME: remove CudaContext return opType. We just don't need it
|
// FIXME: remove CudaContext return opType. We just don't need it
|
||||||
@Override
|
@Override
|
||||||
protected CudaContext invoke(ScalarOp op) {
|
protected CudaContext invoke(ScalarOp op, OpContext oc) {
|
||||||
|
Preconditions.checkState(oc == null);
|
||||||
processAsGridOp(op, null);
|
processAsGridOp(op, null);
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -724,7 +727,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
|
|
||||||
// FIXME: remove CudaContext return opType. We just don't need it
|
// FIXME: remove CudaContext return opType. We just don't need it
|
||||||
@Override
|
@Override
|
||||||
protected CudaContext invoke(TransformOp op) {
|
protected CudaContext invoke(TransformOp op, OpContext oc) {
|
||||||
|
Preconditions.checkState( oc == null);
|
||||||
processAsGridOp(op, null);
|
processAsGridOp(op, null);
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -385,6 +385,7 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUniformDtype(){
|
public void testUniformDtype(){
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
|
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));
|
SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));
|
||||||
|
|
|
@ -0,0 +1,169 @@
|
||||||
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
|
import org.nd4j.imports.TFGraphs.TFGraphTestZooModels;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.AtomicBoolean;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.Semaphore;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class SameDiffMultiThreadTests extends BaseND4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 60000L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSimple() throws Exception {
|
||||||
|
|
||||||
|
int nThreads = 4;
|
||||||
|
int nRuns = 1000;
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 10);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
|
||||||
|
|
||||||
|
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 10, 10));
|
||||||
|
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10));
|
||||||
|
SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, 10, 10));
|
||||||
|
SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 10));
|
||||||
|
SDVariable w3 = sd.var("w3", Nd4j.rand(DataType.FLOAT, 10, 10));
|
||||||
|
SDVariable b3 = sd.var("b3", Nd4j.rand(DataType.FLOAT, 10));
|
||||||
|
|
||||||
|
SDVariable l1 = sd.nn.tanh(in.mmul(w1).add(b1));
|
||||||
|
SDVariable l2 = sd.nn.sigmoid(l1.mmul(w2).add(b2));
|
||||||
|
SDVariable l3 = sd.nn.softmax("out", l2.mmul(w3).add(b3));
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, l3);
|
||||||
|
|
||||||
|
INDArray[] inputArrs = new INDArray[nThreads];
|
||||||
|
INDArray[] expOut = new INDArray[nThreads];
|
||||||
|
for( int i=0; i<nThreads; i++ ){
|
||||||
|
inputArrs[i] = Nd4j.rand(DataType.FLOAT, i+1, 10);
|
||||||
|
expOut[i] = sd.outputSingle(Collections.singletonMap("in", inputArrs[i]), "out");
|
||||||
|
}
|
||||||
|
|
||||||
|
Semaphore s = new Semaphore(nThreads);
|
||||||
|
CountDownLatch latch = new CountDownLatch(nThreads);
|
||||||
|
|
||||||
|
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
|
||||||
|
AtomicInteger[] counters = new AtomicInteger[nThreads];
|
||||||
|
doTest(sd, nThreads, nRuns, inputArrs, expOut, "in", "out", failuresByThread, counters, s, latch);
|
||||||
|
|
||||||
|
s.release(nThreads);
|
||||||
|
latch.await();
|
||||||
|
|
||||||
|
for(int i=0; i<nThreads; i++ ){
|
||||||
|
assertFalse("Thread " + i + " failed", failuresByThread[i].get());
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i=0; i<nThreads; i++ ){
|
||||||
|
assertEquals("Thread " + i + " number of runs", nRuns, counters[i].get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMobilenet() throws Exception {
|
||||||
|
TFGraphTestZooModels.currentTestDir = testDir.newFolder();
|
||||||
|
File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt");
|
||||||
|
SameDiff sd = TFGraphTestZooModels.LOADER.apply(f, "mobilenet_v2_1.0_224");
|
||||||
|
// System.out.println(sd.summary());
|
||||||
|
|
||||||
|
int nThreads = 4;
|
||||||
|
int nRuns = 30;
|
||||||
|
INDArray[] inputArrs = new INDArray[nThreads];
|
||||||
|
INDArray[] expOut = new INDArray[nThreads];
|
||||||
|
for( int i=0; i<nThreads; i++ ){
|
||||||
|
if(i == 0 || i > 2)
|
||||||
|
inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3);
|
||||||
|
else if(i == 1)
|
||||||
|
inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3);
|
||||||
|
else if(i == 2)
|
||||||
|
inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3);
|
||||||
|
|
||||||
|
expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1");
|
||||||
|
Nd4j.getExecutioner().commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
|
||||||
|
AtomicInteger[] counters = new AtomicInteger[nThreads];
|
||||||
|
Semaphore s = new Semaphore(nThreads);
|
||||||
|
CountDownLatch latch = new CountDownLatch(nThreads);
|
||||||
|
|
||||||
|
doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch);
|
||||||
|
|
||||||
|
s.release(nThreads);
|
||||||
|
latch.await();
|
||||||
|
|
||||||
|
for(int i=0; i<nThreads; i++ ){
|
||||||
|
assertFalse("Thread " + i + " failed", failuresByThread[i].get());
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i=0; i<nThreads; i++ ){
|
||||||
|
assertEquals("Thread " + i + " number of runs", nRuns, counters[i].get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut,
|
||||||
|
String inName, String outName,
|
||||||
|
AtomicBoolean[] failuresByThread, AtomicInteger[] counters, Semaphore s, CountDownLatch latch){
|
||||||
|
|
||||||
|
for( int i=0; i<nThreads; i++ ){
|
||||||
|
failuresByThread[i] = new AtomicBoolean(false);
|
||||||
|
counters[i] = new AtomicInteger(0);
|
||||||
|
final int j=i;
|
||||||
|
Thread t = new Thread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try{
|
||||||
|
s.acquire(1);
|
||||||
|
for( int i=0; i<nRuns; i++ ){
|
||||||
|
INDArray out = sd.outputSingle(Collections.singletonMap(inName, inputArrs[j]), outName);
|
||||||
|
Nd4j.getExecutioner().commit();
|
||||||
|
INDArray exp = expOut[j];
|
||||||
|
|
||||||
|
if(!exp.equals(out)){
|
||||||
|
failuresByThread[j].set(true);
|
||||||
|
log.error("Failure in thread: {}/{} - iteration {}\nExpected ={}\nActual={}", Thread.currentThread().getId(), j, i, exp, out);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(out.closeable())
|
||||||
|
out.close();
|
||||||
|
|
||||||
|
// if(i % 100 == 0){
|
||||||
|
// log.info("Thread {} at {}", Thread.currentThread().getId(), i);
|
||||||
|
// }
|
||||||
|
counters[j].addAndGet(1);
|
||||||
|
}
|
||||||
|
} catch (Throwable t){
|
||||||
|
log.error("Error in thread: {}", Thread.currentThread().getId(), t);
|
||||||
|
} finally {
|
||||||
|
latch.countDown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
t.start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -99,6 +99,10 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
@ClassRule
|
@ClassRule
|
||||||
public static TemporaryFolder folder = new TemporaryFolder();
|
public static TemporaryFolder folder = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 999999999L;
|
||||||
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void before() {
|
public void before() {
|
||||||
|
|
|
@ -36,6 +36,7 @@ import org.nd4j.evaluation.classification.Evaluation.Metric;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
|
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
|
||||||
|
@ -336,12 +337,12 @@ public class ListenerTest extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
|
||||||
preOpExecutionCount++;
|
preOpExecutionCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
opExecutionCount++;
|
opExecutionCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,8 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
||||||
public class OpExecOrderListener extends BaseListener {
|
public class OpExecOrderListener extends BaseListener {
|
||||||
|
@ -24,7 +26,7 @@ public class OpExecOrderListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
String opName = op.getName();
|
String opName = op.getName();
|
||||||
if(!opSet.contains(opName)){
|
if(!opSet.contains(opName)){
|
||||||
opNamesList.add(opName);
|
opNamesList.add(opName);
|
||||||
|
|
|
@ -6,6 +6,7 @@ import org.nd4j.autodiff.listeners.Operation;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -20,7 +21,7 @@ public class ExecPrintListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------");
|
System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------");
|
||||||
for(INDArray arr : outputs){
|
for(INDArray arr : outputs){
|
||||||
System.out.println(arr);
|
System.out.println(arr);
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.autodiff.listeners.Operation;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -56,7 +57,7 @@ public class ImportDebugListener extends BaseListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
|
||||||
//No op
|
//No op
|
||||||
|
|
||||||
for( int i=0; i<outputs.length; i++ ) {
|
for( int i=0; i<outputs.length; i++ ) {
|
||||||
|
|
|
@ -750,7 +750,8 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
||||||
|
|
||||||
|
|
||||||
INDArray toPermute = Nd4j.create(Nd4j.linspace(0, 7, 8, DataType.DOUBLE).data(), new long[] {2, 2, 2});
|
INDArray toPermute = Nd4j.create(Nd4j.linspace(0, 7, 8, DataType.DOUBLE).data(), new long[] {2, 2, 2});
|
||||||
INDArray permuted = toPermute.permute(2, 1, 0);
|
INDArray permuted = toPermute.dup().permute(2, 1, 0);
|
||||||
|
boolean eq = toPermute.equals(permuted);
|
||||||
assertNotEquals(toPermute, permuted);
|
assertNotEquals(toPermute, permuted);
|
||||||
|
|
||||||
INDArray permuteOther = toPermute.permute(1, 2, 0);
|
INDArray permuteOther = toPermute.permute(1, 2, 0);
|
||||||
|
|
|
@ -86,6 +86,9 @@ class FunctionalOpExecutioner extends OpExecutioner {
|
||||||
case _ => op.z()
|
case _ => op.z()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def exec(op: Op, context: OpContext): INDArray =
|
||||||
|
Nd4j.getExecutioner.exec(op, context)
|
||||||
|
|
||||||
def exec(op: FilterOps): INDArray = {
|
def exec(op: FilterOps): INDArray = {
|
||||||
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
|
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
|
||||||
for (i <- 0 until op.x().length().toInt) {
|
for (i <- 0 until op.x().length().toInt) {
|
||||||
|
@ -408,6 +411,9 @@ class FunctionalOpExecutioner extends OpExecutioner {
|
||||||
def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] =
|
def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] =
|
||||||
Nd4j.getExecutioner.calculateOutputShape(op)
|
Nd4j.getExecutioner.calculateOutputShape(op)
|
||||||
|
|
||||||
|
def calculateOutputShape(op: CustomOp, ctx: OpContext): java.util.List[LongShapeDescriptor] =
|
||||||
|
Nd4j.getExecutioner.calculateOutputShape(op, ctx)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Equivalent to calli
|
* Equivalent to calli
|
||||||
*/
|
*/
|
||||||
|
|
Loading…
Reference in New Issue