From f79207033b06322ab35bd6000e43a00d472fb9e8 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 20 Mar 2020 21:24:39 +1100 Subject: [PATCH] SameDiff multi-threaded inference (#263) * #8682 Don't log openmp BLAS threads for CUDA Signed-off-by: Alex Black * #8654 Add SameDiff multi-threaded tests Signed-off-by: Alex Black * Switching to op context for SameDiff exec Signed-off-by: Alex Black * Next steps Signed-off-by: Alex Black * Most back to passing Signed-off-by: Alex Black * Fixes Signed-off-by: Alex Black * Better tests, test refactoring Signed-off-by: Alex Black * Small tweak Signed-off-by: Alex Black * Code duplication reduction Signed-off-by: Alex Black * More code deduplication Signed-off-by: Alex Black * CUDA fixes Signed-off-by: Alex Black * More CUDA fixes Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black * Small fix Signed-off-by: Alex Black * ND4S small fixes Signed-off-by: Alex Black --- .../functions/DifferentialFunction.java | 5 + .../nd4j/autodiff/listeners/BaseListener.java | 5 +- .../org/nd4j/autodiff/listeners/Listener.java | 7 +- .../debugging/ArraySavingListener.java | 3 +- .../debugging/ExecDebuggingListener.java | 3 +- .../debugging/OpBenchmarkListener.java | 5 +- .../autodiff/listeners/impl/UIListener.java | 3 +- .../listeners/profiler/ProfilingListener.java | 5 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 20 - .../samediff/internal/InferenceSession.java | 78 +- .../samediff/internal/TrainingSession.java | 7 +- .../ActivationGradientCheckListener.java | 4 +- .../NonInplaceValidationListener.java | 21 +- .../linalg/api/ops/BaseIndexAccumulation.java | 6 + .../nd4j/linalg/api/ops/BaseOpContext.java | 39 + .../nd4j/linalg/api/ops/BaseReduceBoolOp.java | 24 +- .../linalg/api/ops/BaseReduceFloatOp.java | 34 +- .../nd4j/linalg/api/ops/BaseReduceLongOp.java | 24 +- .../nd4j/linalg/api/ops/BaseReduceSameOp.java | 30 +- .../nd4j/linalg/api/ops/BaseScalarBoolOp.java | 6 + .../org/nd4j/linalg/api/ops/BaseScalarOp.java | 7 + .../linalg/api/ops/BaseTransformAnyOp.java | 7 +- .../linalg/api/ops/BaseTransformBoolOp.java | 24 +- .../linalg/api/ops/BaseTransformFloatOp.java | 26 +- .../linalg/api/ops/BaseTransformSameOp.java | 28 +- .../linalg/api/ops/BaseTransformStrictOp.java | 25 +- .../org/nd4j/linalg/api/ops/CustomOp.java | 8 +- .../nd4j/linalg/api/ops/DynamicCustomOp.java | 26 +- .../java/org/nd4j/linalg/api/ops/NoOp.java | 8 + .../org/nd4j/linalg/api/ops/OpContext.java | 14 +- .../org/nd4j/linalg/api/ops/ReduceOp.java | 4 +- .../org/nd4j/linalg/api/ops/TransformOp.java | 4 +- .../linalg/api/ops/custom/ScatterUpdate.java | 6 + .../ops/executioner/DefaultOpExecutioner.java | 78 +- .../api/ops/executioner/OpExecutioner.java | 9 + .../impl/layers/ExternalErrorsFunction.java | 6 + .../api/ops/impl/summarystats/Variance.java | 40 +- .../api/ops/impl/transforms/MaxOut.java | 18 +- .../linalg/api/ops/random/BaseRandomOp.java | 10 + .../ops/random/impl/BinomialDistribution.java | 5 + .../ops/random/impl/GaussianDistribution.java | 5 + .../random/impl/LogNormalDistribution.java | 5 + .../impl/TruncatedNormalDistribution.java | 5 + .../java/org/nd4j/linalg/factory/Nd4j.java | 4 + .../java/org/nd4j/nativeblas/Nd4jBlas.java | 6 +- .../nd4j/linalg/jcublas/blas/CudaBlas.java | 5 + .../ops/executioner/CudaExecutioner.java | 522 ++++++++------ .../ops/executioner/CudaGridExecutioner.java | 20 +- .../nativecpu/ops/NativeOpExecutioner.java | 665 ++++++++++-------- .../opvalidation/RandomOpValidation.java | 1 + .../samediff/SameDiffMultiThreadTests.java | 169 +++++ .../nd4j/autodiff/samediff/SameDiffTests.java | 4 + .../samediff/listeners/ListenerTest.java | 5 +- .../listener/OpExecOrderListener.java | 4 +- .../imports/listeners/ExecPrintListener.java | 3 +- .../listeners/ImportDebugListener.java | 3 +- .../org/nd4j/linalg/NDArrayTestsFortran.java | 3 +- .../nd4s/ops/FunctionalOpExecutioner.scala | 6 + 58 files changed, 1426 insertions(+), 691 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 655e4159f..94bda0b78 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -31,6 +31,7 @@ import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.shade.jackson.annotation.JsonIgnore; @@ -708,6 +709,10 @@ public abstract class DifferentialFunction { throw new ND4JIllegalStateException("calculateOutputShape() method leaked out for [" + this.opName() + "]"); } + public List calculateOutputShape(OpContext oc){ + throw new ND4JIllegalStateException("calculateOutputShape(OpContext) method leaked out for [" + this.opName() + "]"); + } + /** * Calculate the data types for the output arrays. * Though datatypes can also be inferred from {@link #calculateOutputShape()}, this method differs in that it does not diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java index 61a5e75a3..6978a79d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java @@ -5,6 +5,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; /** @@ -60,12 +61,12 @@ public abstract class BaseListener implements Listener { } @Override - public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) { //No op } @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { //No op } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java index 18e3b934b..4ed7df6c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java @@ -5,6 +5,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; /** @@ -104,7 +105,7 @@ public interface Listener { * @param at Current iteration/epoch etc * @param op Operation that has just been executed */ - void preOpExecution(SameDiff sd, At at, SameDiffOp op); + void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext); /** * Called at the end of each operation execution
@@ -117,7 +118,7 @@ public interface Listener { * @param op Operation that has just been executed * @param outputs The output arrays for the just-executed operation */ - void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs); + void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs); /** * Called when any activation becomes available. @@ -127,7 +128,7 @@ public interface Listener { * Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}
* It is guaranteed to be called for variables from requiredVariables().
*
- * Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, INDArray[])} - + * Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, OpContext, INDArray[])} - * both contain the same information/arrays * * @param sd The SameDiff instance diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java index 6b64c69d8..9770dc50c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java @@ -9,6 +9,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; @@ -44,7 +45,7 @@ public class ArraySavingListener extends BaseListener { @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { List outNames = op.getOutputsOfOp(); for(int i=0; i this.minRuntime) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java index 6c38c6c9c..b4f3a371b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java @@ -19,6 +19,7 @@ import org.nd4j.graph.UIInfoType; import org.nd4j.graph.UIStaticInfoRecord; import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.primitives.Pair; @@ -410,7 +411,7 @@ public class UIListener extends BaseListener { @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { //Do training set evaluation, if required diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java index 9b92b0412..3dc21876e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java @@ -30,6 +30,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -192,7 +193,7 @@ public class ProfilingListener extends BaseListener { } @Override - public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) { if (logActive) { opStartNano = System.nanoTime(); @@ -202,7 +203,7 @@ public class ProfilingListener extends BaseListener { } @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { if (logActive) { long now = System.nanoTime(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index de421b297..7ca809b2d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -105,7 +105,6 @@ import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; * In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)} */ @AllArgsConstructor -@Builder @Slf4j public class SameDiff extends SDBaseOps { protected static final String GRAD_FN_KEY = "grad"; @@ -1232,25 +1231,6 @@ public class SameDiff extends SDBaseOps { return result; } - - /** - * Create a new SameDiff instance from an existing instance. - * Note that state (variables and functions) is shared between the two SameDiff instance - * - * @param originalSameDiff Original SameDiff instance - * @return Copy - */ - public static SameDiff create(SameDiff originalSameDiff) { - SameDiff ret = SameDiff.builder() - .sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances) - .build(); - ret.variables.putAll(originalSameDiff.variables); - //ensuring proper sameDiff reference - DifferentialFunctionFactory differentialFunctionFactory = new DifferentialFunctionFactory(ret); - ret.functionFactory = differentialFunctionFactory; - return ret; - } - @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 9b8d751eb..26bf82893 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.internal; import lombok.*; import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.Pointer; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.Listener; @@ -46,6 +47,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import java.util.*; @@ -65,7 +67,7 @@ import java.util.*; * @author Alex Black */ @Slf4j -public class InferenceSession extends AbstractSession { +public class InferenceSession extends AbstractSession> { private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" + "Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed."; @@ -83,6 +85,8 @@ public class InferenceSession extends AbstractSession { private IdentityDependencyTracker arrayUseTracker = new IdentityDependencyTracker<>(); + private Map opContexts = new HashMap<>(); + public InferenceSession(@NonNull SameDiff sameDiff) { super(sameDiff); mmgr = new ArrayCacheMemoryMgr(); @@ -204,18 +208,19 @@ public class InferenceSession extends AbstractSession { } @Override - public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + public INDArray[] getOutputs(Pair opPair, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { + SameDiffOp op = opPair.getFirst(); at.setFrameIter(outputFrameIter); if (listeners != null && listeners.size() > 0) { SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName()); for (Listener l : listeners) { if (l.isActive(at.operation())) - l.preOpExecution(sameDiff, at, sdOp); + l.preOpExecution(sameDiff, at, sdOp, opPair.getSecond()); } } - INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); + INDArray[] out = doExec(op.getOp(), opPair.getRight(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); if (log.isTraceEnabled()) { StringBuilder sb = new StringBuilder(); @@ -246,7 +251,7 @@ public class InferenceSession extends AbstractSession { } - l.opExecution(sameDiff, at, batch, op, out); + l.opExecution(sameDiff, at, batch, op, opPair.getSecond(), out); for (String varName : namedOuts.keySet()) { l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName)); @@ -255,6 +260,8 @@ public class InferenceSession extends AbstractSession { } } op.getOp().clearArrays(); + if(opPair.getSecond() != null) + opPair.getSecond().purge(); //Record array uses for memory management/deallocation @@ -343,7 +350,7 @@ public class InferenceSession extends AbstractSession { return out; } - public INDArray[] doExec(DifferentialFunction op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + public INDArray[] doExec(DifferentialFunction op, OpContext opContext, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs) { int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size()) @@ -467,31 +474,31 @@ public class InferenceSession extends AbstractSession { return new INDArray[]{out}; } else if (op instanceof Assert) { Assert a = (Assert)op; - boolean condition = a.getInputArgument(0).getDouble(0) != 0.0; + boolean condition = opContext.getInputArray(0).getDouble(0) != 0.0; if(!condition){ //Assertion failed String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution"; if(a.numInputArguments() >= 3) { - INDArray msg = a.getInputArgument(2); + INDArray msg = opContext.getInputArray(2); if (msg != null && msg.dataType() == DataType.UTF8) { s += ": " + msg.getString(0); } } if(a.numInputArguments() >= 5){ - INDArray arr = a.getInputArgument(4); + INDArray arr = opContext.getInputArray(4); s += "\n" + arr; } throw new IllegalStateException(s); } - return ((Assert) op).outputArguments().toArray(new INDArray[0]); + return opContext.getOutputArrays().toArray(new INDArray[0]); } else if (op instanceof CustomOp) { CustomOp c = (CustomOp) op; - Nd4j.exec(c); - return c.outputArguments().toArray(new INDArray[0]); + Nd4j.exec(c, opContext); + return opContext.getOutputArrays().toArray(new INDArray[0]); } else if (op instanceof Op) { Op o = (Op) op; - Nd4j.exec(o); - return new INDArray[]{o.z()}; + Nd4j.exec(o, opContext); + return new INDArray[]{opContext.getOutputArray(0)}; } else { throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName()); } @@ -774,7 +781,7 @@ public class InferenceSession extends AbstractSession { } @Override - public SameDiffOp getAndParameterizeOp(String opName, FrameIter frameIter, Set opInputs, Set allIterInputs, + public Pair getAndParameterizeOp(String opName, FrameIter frameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs, Map placeholderValues, Set allReqVariables) { SameDiffOp sdo = sameDiff.getOps().get(opName); DifferentialFunction df = sdo.getOp(); @@ -786,7 +793,7 @@ public class InferenceSession extends AbstractSession { if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) { //Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case - return sdo; + return new Pair<>(sdo, null); } //Infer the args based on the inputs (variable + frame + iteration) @@ -839,24 +846,39 @@ public class InferenceSession extends AbstractSession { //TODO let's find a way to use in-place modification for loops where possible to reduce memory requirements boolean isLoop = !frameIter.getFrame().equals(OUTER_FRAME) && frameIter.getIteration() > 0; + OpContext oc = opContexts.get(opName); + if(oc == null){ + oc = Nd4j.getExecutioner().buildContext(); + opContexts.put(opName, oc); + } + if (df instanceof CustomOp) { DynamicCustomOp customOp = (DynamicCustomOp) df; if (args != null) { - customOp.setInputArguments(args); + oc.setInputArrays(args); } if (df instanceof Identity) { //We don't need to allocate an output array for Identity, we pass through the input array without copying - return sdo; + return new Pair<>(sdo, oc); } - List outShape = customOp.calculateOutputShape(); + if(customOp.numIArguments() > 0) + oc.setIArguments(customOp.iArgs()); + if(customOp.numDArguments() > 0) + oc.setDArguments(customOp.dArgs()); + if(customOp.numTArguments() > 0) + oc.setTArguments(customOp.tArgs()); + if(customOp.numBArguments() > 0) + oc.setBArguments(customOp.bArgs()); + + + List outShape = customOp.calculateOutputShape(oc); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); String[] outNames = df.outputVariablesNames(); Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" + " with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length); for (int i = 0; i < outShape.size(); i++) { - INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i)); LongShapeDescriptor reqShape = outShape.get(i); //Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/deeplearning4j/deeplearning4j/issues/6872 @@ -870,7 +892,7 @@ public class InferenceSession extends AbstractSession { //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc boolean isOutput = allReqVariables.contains(outNames[i]); INDArray out = mmgr.allocate(isOutput, reqShape); - customOp.setOutputArgument(i, out); + oc.setOutputArray(i, out); } } else if (df instanceof Op) { @@ -909,9 +931,9 @@ public class InferenceSession extends AbstractSession { } if (args != null && args.length > 0) { - op.setX(args[0]); + oc.setInputArray(0, args[0]); if (args.length == 2 && !axisArg) - op.setY(args[1]); + oc.setInputArray(1, args[1]); } @@ -920,18 +942,18 @@ public class InferenceSession extends AbstractSession { boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]); if (emptyReduce) { //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc - INDArray z = mmgr.allocate(false, op.x().dataType(), op.x().shape()); - op.setZ(z); + INDArray z = mmgr.allocate(false, oc.getInputArray(0).dataType(), oc.getInputArray(0).shape()); + oc.setOutputArray(0, z); } else { - List outputShape = ((BaseOp) op).calculateOutputShape(); + List outputShape = ((BaseOp) op).calculateOutputShape(oc); Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); LongShapeDescriptor lsd = outputShape.get(0); INDArray z = mmgr.allocate(isOutput, lsd); - op.setZ(z); + oc.setOutputArray(0, z); } } - return sdo; + return new Pair<>(sdo, oc); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java index 992a747a0..e683acc47 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java @@ -11,10 +11,12 @@ import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.primitives.AtomicDouble; +import org.nd4j.linalg.primitives.Pair; import java.util.*; @@ -135,10 +137,11 @@ public class TrainingSession extends InferenceSession { } @Override - public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + public INDArray[] getOutputs(Pair opPair, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { //Get outputs from InferenceSession - INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables); + INDArray[] out = super.getOutputs(opPair, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables); + SameDiffOp op = opPair.getFirst(); List outputs = op.getOutputsOfOp(); int outIdx = 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java index a8865f972..d1137746d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java @@ -12,6 +12,8 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.List; + +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; /** @@ -36,7 +38,7 @@ public class ActivationGradientCheckListener extends BaseListener { } @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener"); Preconditions.checkState(eps != 0.0, "Epsilon has not been set"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java index 9eee099a5..3d28b29fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java @@ -14,7 +14,10 @@ import org.nd4j.linalg.api.ops.Op; import java.security.MessageDigest; import java.util.Arrays; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; + +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; public class NonInplaceValidationListener extends BaseListener { @@ -33,25 +36,25 @@ public class NonInplaceValidationListener extends BaseListener { } @Override - public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext oc) { if(op.getOp().isInPlace()){ //Don't check inplace op return; } if(op.getOp() instanceof Op){ Op o = (Op)op.getOp(); - if(o.x() == null){ + if(oc.getInputArray(0) == null){ //No input op return; - } else if(o.y() == null){ - opInputsOrig = new INDArray[]{o.x()}; - opInputs = new INDArray[]{o.x().dup()}; + } else if(oc.getInputArray(1) == null){ + opInputsOrig = new INDArray[]{oc.getInputArray(0)}; + opInputs = new INDArray[]{oc.getInputArray(0).dup()}; } else { - opInputsOrig = new INDArray[]{o.x(), o.y()}; - opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; + opInputsOrig = new INDArray[]{oc.getInputArray(0), oc.getInputArray(1)}; + opInputs = new INDArray[]{oc.getInputArray(0).dup(), oc.getInputArray(1).dup()}; } } else if(op.getOp() instanceof DynamicCustomOp){ - val arr = ((DynamicCustomOp) op.getOp()).inputArguments(); + List arr = oc.getInputArrays(); // ((DynamicCustomOp) op.getOp()).inputArguments(); opInputs = new INDArray[arr.size()]; opInputsOrig = new INDArray[arr.size()]; for( int i=0; i calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc){ + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java index 0139a9db5..c7d71db04 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java @@ -55,6 +55,11 @@ public abstract class BaseOpContext implements OpContext { return fastpath_i; } + @Override + public int numIArguments() { + return fastpath_i.size(); + } + @Override public void setTArguments(double... arguments) { fastpath_t.clear(); @@ -67,6 +72,11 @@ public abstract class BaseOpContext implements OpContext { return fastpath_t; } + @Override + public int numTArguments() { + return fastpath_t.size(); + } + @Override public void setBArguments(boolean... arguments) { fastpath_b.clear(); @@ -79,6 +89,11 @@ public abstract class BaseOpContext implements OpContext { return fastpath_b; } + @Override + public int numBArguments() { + return fastpath_b.size(); + } + @Override public void setDArguments(DataType... arguments) { fastpath_d.clear(); @@ -91,6 +106,11 @@ public abstract class BaseOpContext implements OpContext { return fastpath_d; } + @Override + public int numDArguments() { + return fastpath_d.size(); + } + @Override public void setInputArray(int index, @NonNull INDArray array) { fastpath_in.put(index, array); @@ -110,6 +130,16 @@ public abstract class BaseOpContext implements OpContext { return result; } + @Override + public int numInputArguments() { + return fastpath_in.size(); + } + + @Override + public INDArray getInputArray(int idx) { + return fastpath_in.get(idx); + } + @Override public List getOutputArrays() { val result = new ArrayList(); @@ -129,6 +159,15 @@ public abstract class BaseOpContext implements OpContext { fastpath_out.put(index, array); } + @Override + public INDArray getOutputArray(int i) { + return fastpath_out.get(i); + } + + @Override + public int numOutputArguments() { + return fastpath_out.size(); + } @Override public void setInputArrays(@NonNull List arrays) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java index dd2072758..af022c86f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java @@ -72,19 +72,33 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo } @Override - public boolean validateDataTypes() { - if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(),"Op.X type must be the same as Op.Y:" + + public DataType resultType(OpContext oc) { + return DataType.BOOL; + } + + @Override + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null) + Preconditions.checkArgument(x.dataType() == y.dataType(),"Op.X type must be the same as Op.Y:" + " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); - if (z() != null) - Preconditions.checkArgument(z().isB(), "Op.X type must be bool: got type %s for op %s", x.dataType(), getClass()); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null) + Preconditions.checkArgument(z.isB(), "Op.Z type must be bool: got type %s for op %s", z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java index 6f3722011..29860aee9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java @@ -90,27 +90,43 @@ public abstract class BaseReduceFloatOp extends BaseReduceOp implements ReduceFl @Override public DataType resultType() { - if (this.x() != null && this.x().isR()) - return this.x().dataType(); + return resultType(null); + } + + @Override + public DataType resultType(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if (x != null && x.isR()) + return x.dataType(); return Nd4j.defaultFloatingPointType(); } @Override - public boolean validateDataTypes() { - if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(), - "Op.X [%s] type must be the same as Op.Y [%s] for op %s: x.shape=%ndShape, y.shape=%ndShape", x().dataType(), - y().dataType(), getClass().getName(), x(), y() ); + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null) + Preconditions.checkArgument(x.dataType() == y.dataType(), + "Op.X [%s] type must be the same as Op.Y [%s] for op %s: x.shape=%ndShape, y.shape=%ndShape", x.dataType(), + y.dataType(), getClass().getName(), x, y ); - if (z() != null) - Preconditions.checkArgument(z().isR(),"Op.Z (result array) must be one of floating types: z datatype = %s", z().dataType()); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null) + Preconditions.checkArgument(z.isR(),"Op.Z (result array) must be one of floating types: z datatype = %s", z.dataType()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java index 9f82bb6b4..b5131eb61 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java @@ -69,19 +69,33 @@ public abstract class BaseReduceLongOp extends BaseReduceOp implements ReduceLon } @Override - public boolean validateDataTypes() { - if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X type must be the same as Op.Y:" + + public DataType resultType(OpContext oc) { + return DataType.LONG; + } + + @Override + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null) + Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X type must be the same as Op.Y:" + " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); - if (z() != null) - Preconditions.checkArgument( z().dataType() == DataType.LONG,"Op.Z must be long: has type %s for op %s", z().dataType(), getClass()); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null) + Preconditions.checkArgument( z.dataType() == DataType.LONG,"Op.Z must be long: has type %s for op %s", z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java index 0aa4460c3..015b87b5d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java @@ -77,26 +77,42 @@ public abstract class BaseReduceSameOp extends BaseReduceOp implements ReduceSam } @Override - public boolean validateDataTypes() { - if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(),"Op.X type must be the same as Op.Y type:" + + public DataType resultType(OpContext oc){ + return oc.getInputArray(0).dataType(); + } + + @Override + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null) + Preconditions.checkArgument(x.dataType() == y.dataType(),"Op.X type must be the same as Op.Y type:" + " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); - if (z() != null) - Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must be the same as Op.X type. Op.X.datatype=%s, " + - "Op.Z.datatype=%s", x().dataType(), z.dataType()); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null) + Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must be the same as Op.X type. Op.X.datatype=%s, " + + "Op.Z.datatype=%s", x.dataType(), z.dataType()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if(x == null) return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); - return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, this.resultType())); + DataType rt = oc != null ? resultType(oc) : resultType(); + return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, rt)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java index 082465cbe..8cb7e50b4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java @@ -98,6 +98,12 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp { @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java index e6df6ceec..254069929 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java @@ -115,6 +115,13 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp { @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + val ret = new ArrayList(1); long[] s; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java index 71749bdda..7aa24a60b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java @@ -89,7 +89,12 @@ public abstract class BaseTransformAnyOp extends BaseTransformOp implements Tran } @Override - public boolean validateDataTypes(boolean experimentalMode) { + public DataType resultType(OpContext oc) { + return oc.getInputArray(0).dataType(); + } + + @Override + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { return true; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java index fd19f23d0..d4e69db5d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java @@ -88,20 +88,34 @@ public abstract class BaseTransformBoolOp extends BaseTransformOp implements Tra } @Override - public boolean validateDataTypes(boolean experimentalMode) { + public DataType resultType(OpContext oc) { + return DataType.BOOL; + } + + @Override + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X must be the same type as Op.Y: " + - "x.datatype=%s, y.datatype=%s", x().dataType(), y.dataType()); + Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must be the same type as Op.Y: " + + "x.datatype=%s, y.datatype=%s", x.dataType(), y.dataType()); - if (z() != null) - Preconditions.checkArgument(z().isB(),"Op.Z type must be bool: z.datatype=%s for op %s", z().dataType(), getClass()); + if (z != null) + Preconditions.checkArgument(z.isB(),"Op.Z type must be bool: z.datatype=%s for op %s", z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), DataType.BOOL)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java index 12516577c..42e2ef278 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java @@ -72,19 +72,37 @@ public abstract class BaseTransformFloatOp extends BaseTransformOp implements Tr } @Override - public boolean validateDataTypes(boolean experimentalMode) { - if (y() != null && !experimentalMode) { + public DataType resultType(OpContext oc) { + if (oc.getInputArray(0) != null && oc.getInputArray(0).isR()) + return oc.getInputArray(0).dataType(); + + return Nd4j.defaultFloatingPointType(); + } + + @Override + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + + if (y != null && !experimentalMode) { Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y"); } - if (z() != null) - Preconditions.checkArgument(z().isR(),"Op.Z must be one of floating types: z.datatype=%s for op %s", z().dataType(), getClass()); + if (z != null) + Preconditions.checkArgument(z.isR(),"Op.Z must be one of floating types: z.datatype=%s for op %s", z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.isR() ? x.dataType() : Nd4j.defaultFloatingPointType())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java index b04c24c8c..b7d0ff4ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java @@ -89,22 +89,36 @@ public abstract class BaseTransformSameOp extends BaseTransformOp implements Tra } @Override - public boolean validateDataTypes(boolean experimentalMode) { - if (y() != null) { - Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X type must be the same as Op.Y type: x.datatype=%s, y.datatype=%s for op %s", - x().dataType(), y().dataType(), getClass()); + public DataType resultType(OpContext oc) { + return oc.getInputArray(0).dataType(); + } + + @Override + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (y != null) { + Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X type must be the same as Op.Y type: x.datatype=%s, y.datatype=%s for op %s", + x.dataType(), y.dataType(), getClass()); } - if (z() != null) - Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must be the same as Op.X type: x.datatype=%s, z.datatype=%s for op %s", - x().dataType(), z.dataType(), getClass()); + if (z != null) + Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must be the same as Op.X type: x.datatype=%s, z.datatype=%s for op %s", + x.dataType(), z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java index d2a4dccc3..963138880 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java @@ -76,20 +76,28 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T return this.x().dataType(); } + @Override + public DataType resultType(OpContext opContext) { + return opContext.getInputArray(0).dataType(); + } + @Override - public boolean validateDataTypes(boolean experimentalMode) { - Preconditions.checkArgument(x().isR(), "Op.X must be one of floating types: x.datatype=%s for op %s", x().dataType(), getClass()); + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + Preconditions.checkArgument(x.isR(), "Op.X must be one of floating types: x.datatype=%s for op %s", x.dataType(), getClass()); - if (y() != null) { - Preconditions.checkArgument(y().isR(), "Op.Y must be one of floating types: y.datatype=%s for op %s", y().dataType(), getClass()); + if (y != null) { + Preconditions.checkArgument(y.isR(), "Op.Y must be one of floating types: y.datatype=%s for op %s", y.dataType(), getClass()); if (!experimentalMode) Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y"); } if (z() != null) - Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must have the same type as Op.X: x.datatype=%s, z.datatype=%s for op %s", + Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must have the same type as Op.X: x.datatype=%s, z.datatype=%s for op %s", x.dataType(), z.dataType(), getClass()); return true; @@ -102,6 +110,13 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.dataType())); } + @Override + public List calculateOutputShape(OpContext oc) { + if(oc.getInputArray(0) == null) + return Collections.emptyList(); + return Collections.singletonList(LongShapeDescriptor.fromShape(oc.getInputArray(0).shape(), oc.getInputArray(0).dataType())); + } + @Override public List calculateOutputDataTypes(List dataTypes){ //All strict tranform ops: FP in, FP out diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java index befdfb605..cdf8e3b36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java @@ -108,10 +108,16 @@ public interface CustomOp { /** * Calculate the output shape for this op - * @return + * @return Output array shapes */ List calculateOutputShape(); + /** + * Calculate the output shape for this op + * @return Output array shapes + */ + List calculateOutputShape(OpContext opContext); + /** * Get the custom op descriptor if one is available. * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index f4116ba3e..3fe90bdbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -493,6 +493,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { val descriptor = getDescriptor(); if (outputShapes != null && !outputShapes.isEmpty()) return outputShapes; @@ -504,34 +509,41 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { //not fully initialized: missing integer args - if (descriptor.getNumIArgs() >= 0 && numIArguments() < descriptor.getNumIArgs()) { + int nI = oc != null ? oc.numIArguments() : numIArguments(); + if (descriptor.getNumIArgs() >= 0 && nI < descriptor.getNumIArgs()) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: not fully initialized ({} IArgs specified, " + - "{} required)", getClass().getName(),numIArguments(), descriptor.getNumIArgs()); + "{} required)", getClass().getName(), nI, descriptor.getNumIArgs()); } return Collections.emptyList(); } //not fully initialized: missing floating point args - if (descriptor.getNumTArgs() >= 0 && numTArguments() < descriptor.getNumTArgs()) { + int nT = oc != null ? oc.numTArguments() : numTArguments(); + if (descriptor.getNumTArgs() >= 0 && nT < descriptor.getNumTArgs()) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: not fully initialized ({} TArgs specified, " + - "{} required)", getClass().getName(),numTArguments(), descriptor.getNumTArgs()); + "{} required)", getClass().getName(), nT, descriptor.getNumTArgs()); } return Collections.emptyList(); } //not fully initialized: missing INDArray input args - if(descriptor.getNumInputs() >= 0 && numInputArguments() < descriptor.getNumInputs()){ + int nIn = oc != null ? oc.numInputArguments() : numInputArguments(); + if(descriptor.getNumInputs() >= 0 && nIn < descriptor.getNumInputs()){ if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: not fully initialized ({} input (INDArray) args specified, " + - "{} required)", getClass().getName(),numInputArguments(), descriptor.getNumInputs()); + "{} required)", getClass().getName(), nIn, descriptor.getNumInputs()); } return Collections.emptyList(); } - List ret = Nd4j.getExecutioner().calculateOutputShape(this); + List ret; + if(oc == null) + ret = Nd4j.getExecutioner().calculateOutputShape(this); + else + ret = Nd4j.getExecutioner().calculateOutputShape(this, oc); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java index 554ad917e..b4cf2d05a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java @@ -89,6 +89,14 @@ public class NoOp extends DynamicCustomOp { return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor()); } + @Override + public List calculateOutputShape(OpContext oc){ + if(oc.getInputArrays() != null && !oc.getInputArrays().isEmpty()){ + return Collections.singletonList(oc.getInputArray(0).shapeDescriptor()); + } + return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor()); + } + @Override public List calculateOutputDataTypes(List inputDataTypes){ return Collections.singletonList(DataType.BOOL); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index 62a4906a7..4bda3701e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -39,12 +39,15 @@ public interface OpContext extends AutoCloseable { List getIArguments(); + int numIArguments(); + /** * This method sets floating point arguments required for operation * @param arguments */ void setTArguments(double... arguments); List getTArguments(); + int numTArguments(); /** * This method sets data type arguments required for operation @@ -52,14 +55,15 @@ public interface OpContext extends AutoCloseable { */ void setDArguments(DataType... arguments); List getDArguments(); + int numDArguments(); /** * This method sets boolean arguments required for operation * @param arguments */ void setBArguments(boolean... arguments); - List getBArguments(); + int numBArguments(); /** * This method sets root-level seed for rng @@ -99,6 +103,10 @@ public interface OpContext extends AutoCloseable { */ List getInputArrays(); + int numInputArguments(); + + INDArray getInputArray(int idx); + /** * This method adds INDArray as output for future op call * @param index @@ -124,6 +132,10 @@ public interface OpContext extends AutoCloseable { */ List getOutputArrays(); + INDArray getOutputArray(int i); + + int numOutputArguments(); + /** * This method returns pointer to context, to be used during native op execution * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java index 8f1814dfd..23d81c5b4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java @@ -86,7 +86,9 @@ public interface ReduceOp extends Op { */ DataType resultType(); - boolean validateDataTypes(); + DataType resultType(OpContext oc); + + boolean validateDataTypes(OpContext oc); Number getFinalResult(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/TransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/TransformOp.java index 9c3f9b423..f50116d32 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/TransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/TransformOp.java @@ -31,7 +31,9 @@ public interface TransformOp extends Op { */ DataType resultType(); + DataType resultType(OpContext opContext); + Type getOpType(); - boolean validateDataTypes(boolean experimentalMode); + boolean validateDataTypes(OpContext opContext, boolean experimentalMode); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java index 313b7ccb4..50c5db75e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOpDescriptor; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -237,6 +238,11 @@ public class ScatterUpdate implements CustomOp { return Nd4j.getExecutioner().calculateOutputShape(this); } + @Override + public List calculateOutputShape(OpContext opContext) { + return Nd4j.getExecutioner().calculateOutputShape(this, opContext); + } + @Override public CustomOpDescriptor getDescriptor() { return op.getDescriptor(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java index aea251ebd..c60b11d23 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java @@ -55,7 +55,7 @@ import java.util.*; * @author Adam Gibson */ @Slf4j -public class DefaultOpExecutioner implements OpExecutioner { +public abstract class DefaultOpExecutioner implements OpExecutioner { private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: deeplearning4j.org/docs/latest/nd4j-overview#workspaces-panic"; @@ -108,9 +108,10 @@ public class DefaultOpExecutioner implements OpExecutioner { } @Override - public INDArray exec(Op op) { - throw new IllegalStateException("Java computation no longer supported"); - } + public abstract INDArray exec(Op op); + + @Override + public abstract INDArray exec(Op op, OpContext opContext); @Override public Op execAndReturn(Op op) { @@ -175,24 +176,16 @@ public class DefaultOpExecutioner implements OpExecutioner { } @Override - public INDArray exec(ReduceOp op) { - throw new UnsupportedOperationException("Java computation no longer supported"); - } + public abstract INDArray exec(ReduceOp op); @Override - public INDArray exec(Variance accumulation) { - throw new UnsupportedOperationException("Operation should use exec special"); - } + public abstract INDArray exec(Variance accumulation); @Override - public INDArray exec(IndexAccumulation op) { - throw new UnsupportedOperationException("Operation should use exec special"); - } + public abstract INDArray exec(IndexAccumulation op); @Override - public INDArray exec(BroadcastOp broadcast) { - throw new IllegalStateException("Java computation no longer supported"); - } + public abstract INDArray exec(BroadcastOp broadcast); @Override public void exec(MetaOp op) { @@ -215,9 +208,7 @@ public class DefaultOpExecutioner implements OpExecutioner { } @Override - public INDArray exec(ScalarOp op) { - throw new UnsupportedOperationException(); - } + public abstract INDArray exec(ScalarOp op); @Override public void exec(List batch) { @@ -241,9 +232,7 @@ public class DefaultOpExecutioner implements OpExecutioner { * @param rng */ @Override - public INDArray exec(RandomOp op, Random rng) { - throw new UnsupportedOperationException(); - } + public abstract INDArray exec(RandomOp op, Random rng); @Deprecated @@ -741,6 +730,11 @@ public class DefaultOpExecutioner implements OpExecutioner { throw new UnsupportedOperationException(); } + @Override + public List calculateOutputShape(CustomOp op, OpContext opContext) { + throw new UnsupportedOperationException(); + } + @Override public INDArray[] allocateOutputArrays(CustomOp op){ List shapes = calculateOutputShape(op); @@ -946,4 +940,44 @@ public class DefaultOpExecutioner implements OpExecutioner { public String runFullBenchmarkSuit(boolean printOut) { throw new UnsupportedOperationException(); } + + + public void setX(INDArray x, Op op, OpContext oc){ + if(oc != null) + oc.setInputArray(0, x); + else + op.setX(x); + } + + public INDArray getX(Op op, OpContext oc){ + if( oc != null ) + return oc.getInputArray(0); + return op.x(); + } + + public void setY(INDArray y, Op op, OpContext oc){ + if(oc != null) + oc.setInputArray(1, y); + else + op.setY(y); + } + + public INDArray getY(Op op, OpContext oc){ + if( oc != null ) + return oc.getInputArray(1); + return op.y(); + } + + public void setZ(INDArray z, Op op, OpContext oc){ + if(oc != null) + oc.setOutputArray(0, z); + else + op.setZ(z); + } + + public INDArray getZ(Op op, OpContext oc){ + if( oc != null ) + return oc.getOutputArray(0); + return op.z(); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java index c4af57864..fcf8bfd3f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java @@ -98,6 +98,13 @@ public interface OpExecutioner { */ INDArray exec(Op op); + /** + * Execute the operation + * + * @param op the operation to execute + */ + INDArray exec(Op op, OpContext opContext); + /**Execute a TransformOp and return the result * @param op the operation to execute */ @@ -364,6 +371,8 @@ public interface OpExecutioner { List calculateOutputShape(CustomOp op); + List calculateOutputShape(CustomOp op, OpContext opContext); + /** * Equivalent to calli */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java index 496943b45..8b8cc5f53 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -150,6 +151,11 @@ public class ExternalErrorsFunction extends DynamicCustomOp { return OUT_SHAPE; } + @Override + public List calculateOutputShape(OpContext oc){ + return OUT_SHAPE; + } + public Op.Type opType() { return Op.Type.LOGIC; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index 504012703..b44b11cf6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -131,8 +132,14 @@ public class Variance extends BaseReduceOp { @Override public DataType resultType() { - if (this.x() != null && this.x().isR()) - return this.x().dataType(); + return resultType(null); + } + + @Override + public DataType resultType(OpContext oc){ + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if (x != null && x.isR()) + return x.dataType(); if(this.arg() != null){ return this.arg().dataType(); @@ -142,14 +149,18 @@ public class Variance extends BaseReduceOp { } @Override - public boolean validateDataTypes() { - if (!x().isR()) + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if (x != null && !x.isR()) { + return false; + } + + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null && !y.isR()) return false; - if (y() != null && !y().isR()) - return false; - - if (z() != null && !z().isR()) + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null && !z.isR()) return false; return true; @@ -157,15 +168,22 @@ public class Variance extends BaseReduceOp { @Override public List calculateOutputShape() { - if(args().length < 1) { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + + if(oc == null && args().length < 1) { throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found."); } long[] argShape = arg().getShape(); - if (argShape == null && x() == null) { + if (argShape == null && x == null) { return Collections.emptyList(); } - long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x().shape() : argShape); + long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x.shape() : argShape); val ret = new ArrayList(1); val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java index 9c8607b98..765ab3341 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java @@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -94,20 +95,29 @@ public class MaxOut extends BaseTransformOp { return Nd4j.defaultFloatingPointType(); } + @Override + public DataType resultType(OpContext oc) { + return Nd4j.defaultFloatingPointType(); + } + @Override public Type getOpType() { return Type.TRANSFORM_STRICT; } @Override - public boolean validateDataTypes(boolean experimentalMode) { - if (!x().isR()) + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + + if (!x.isR()) return false; - if (y() != null && !y().isR()) + if (y != null && !y().isR()) return false; - if (z() != null && z().dataType() != x().dataType()) + if (z != null && z().dataType() != x().dataType()) return false; return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java index 51f682876..752881c6e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.RandomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -65,6 +66,11 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext opContext) { if(shape != null){ return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType())); } else { @@ -83,4 +89,8 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { public boolean isInPlace(){ return x == null || x == z || x.data().pointer().address() == z.data().pointer().address(); } + + public boolean isTripleArgRngOp(){ + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java index 35c6ee05e..b08f56be3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java @@ -139,4 +139,9 @@ public class BinomialDistribution extends BaseRandomOp { //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 return Collections.singletonList(DataType.DOUBLE); } + + @Override + public boolean isTripleArgRngOp() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java index ed43f807d..1081e141b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java @@ -138,4 +138,9 @@ public class GaussianDistribution extends BaseRandomOp { //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 return Collections.singletonList(DataType.DOUBLE); } + + @Override + public boolean isTripleArgRngOp() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java index 080a7305a..c007d4e92 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java @@ -135,4 +135,9 @@ public class LogNormalDistribution extends BaseRandomOp { //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 return Collections.singletonList(DataType.DOUBLE); } + + @Override + public boolean isTripleArgRngOp() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java index 24e52a532..ba09a2d29 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java @@ -136,4 +136,9 @@ public class TruncatedNormalDistribution extends BaseRandomOp { //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 return Collections.singletonList(DataType.DOUBLE); } + + @Override + public boolean isTripleArgRngOp() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index bafee4003..5da64dadb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -6556,6 +6556,10 @@ public class Nd4j { return getExecutioner().exec(op); } + public static INDArray exec(Op op, OpContext context){ + return getExecutioner().exec(op, context); + } + /** * Execute the operation and return the result * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java index fa92f94f5..d34e24def 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java @@ -54,7 +54,7 @@ public abstract class Nd4jBlas implements Blas { } String logInit = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION); - if(logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit)) { + if(logOpenMPBlasThreads() && (logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit))) { log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads()); } } @@ -74,4 +74,8 @@ public abstract class Nd4jBlas implements Blas { } return Vendor.values()[vendor]; } + + public boolean logOpenMPBlasThreads(){ + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/CudaBlas.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/CudaBlas.java index ce5ac2a0d..624460b50 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/CudaBlas.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/CudaBlas.java @@ -134,4 +134,9 @@ public class CudaBlas extends Nd4jBlas { public int getBlasVendorId() { return 1; } + + @Override + public boolean logOpenMPBlasThreads() { + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index f18bd1459..a6ccd25ed 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -44,6 +44,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; +import org.nd4j.linalg.api.ops.random.BaseRandomOp; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -229,7 +230,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { INDArray ret = op.z(); checkForCompression(op); - op.validateDataTypes(); + op.validateDataTypes(null); //validateDataType(Nd4j.dataType(), op); for (int i = 0; i < dimension.length; i++) @@ -614,8 +615,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(Op op) { + return exec(op, null); + } + + @Override + public INDArray exec(Op op, OpContext oc) { checkForCompression(op); + /* + //TODO this never would have worked //linear views and oblong offsets can't be handled by the gpu (due to the way the buffers are interpreted as vectors) if ( op instanceof CopyOp) { // we dont' care about op.Z sync state, since it'll be overwritten @@ -631,27 +639,27 @@ public class CudaExecutioner extends DefaultOpExecutioner { //AtomicAllocator.getInstance().tickHostWrite(op.z()); return null; - } + }*/ if (op instanceof TransformOp) { TransformOp t = (TransformOp) op; - invoke(t); + invoke(t, oc); } else if (op instanceof ReduceOp) { ReduceOp acc = (ReduceOp) op; - invoke(acc, acc.dimensions().toIntVector()); + invoke(acc, oc, acc.dimensions().toIntVector()); } else if (op instanceof ScalarOp) { ScalarOp sc = (ScalarOp) op; - invoke(sc); + invoke(sc, oc); } else if (op instanceof BroadcastOp) { BroadcastOp broadcastOp = (BroadcastOp) op; - invoke(broadcastOp); + invoke(broadcastOp, oc); } else if (op instanceof IndexAccumulation) { IndexAccumulation indexAccumulation = (IndexAccumulation) op; - invoke(indexAccumulation, indexAccumulation.dimensions().toIntVector()); + invoke(indexAccumulation, oc, indexAccumulation.dimensions().toIntVector()); } else if (op instanceof RandomOp) { - exec((RandomOp) op); + exec((RandomOp) op, oc, Nd4j.getRandom()); } else if (op instanceof CustomOp) { - exec((CustomOp) op); + exec((CustomOp) op, oc); } @@ -659,19 +667,22 @@ public class CudaExecutioner extends DefaultOpExecutioner { } - @Override public TransformOp execAndReturn(TransformOp op) { checkForCompression(op); - invoke(op); + invoke(op, null); return op; } - protected CudaContext invoke(BroadcastOp op) { + protected CudaContext invoke(BroadcastOp op, OpContext oc) { long st = profilingConfigurableHookIn(op); + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + checkForCompression(op); //validateDataType(Nd4j.dataType(), op); @@ -684,17 +695,17 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); val hostXShapeInfo = - op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); + x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); val hostYShapeInfo = - op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); + y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); val hostZShapeInfo = - op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); - val tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), op.getDimension()); + val tadBuffers = tadManager.getTADOnlyShapeInfo(x, op.getDimension()); val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); @@ -706,13 +717,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer devTadOffsetsZ = null; // that's the place where we're going to have second TAD in place - val tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), op.getDimension()); + val tadBuffersZ = tadManager.getTADOnlyShapeInfo(z, op.getDimension()); devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getFirst(), context); devTadOffsetsZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getSecond(), context); PointerPointer xShapeInfoHostPointer = extraz.get().put( - AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), // 0 + AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), // 0 context.getOldStream(), // 1 AtomicAllocator.getInstance().getDeviceIdPointer(), // 2 context.getBufferAllocation(), // 3 @@ -727,30 +738,30 @@ public class CudaExecutioner extends DefaultOpExecutioner { devTadShapeInfoZ, // 12 devTadOffsetsZ); // 13 - Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); + Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context); - Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context); Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); //log.info("X: {}; Y: {}; Z: {}; dTS: {}, dTO: {}; dTSz: {}; dTOz: {};", x.address(), y.address(), z.address(), devTadShapeInfo.address(), devTadOffsets.address(), devTadShapeInfoZ.address(), devTadOffsetsZ.address()); switch (op.getOpType()) { case BROADCAST: nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case BROADCAST_BOOL: nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, null, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; @@ -768,11 +779,16 @@ public class CudaExecutioner extends DefaultOpExecutioner { - protected CudaContext invoke(IndexAccumulation op, int[] dimension) { - dimension = Shape.normalizeAxis(op.x().rank(), dimension); + protected CudaContext invoke(IndexAccumulation op, OpContext oc, int[] dimension) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + + dimension = Shape.normalizeAxis(x.rank(), dimension); if (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)) { - if(op.z() == op.x() || op.z() == null) { - op.setZ(Nd4j.createUninitialized(DataType.LONG, new long[0], 'c')); + if(z == x || z == null) { + z = Nd4j.createUninitialized(DataType.LONG, new long[0], 'c'); + setZ(z, op, oc); } } @@ -790,46 +806,45 @@ public class CudaExecutioner extends DefaultOpExecutioner { CudaEnvironment.getInstance().getConfiguration().enableDebug(true); if (dimension != null) for (int i = 0; i < dimension.length; i++) - if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE) - throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); + if (dimension[i] >= x.rank() && dimension[i] != Integer.MAX_VALUE) + throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + x.rank() + "]"); val context = AtomicAllocator.getInstance().getDeviceContext(); - Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); - Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null; + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); + Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(x.dataType()), context) : null; - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); - val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); + val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); int fdimension[] = dimension; if (fdimension == null) fdimension = new int[] {0}; - Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), fdimension); + Pair tadBuffers = tadManager.getTADOnlyShapeInfo(x, fdimension); Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); DataBuffer offsets = tadBuffers.getSecond(); Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + val zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); PointerPointer xShapeInfoHostPointer = extraz.get().put( - AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), + AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets); - if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) { + if (z.isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) { nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); } else { if (dimension != null && dimension.length > 1) Arrays.sort(dimension); @@ -839,9 +854,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { .getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension)); nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } @@ -855,30 +870,34 @@ public class CudaExecutioner extends DefaultOpExecutioner { } - protected CudaContext invoke(ReduceOp op, int[] dimension) { + protected CudaContext invoke(ReduceOp op, OpContext oc, int[] dimension) { val context = AtomicAllocator.getInstance().getDeviceContext(); + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" - if(op.z() != null){ - Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + - " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); - op.z().assign(op.x()); + if(z != null){ + Preconditions.checkState(x.equalShapes(z), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", x, z); + z.assign(x); return context; } else { - op.setZ(op.x().dup()); + op.setZ(x.dup()); return context; } } // FIXME: this should be moved down to C++ on per-op basis // reduce to scalar case, ReduceBool ops require special treatment - if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { - if (op.z() == null) { + if (op instanceof BaseReduceBoolOp && x.isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { + if (z == null) { op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue())); } else { - op.z().assign(((BaseReduceBoolOp) op).emptyValue()); + z.assign(((BaseReduceBoolOp) op).emptyValue()); } return context; @@ -888,7 +907,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { checkForCompression(op); - dimension = Shape.normalizeAxis(op.x().rank(), dimension); + dimension = Shape.normalizeAxis(x.rank(), dimension); //validateDataType(Nd4j.dataType(), op); @@ -903,130 +922,131 @@ public class CudaExecutioner extends DefaultOpExecutioner { Arrays.sort(dimension); for (int i = 0; i < dimension.length; i++) - if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE) + if (dimension[i] >= x.rank() && dimension[i] != Integer.MAX_VALUE) throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) - + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); + + " contains element that higher then rank of op.X: [" + x.rank() + "]"); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - val tadBuffers = op.x().isEmpty() ? Pair.makePair(op.x().data(), null) : tadManager.getTADOnlyShapeInfo(op.x(), dimension); + val tadBuffers = x.isEmpty() ? Pair.makePair(x.data(), null) : tadManager.getTADOnlyShapeInfo(x, dimension); val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); - val offsets = op.x().isEmpty() ? null : tadBuffers.getSecond(); + val offsets = x.isEmpty() ? null : tadBuffers.getSecond(); val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); - long[] retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims()); + long[] retShape = Shape.reductionShape(x, dimension, true, op.isKeepDims()); - if (op.y() != null) { + if (y != null) { //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y - if (op.x().length() == op.y().length()) { + if (x.length() == y.length()) { //Pairwise - if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) { + if (x.tensorsAlongDimension(dimension) != y.tensorsAlongDimension(dimension)) { throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + - Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) + + Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) + ", dimension = " + Arrays.toString(dimension) + ")"); } } else { //Every X TAD vs. entirety of Y - val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension); + val xTADSize = x.length() / x.tensorsAlongDimension(dimension); - if (xTADSize != op.y().length()) { + if (xTADSize != y.length()) { throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" + - " (x TAD size = " + xTADSize + ", y size = " + op.y().length()); + " (x TAD size = " + xTADSize + ", y size = " + y.length()); } } } - //if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) { + //if (x.isVector() && x.length() == ArrayUtil.prod(retShape)) { // return null; //} - val dataType = op.resultType(); + val dataType = oc != null ? op.resultType(oc) : op.resultType(); - if( op.z() == null ){ + if( z == null ){ val ret = Nd4j.createUninitialized(dataType, retShape); - op.setZ(ret); - } else if(op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())){ + setZ(ret, op, oc); + z = ret; + } else if(z.dataType() != dataType || !Arrays.equals(retShape, z.shape())){ throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape) - + " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape())); + + " but has datatype " + z.dataType() + " and shape " + Arrays.toString(z.shape())); } - val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType()); + val eb = op.extraArgsDataBuff(z.dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? x.dataType() : z.dataType()); Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null; - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); - val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); + val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); val xShapeInfoHostPointer = extraz.get().put( - AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), + AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets); - val yTadBuffers = op.y() == null ? null : tadManager.getTADOnlyShapeInfo(op.y(), dimension); + val yTadBuffers = y == null ? null : tadManager.getTADOnlyShapeInfo(y, dimension); - val yDevTadShapeInfo = op.y() == null ? null : AtomicAllocator.getInstance().getPointer(yTadBuffers.getFirst(), context); - val yOffsets = op.y() == null ? null : yTadBuffers.getSecond(); + val yDevTadShapeInfo = y == null ? null : AtomicAllocator.getInstance().getPointer(yTadBuffers.getFirst(), context); + val yOffsets = y == null ? null : yTadBuffers.getSecond(); val yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context); - if (op.y() != null) { + if (y != null) { xShapeInfoHostPointer.put(12, yDevTadShapeInfo); xShapeInfoHostPointer.put(13, yDevTadOffsets); } - val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + val zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); - op.validateDataTypes(); + op.validateDataTypes(null); - if (op.z().isScalar()) { + if (z.isScalar()) { if (op instanceof Variance) { nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((Variance) op).isBiasCorrected()); - } else if (op.y() != null) { - Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); + } else if (y != null) { + Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context); nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); } else { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_BOOL: nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_SAME: nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_LONG: nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; default: throw new UnsupportedOperationException(); @@ -1035,21 +1055,21 @@ public class CudaExecutioner extends DefaultOpExecutioner { } else { val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); - if (op.y() != null) { - val yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); + if (y != null) { + val yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context); nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); } else { if (op instanceof Variance) { nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, ((Variance) op).isBiasCorrected(), (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets); @@ -1057,30 +1077,30 @@ public class CudaExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: @@ -1187,34 +1207,40 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(ScalarOp op) { - invoke(op); + invoke(op, null); return op.z(); } - protected CudaContext invoke(ScalarOp op) { + protected CudaContext invoke(ScalarOp op, OpContext oc) { long st = profilingConfigurableHookIn(op); checkForCompression(op); + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + // validateDataType(Nd4j.dataType(), op); - if(op.z() == null){ + if(z == null){ switch (op.getOpType()) { case SCALAR: - op.setZ(op.x().ulike()); + z = x.ulike(); + setZ(x.ulike(), op, oc); break; case SCALAR_BOOL: - op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); + z = Nd4j.createUninitialized(DataType.BOOL, x.shape()); + setZ(z, op, oc); break; default: throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); } } - if (op.x().length() != op.z().length()) + if (x.length() != z.length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" - + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" - + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]"); + + Arrays.toString(x.shapeInfoDataBuffer().asInt()) + "] != [" + + Arrays.toString(z.shapeInfoDataBuffer().asInt()) + "]"); if (extraz.get() == null) extraz.set(new PointerPointer(32)); @@ -1229,38 +1255,38 @@ public class CudaExecutioner extends DefaultOpExecutioner { val context = AtomicAllocator.getInstance().getDeviceContext(); - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); val hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer()); - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); - Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); - Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? op.x().dataType() : op.z().dataType()), context) : null; + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); + Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? x.dataType() : z.dataType()), context) : null; - Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context); PointerPointer xShapeInfoHostPointer = extraz.get().put( - AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), + AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.scalar() == null ? null : ((BaseCudaDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = op.scalar() == null ? null : ((BaseCudaDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case SCALAR_BOOL: nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs); break; case SCALAR: nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs); break; default: @@ -1275,9 +1301,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { return null; } - protected CudaContext invoke(TransformOp op) { + protected CudaContext invoke(TransformOp op, OpContext oc) { long st = profilingConfigurableHookIn(op); + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + checkForCompression(op); //validateDataType(Nd4j.dataType(), op); @@ -1295,7 +1325,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { // special temp array for IsMax along dimension INDArray ret = null; - Pointer xShapeInfo = allocator.getPointer(op.x().shapeInfoDataBuffer(), context); + Pointer xShapeInfo = allocator.getPointer(x.shapeInfoDataBuffer(), context); Pointer dimensionDevPointer = null; @@ -1304,17 +1334,18 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer retHostShape = null; int dimension[] = null; - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); - var hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); + var hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); - if (op.z() == null) { - ret = Nd4j.createUninitialized(op.resultType(), op.x().shape(), op.x().ordering()); - op.setZ(ret); + if (z == null) { + ret = Nd4j.createUninitialized(op.resultType(), x.shape(), x.ordering()); + setZ(ret, op, oc); + z = ret; } - var extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? op.x().dataType() : op.z().dataType()), context) : null; - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + var extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? x.dataType() : z.dataType()), context) : null; + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); Pointer hostTadShapeInfo = null; Pointer devTadShapeInfo = null; @@ -1328,13 +1359,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer devTadOffsets = null; Pointer devMaxTadOffsets = null; - op.validateDataTypes(experimentalMode.get()); + op.validateDataTypes(oc, experimentalMode.get()); - Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context); + Pointer zShapeInfo = allocator.getPointer(z.shapeInfoDataBuffer(), context); PointerPointer xShapeInfoHostPointer = - extraz.get().put(AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), // 0 + extraz.get().put(AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), // 0 context.getOldStream(), // 1 allocator.getDeviceIdPointer(), // 2 context.getBufferAllocation(), // 3 @@ -1356,30 +1387,30 @@ public class CudaExecutioner extends DefaultOpExecutioner { retHostShape); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); - if (op.y() != null) { - Pointer yShapeInfo = allocator.getPointer(op.y().shapeInfoDataBuffer(), context); + if (y != null) { + Pointer yShapeInfo = allocator.getPointer(y.shapeInfoDataBuffer(), context); - if (op.x().length() != op.y().length() || op.x().length() != op.z().length()) + if (x.length() != y.length() || x.length() != z.length()) throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform"); switch (op.getOpType()) { case TRANSFORM_BOOL: case PAIRWISE_BOOL: nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; default: nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; } @@ -1387,32 +1418,32 @@ public class CudaExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case TRANSFORM_ANY: nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_FLOAT: nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_BOOL: nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_SAME: nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_STRICT: nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; default: @@ -1478,6 +1509,21 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(RandomOp op, Random rng) { + return exec(op, null, rng); + } + + public INDArray exec(RandomOp op, OpContext oc, Random rng){ + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + + if(op instanceof BaseRandomOp && ((BaseRandomOp)op).isTripleArgRngOp() && z != null && x == null && y == null){ + //Ugly hack to ensure the triple arg call occurs + //See GaussianDistribution.setZ etc + x = z; + y = z; + } + long st = profilingConfigurableHookIn(op); checkForCompression(op); @@ -1496,38 +1542,38 @@ public class CudaExecutioner extends DefaultOpExecutioner { val context = AtomicAllocator.getInstance().getDeviceContext(); - PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()), + PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); - val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); + val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); - if (op.x() != null && op.y() != null && op.z() != null) { + if (x != null && y != null && z != null) { // triple arg call nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), - y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), - z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); + xb, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), + yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context), + zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), + AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context)); - } else if (op.x() != null && op.z() != null) { + } else if (x != null && z != null) { //double arg call nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), - z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()),context)); + xb, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), + zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), + AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()),context)); } else { // single arg call nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); + zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), + AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context)); } if (nativeOps.lastErrorCode() != 0) @@ -1535,7 +1581,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { profilingConfigurableHookOut(op, st); - return op.z(); + return z; } /** @@ -1888,6 +1934,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public List calculateOutputShape(@NonNull CustomOp op) { + return calculateOutputShape(op, null); + } + + @Override + public List calculateOutputShape(@NonNull CustomOp op, OpContext opContext){ Nd4j.getExecutioner().commit(); @@ -1895,7 +1946,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hash = op.opHash(); val result = new ArrayList(); - if(op.numInputArguments() < 1 && op.getDescriptor().getNumInputs() != -2) { + int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); + if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); @@ -1903,47 +1955,75 @@ public class CudaExecutioner extends DefaultOpExecutioner { return Collections.emptyList(); } - val inputBuffers = new PointerPointer<>(op.inputArguments().size() * 2); - val inputShapes = new PointerPointer<>(op.inputArguments().size()); + val inputBuffers = new PointerPointer<>(nIn * 2); + val inputShapes = new PointerPointer<>(nIn); + val inputArgs = opContext != null ? opContext.getInputArrays() : op.inputArguments(); int cnt= 0; - for (val in: op.inputArguments()) { + for (val in: inputArgs) { // NOT A TYPO: shape functions work on host side only if (!in.isEmpty()) { inputBuffers.put(cnt, in.data().addressPointer()); - inputBuffers.put(cnt + op.inputArguments().size(), AtomicAllocator.getInstance().getPointer(in.data())); + inputBuffers.put(cnt + nIn, AtomicAllocator.getInstance().getPointer(in.data())); } inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer()); } - val iArgs = op.iArgs().length > 0 ? new LongPointer(op.iArgs().length) : null; + int nIArgs = opContext != null ? opContext.numIArguments() : op.numIArguments(); + val iArgs = nIArgs > 0 ? new LongPointer(nIArgs) : null; cnt = 0; - for (val i: op.iArgs()) - iArgs.put(cnt++, i); + if(opContext != null){ + for (val i: opContext.getIArguments()) + iArgs.put(cnt++, i); + } else { + for (val i: op.iArgs()) + iArgs.put(cnt++, i); + } - val tArgs = op.tArgs().length > 0 ? new DoublePointer(op.tArgs().length) : null; + int nTArgs = opContext != null ? opContext.numTArguments() : op.numTArguments(); + val tArgs = nTArgs > 0 ? new DoublePointer(nTArgs) : null; - val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null; + int nBArgs = opContext != null ? opContext.numBArguments() : op.numBArguments(); + val bArgs = nBArgs > 0 ? new BooleanPointer(nBArgs) : null; - val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null; + int nDArgs = opContext != null ? opContext.numDArguments() : op.numDArguments(); + val dArgs = nDArgs > 0 ? new IntPointer(nDArgs) : null; cnt = 0; - for (val b: op.bArgs()) - bArgs.put(cnt++, b); + if(opContext != null){ + for (val b: opContext.getBArguments()) + bArgs.put(cnt++, b); + } else { + for (val b: op.bArgs()) + bArgs.put(cnt++, b); + } + cnt = 0; - for (val t: op.tArgs()) - tArgs.put(cnt++, t); + if(opContext != null){ + for (val b: opContext.getTArguments()) + tArgs.put(cnt++, b); + } else { + for (val b: op.tArgs()) + tArgs.put(cnt++, b); + } cnt = 0; - val dArgs1 = op.dArgs(); - for (val d: dArgs1) - dArgs.put(cnt++, d.toInt()); + if(opContext != null){ + for (val b: opContext.getDArguments()) + dArgs.put(cnt++, b.toInt()); + } else { + for (val b: op.dArgs()) + dArgs.put(cnt++, b.toInt()); + } - OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments()); + OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, + hash, inputBuffers, inputShapes, nIn, tArgs, nTArgs, + iArgs, nIArgs, bArgs, nBArgs, dArgs, nDArgs); +// OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments()); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java index 850096359..1d8a3de65 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java @@ -20,6 +20,7 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import lombok.val; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.primitives.Pair; @@ -127,7 +128,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // the only entry place for TADless ops processAsGridOp(op); } else if (op instanceof BroadcastOp) { - invoke((BroadcastOp) op); + invoke((BroadcastOp) op, null); } else { //logger.info("Random op: {}", op.getClass().getSimpleName()); pushToGrid(new OpDescriptor(op)); @@ -238,7 +239,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio flushQueue(); //logger.info("Sending TransformOp to CudaExecutioner"); - super.invoke(t); + super.invoke(t, null); } else if (op instanceof Variance) { Variance acc = (Variance) op; if (flush) @@ -258,7 +259,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio flushQueue(); //logger.info("Sending ScalarOp to CudaExecutioner"); - super.invoke(sc); + super.invoke(sc, null); } else if (op instanceof BroadcastOp) { BroadcastOp broadcastOp = (BroadcastOp) op; if (flush) @@ -268,7 +269,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio if (dimensions != null) { super.exec(broadcastOp); } else { - super.invoke(broadcastOp); + super.invoke(broadcastOp, null); } } else if (op instanceof IndexAccumulation) { IndexAccumulation indexAccumulation = (IndexAccumulation) op; @@ -690,7 +691,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio flushQueue(); buildZ(op, new int[] {Integer.MAX_VALUE}); - super.invoke(op, new int[] {Integer.MAX_VALUE}); + super.invoke(op, null, new int[] {Integer.MAX_VALUE}); } else { buildZ(op, dimension); processAsGridOp(op, dimension); @@ -708,7 +709,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // FIXME: remove CudaContext return opType. We just don't need it @Override - protected CudaContext invoke(BroadcastOp op) { + protected CudaContext invoke(BroadcastOp op, OpContext oc) { + Preconditions.checkState(oc == null); processAsGridOp(op, op.getDimension()); return null; @@ -716,7 +718,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // FIXME: remove CudaContext return opType. We just don't need it @Override - protected CudaContext invoke(ScalarOp op) { + protected CudaContext invoke(ScalarOp op, OpContext oc) { + Preconditions.checkState(oc == null); processAsGridOp(op, null); return null; @@ -724,7 +727,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // FIXME: remove CudaContext return opType. We just don't need it @Override - protected CudaContext invoke(TransformOp op) { + protected CudaContext invoke(TransformOp op, OpContext oc) { + Preconditions.checkState( oc == null); processAsGridOp(op, null); return null; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index cc3d17b5f..7a29f71d7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -44,6 +44,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; +import org.nd4j.linalg.api.ops.random.BaseRandomOp; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -135,26 +136,31 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(Op op) { + return exec(op, null); + } + + @Override + public INDArray exec(Op op, OpContext opContext) { checkForCompression(op); if (op instanceof ScalarOp) { ScalarOp s = (ScalarOp) op; - exec(s); + exec(s, opContext); } else if (op instanceof TransformOp) { TransformOp t = (TransformOp) op; - exec(t); + exec(t, opContext); } else if (op instanceof ReduceOp) { ReduceOp ac = (ReduceOp) op; - exec(ac); + exec(ac, opContext); } else if (op instanceof IndexAccumulation) { IndexAccumulation iac = (IndexAccumulation) op; - exec(iac); //Currently using DefaultOpExecutioner + exec(iac, opContext); //Currently using DefaultOpExecutioner } else if (op instanceof BroadcastOp) { BroadcastOp broadcastOp = (BroadcastOp) op; - exec(broadcastOp); + exec(broadcastOp, opContext); } else if (op instanceof RandomOp) { RandomOp rngOp = (RandomOp) op; - exec(rngOp, Nd4j.getRandom()); + exec(rngOp, opContext, Nd4j.getRandom()); } return op.z(); @@ -163,36 +169,44 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(IndexAccumulation op) { + return exec(op, null); + } + + public INDArray exec(IndexAccumulation op, OpContext oc) { checkForCompression(op); + INDArray x = getX(op, oc); + INDArray z = getZ(op, oc); + if (extraz.get() == null) extraz.set(new PointerPointer(32)); - val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); + val dimension = Shape.normalizeAxis(x.rank(), op.dimensions().toIntVector()); - if (op.x().isEmpty()) { + if (x.isEmpty()) { for (val d:dimension) { - Preconditions.checkArgument(op.x().shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape"); + Preconditions.checkArgument(x.shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape"); } } boolean keepDims = op.isKeepDims(); - long[] retShape = Shape.reductionShape(op.x(), dimension, true, keepDims); + long[] retShape = Shape.reductionShape(x, dimension, true, keepDims); - if(op.z() == null || op.x() == op.z()) { + if(z == null || x == z) { val ret = Nd4j.createUninitialized(DataType.LONG, retShape); - op.setZ(ret); - } else if(!Arrays.equals(retShape, op.z().shape())){ + setZ(ret, op, oc); + z = ret; + } else if(!Arrays.equals(retShape, z.shape())){ throw new IllegalStateException("Z array shape does not match expected return type for op " + op - + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(op.z().shape())); + + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape())); } op.validateDataTypes(); Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); - Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); + Pair tadBuffers = tadManager.getTADOnlyShapeInfo(x, dimension); Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer(); @@ -203,19 +217,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { long st = profilingConfigurableHookIn(op, tadBuffers.getFirst()); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); - if (op.z().isScalar()) { + if (z.isScalar()) { loop.execIndexReduceScalar(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null); } else { loop.execIndexReduce(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } @@ -223,7 +237,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw new RuntimeException(loop.lastErrorMessage()); profilingConfigurableHookOut(op, st); - return op.z(); + return getZ(op, oc); } @Override @@ -233,34 +247,41 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(ReduceOp op) { - Preconditions.checkNotNull(op.x(), "Op.x() cannot be null: Was null for op %s", op); - op.validateDataTypes(); + return exec(op, null); + } + + public INDArray exec(ReduceOp op, OpContext oc) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + Preconditions.checkNotNull(x, "Op.x() cannot be null: Was null for op %s", op); + op.validateDataTypes(oc); if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" - if(op.z() != null){ - Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + - " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); - op.z().assign(op.x()); - return op.z(); + if(z != null){ + Preconditions.checkState(x.equalShapes(z), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", x, z); + z.assign(x); + return z; } else { - op.setZ(op.x().dup()); - return op.z(); + setZ(x.dup(), op, oc); + return z; } } // FIXME: this should be moved down to C++ on per-op basis - val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); + val dimension = Shape.normalizeAxis(x.rank(), op.dimensions().toIntVector()); // reduce to scalar case, ReduceBool ops require special treatment - if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { - if (op.z() == null) { - op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue())); + if (op instanceof BaseReduceBoolOp && x.isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { + if (z == null) { + setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()), op, oc); } else { - op.z().assign(((BaseReduceBoolOp) op).emptyValue()); + z.assign(((BaseReduceBoolOp) op).emptyValue()); } - return op.z(); + return z; } //validateDataType(Nd4j.dataType(), op); @@ -269,10 +290,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { extraz.set(new PointerPointer(32)); boolean keepDims = op.isKeepDims(); - long[] retShape = Shape.reductionShape(op.x(), dimension, true, keepDims); + long[] retShape = Shape.reductionShape(x, dimension, true, keepDims); - if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && op.y() == null) + if (x.isVector() && x.length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && y == null) return op.noOp(); /** @@ -280,92 +301,94 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * We create it only if we hadn't provided it before */ INDArray ret; - if (op.z() == null || op.z() == op.x()) { + if (z == null || z == x) { if (op.isComplexAccumulation()) { - long xT = op.x().tensorsAlongDimension(dimension); - long yT = op.y().tensorsAlongDimension(dimension); + long xT = x.tensorsAlongDimension(dimension); + long yT = y.tensorsAlongDimension(dimension); ret = Nd4j.create(op.resultType(), new long[]{xT, yT}); } else { - if (op.y() != null) { + if (y != null) { //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y - if(op.x().length() == op.y().length()) { + if(x.length() == y.length()) { //Pairwise - if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) { + if (x.tensorsAlongDimension(dimension) != y.tensorsAlongDimension(dimension)) { throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + - Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) + + Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) + ", dimension = " + Arrays.toString(dimension) + ")"); } } else { //Every X TAD vs. entirety of Y - val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension); + val xTADSize = x.length() / x.tensorsAlongDimension(dimension); - if (xTADSize != op.y().length()) { + if (xTADSize != y.length()) { throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" + - " (x TAD size = " + xTADSize + ", y size = " + op.y().length()); + " (x TAD size = " + xTADSize + ", y size = " + y.length()); } } } - ret = Nd4j.create(op.resultType(), retShape); + DataType dt = oc != null ? op.resultType(oc) : op.resultType(); + ret = Nd4j.create(dt, retShape); } - op.setZ(ret); + setZ(ret, op, oc); + z = ret; } else { // compare length long shapeProduct = (retShape.length == 0 ? 1 : ArrayUtil.prodLong(retShape)); - if (!op.isComplexAccumulation() && op.z().length() != shapeProduct) { - if(!(op.x().isEmpty() && op.isKeepDims())){ + if (!op.isComplexAccumulation() && z.length() != shapeProduct) { + if(!(x.isEmpty() && op.isKeepDims())){ //Empty reductions are special case: [1,0].sum(0,1,keep=true) -> shape [1,1] - throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); + throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(z.shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); } } else if (op.isComplexAccumulation()) { - long xT = op.x().tensorsAlongDimension(dimension); - long yT = op.y().tensorsAlongDimension(dimension); + long xT = x.tensorsAlongDimension(dimension); + long yT = y.tensorsAlongDimension(dimension); - if (op.z().length() != xT * yT) - throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + (xT * yT) + "]"); + if (z.length() != xT * yT) + throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(z.shape()) + "] doesn't match expected [" + (xT * yT) + "]"); } - ret = op.z(); + ret = z; } - //log.info("X dtype: {}; Z dtype: {}", op.x().dataType(), op.z().dataType()); + //log.info("X dtype: {}; Z dtype: {}", x.dataType(), z.dataType()); /** * Returns the {@link Shape#createShapeInformation(int[], int[], int, int, char)} * and the associated offsets for each {@link INDArray#tensorAlongDimension(int, int...)} * The first item is the shape information. The second one is the offsets. */ - Pair tadBuffers = op.x().isEmpty() ? Pair.makePair(op.x().data(), null): tadManager.getTADOnlyShapeInfo(op.x(), dimension); + Pair tadBuffers = x.isEmpty() ? Pair.makePair(x.data(), null): tadManager.getTADOnlyShapeInfo(x, dimension); Pair yTadBuffers = null; /** * Note that we use addresses in libnd4j. * We use reinterpret cast in c to take the long * we pass to JNI. This manages overhead. */ - Pointer hostTadShapeInfo = op.x().isEmpty() ? op.x().shapeInfoDataBuffer().addressPointer() : tadBuffers.getFirst().addressPointer(); + Pointer hostTadShapeInfo = x.isEmpty() ? x.shapeInfoDataBuffer().addressPointer() : tadBuffers.getFirst().addressPointer(); - DataBuffer offsets = op.x().isEmpty() ? null : tadBuffers.getSecond(); + DataBuffer offsets = x.isEmpty() ? null : tadBuffers.getSecond(); Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer(); // we're going to check, if that's TAD vs TAD comparison or TAD vs full array. if later - we're going slightly different route boolean tvf = false; - if (op.y() != null) { - if (op.x().tensorAlongDimension(0, dimension).length() == op.y().length()) { + if (y != null) { + if (x.tensorAlongDimension(0, dimension).length() == y.length()) { tvf = true; } } if (op.isComplexAccumulation()) { - yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension); + yTadBuffers = tadManager.getTADOnlyShapeInfo(y, dimension); - if (op.x().tensorAlongDimension(0, dimension).length() != op.y().tensorAlongDimension(0, dimension).length()) + if (x.tensorAlongDimension(0, dimension).length() != y.tensorAlongDimension(0, dimension).length()) throw new ND4JIllegalStateException("Impossible to issue AllDistances operation: TAD lengths mismatch along given dimension: " + - "x TAD length = " + op.x().tensorAlongDimension(0, dimension).length() + ", y TAD length " + - op.y().tensorAlongDimension(0, dimension).length()); + "x TAD length = " + x.tensorAlongDimension(0, dimension).length() + ", y TAD length " + + y.tensorAlongDimension(0, dimension).length()); } /** @@ -383,23 +406,23 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * This gives us a pointer which is passed around in libnd4j. */ Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); if (op instanceof Variance) { if (ret.isScalar()) { loop.execSummaryStatsScalar(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((Variance) op).isBiasCorrected()); } else { Variance var = (Variance) op; try { loop.execSummaryStatsTad(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, var.isBiasCorrected(), null, null); } catch (Throwable t){ @@ -410,15 +433,15 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } //pairwise reduction like similarity of two arrays - else if (op.y() != null && op.getOpType() == Op.Type.REDUCE3) { - val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + else if (y != null && op.getOpType() == Op.Type.REDUCE3) { + val yb = ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); if (op.isComplexAccumulation()) { try { loop.execReduce3All(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), (LongPointer) yTadBuffers.getFirst().addressPointer(), new LongPointerWrapper(yTadBuffers.getSecond().addressPointer()) @@ -429,17 +452,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } else if (ret.isScalar()) { loop.execReduce3Scalar(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); } else { try { loop.execReduce3Tad(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, null, null, null, null); } catch (Throwable t){ @@ -453,27 +476,27 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: loop.execReduceFloat(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: loop.execReduceBool(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: loop.execReduceSame(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: loop.execReduceLong(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unsupported op used in reduce: "+ op.getOpType()); @@ -482,32 +505,32 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: loop.execReduceFloat2(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: loop.execReduceLong2(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: loop.execReduceSame2(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: loop.execReduceBool2(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; @@ -520,7 +543,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - return ret; + return getZ(op, oc); } /** @@ -528,6 +551,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * @param op Op to execute */ private void invokeScalarAlongDimension(ScalarOp op) { + invokeScalarAlongDimension(op, null); + } + + private void invokeScalarAlongDimension(ScalarOp op, OpContext oc) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + val dimension = op.dimensions().toIntVector(); //dimension = Shape.normalizeAxis(op.x().rank(), dimension); // do tad magic @@ -561,16 +592,16 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (extraz.get() == null) extraz.set(new PointerPointer(32)); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case SCALAR: loop.execScalarTad(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),null, (LongPointer) hostTadShapeInfo, (LongPointer) hostTadOffsets, @@ -578,9 +609,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { break; case SCALAR_BOOL: loop.execScalarBoolTad(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) hostTadShapeInfo, (LongPointer) hostTadOffsets, @@ -594,56 +625,63 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw new RuntimeException(loop.lastErrorMessage()); } - public INDArray exec(ScalarOp op) { + public INDArray exec(ScalarOp op){ + return exec(op, null); + } + + public INDArray exec(ScalarOp op, OpContext oc) { long st = profilingConfigurableHookIn(op); //validateDataType(Nd4j.dataType(), op); - if(op.z() == null){ + if((oc != null && oc.getOutputArray(0) == null) || getZ(op, oc) == null){ switch (op.getOpType()) { case SCALAR: - op.setZ(op.x().ulike()); + setZ(getX(op, oc).ulike(), op, oc); +// op.setZ(op.x().ulike()); break; case SCALAR_BOOL: - op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); +// op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); + setZ(Nd4j.createUninitialized(DataType.BOOL, getX(op, oc).shape()), op, oc); break; default: throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); } } - if (op.x().length() != op.z().length()) +// if (op.x().length() != op.z().length()) + if (getX(op, oc).length() != getZ(op, oc).length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " + - "x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = [" - + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "], z shape info = [" - + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]"); + "x.length()=" + getX(op, oc).length() + ", z.length()=" + getZ(op, oc).length() + " - x shape info = [" + + Arrays.toString(getX(op, oc).shapeInfoDataBuffer().asInt()) + "], z shape info = [" + + Arrays.toString(getZ(op, oc).shapeInfoDataBuffer().asInt()) + "]"); if (op.dimensions() != null) { invokeScalarAlongDimension(op); - return op.z(); + return getZ(op, oc); } - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val x = ((BaseCpuDataBuffer) getX(op, oc).data()).getOpaqueDataBuffer(); val scalar = ((BaseCpuDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) getZ(op, oc).data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case SCALAR: loop.execScalar(null, - op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType())); + op.opNum(), + x, (LongPointer) getX(op, oc).shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) getZ(op, oc).shapeInfoDataBuffer().addressPointer(), null, + scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, getZ(op, oc).dataType())); break; case SCALAR_BOOL: loop.execScalarBool(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + x, (LongPointer) getX(op, oc).shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) getZ(op, oc).shapeInfoDataBuffer().addressPointer(), null, scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType())); + getPointerForExtraArgs(op, getX(op, oc).dataType())); break; default: throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); @@ -654,7 +692,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { profilingConfigurableHookOut(op, st); - return op.z(); + return getZ(op, oc); } private Pointer getPointerForExtraArgs(Op op, DataType type) { @@ -670,6 +708,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } private void exec(TransformOp op) { + exec(op, null); + } + + private void exec(TransformOp op, OpContext oc) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + long st = 0; // validateDataType(Nd4j.dataType(), op); @@ -681,8 +727,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { // Pow operations might be special if (op.opNum() == 31) { - if (op.y() != null && op.y().isScalar()) { - op.setY(Nd4j.valueArrayOf(op.x().shape(), op.y().getDouble(0))); + if (y != null && y.isScalar()) { +// op.setY(Nd4j.valueArrayOf(op.x().shape(), op.y().getDouble(0))); + setY(Nd4j.valueArrayOf(x.shape(), y.getDouble(0)), op, oc); } } @@ -723,33 +770,26 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } else st = profilingConfigurableHookIn(op); - if (op.y() != null) { + if (y != null) { - if (op.z() == null) - op.setZ(Nd4j.create(op.resultType(), op.x().shape())); + if (z == null) + setZ(Nd4j.create(op.resultType(), x.shape()), op, oc); +// op.setZ(Nd4j.create(op.resultType(), op.x().shape())); - op.validateDataTypes(experimentalMode.get()); + op.validateDataTypes(oc, experimentalMode.get()); //log.info("X type: {}; Y type: {}; Z type: {}; OpNum: {}", op.x().dataType(), op.y().dataType(), op.z().dataType(), op.opNum()); - int xEWS = op.x().elementWiseStride(); - int yEWS = op.y().elementWiseStride(); - int zEWS = op.z().elementWiseStride(); - - boolean xRow = op.x().isRowVector(); - boolean yRow = op.y().isRowVector(); - boolean zRow = op.z().isRowVector(); - - if (op.x().length() != op.y().length() || op.x().length() != op.z().length()) + if (x.length() != y.length() || x.length() != z.length()) throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform " + - op.opName() + ". x: length " + op.x().length() + ", shape " + Arrays.toString(op.x().shape()) + - "; y: " + op.y().length() + ", shape " + Arrays.toString(op.y().shape()) + - "; z: " + op.z().length() + ", shape " + Arrays.toString(op.z().shape())); + op.opName() + ". x: length " + x.length() + ", shape " + Arrays.toString(x.shape()) + + "; y: " + y.length() + ", shape " + Arrays.toString(y.shape()) + + "; z: " + z.length() + ", shape " + Arrays.toString(z.shape())); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case TRANSFORM_ANY: @@ -757,78 +797,81 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { case TRANSFORM_STRICT: case TRANSFORM_SAME: if (!experimentalMode.get()) - Preconditions.checkArgument(op.x().dataType() == op.y().dataType() || op.y().dataType() == DataType.BOOL, "Op.X and Op.Y must have the same data type, but got " + op.x().dataType() + " vs " + op.y().dataType()); + Preconditions.checkArgument(x.dataType() == y.dataType() || y.dataType() == DataType.BOOL, + "Op.X and Op.Y must have the same data type, but got %s vs. %s", x.dataType(), y.dataType()); loop.execPairwiseTransform(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType())); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType())); break; case TRANSFORM_BOOL: case PAIRWISE_BOOL: loop.execPairwiseTransformBool(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType())); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType())); break; } } else { - if (op.z() == null) - op.setZ(Nd4j.createUninitialized(op.resultType(), op.x().shape())); + if (z == null) { + setZ(Nd4j.createUninitialized((oc != null ? op.resultType(oc) : op.resultType()), x.shape()), op, oc); + z = getZ(op, oc); + } - op.validateDataTypes(experimentalMode.get()); + op.validateDataTypes(oc, experimentalMode.get()); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case TRANSFORM_FLOAT: { - val xtraz = getPointerForExtraArgs(op, op.z().dataType()); + val xtraz = getPointerForExtraArgs(op, z.dataType()); loop.execTransformFloat(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } case TRANSFORM_STRICT: { - val xtraz = getPointerForExtraArgs(op, op.z().dataType()); + val xtraz = getPointerForExtraArgs(op, z.dataType()); loop.execTransformStrict(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } case TRANSFORM_SAME: { - val xtraz = getPointerForExtraArgs(op, op.z().dataType()); + val xtraz = getPointerForExtraArgs(op, z.dataType()); loop.execTransformSame(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } case TRANSFORM_ANY: { - val xtraz = getPointerForExtraArgs(op, op.x().dataType()); + val xtraz = getPointerForExtraArgs(op, x.dataType()); val opNum = op.opNum(); loop.execTransformAny(dummy, opNum, - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } case TRANSFORM_BOOL: { - val xtraz = getPointerForExtraArgs(op, op.x().dataType()); + val xtraz = getPointerForExtraArgs(op, x.dataType()); val opNum = op.opNum(); loop.execTransformBool(dummy, opNum, - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -845,6 +888,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } public INDArray exec(BroadcastOp op) { + return exec(op, null); + } + + public INDArray exec(BroadcastOp op, OpContext oc) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + long st = profilingConfigurableHookIn(op); op.validateDataTypes(experimentalMode.get()); @@ -856,7 +907,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * and the associated offsets for each {@link INDArray#tensorAlongDimension(int, int...)} * The first item is the shape information. The second one is the offsets. */ - Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); + Pair tadBuffers = tadManager.getTADOnlyShapeInfo(x, dimension); Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer(); Pointer hostTadOffsets = tadBuffers.getSecond().addressPointer(); @@ -864,17 +915,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Pointer devTadShapeInfoZ = null; Pointer devTadOffsetsZ = null; - // if (!Arrays.equals(op.x().shape(),op.z().shape()) || !Arrays.equals(op.x().stride(),op.z().stride()) || op.x().ordering() != op.z().ordering()) { + // if (!Arrays.equals(x.shape(),z.shape()) || !Arrays.equals(x.stride(),z.stride()) || x.ordering() != z.ordering()) { // that's the place where we're going to have second TAD in place - Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension); + Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(z, dimension); devTadShapeInfoZ = tadBuffersZ.getFirst().addressPointer(); devTadOffsetsZ = tadBuffersZ.getSecond().addressPointer(); /* log.info("Broascast dimension: {}", Arrays.toString(dimension)); - log.info("x shape: {}; x TAD: {}; comp TAD: {}", Arrays.toString(op.x().shapeInfoDataBuffer().asInt()), Arrays.toString(tadBuffers.getFirst().asInt()), Arrays.toString(op.x().tensorAlongDimension(0, dimension).shapeInfoDataBuffer().asInt())); - log.info("z shape: {}; z TAD: {}", Arrays.toString(op.z().shapeInfoDataBuffer().asInt()), Arrays.toString(tadBuffersZ.getFirst().asInt())); - log.info("y shape: {}", Arrays.toString(op.y().shapeInfoDataBuffer().asInt())); + log.info("x shape: {}; x TAD: {}; comp TAD: {}", Arrays.toString(x.shapeInfoDataBuffer().asInt()), Arrays.toString(tadBuffers.getFirst().asInt()), Arrays.toString(x.tensorAlongDimension(0, dimension).shapeInfoDataBuffer().asInt())); + log.info("z shape: {}; z TAD: {}", Arrays.toString(z.shapeInfoDataBuffer().asInt()), Arrays.toString(tadBuffersZ.getFirst().asInt())); + log.info("y shape: {}", Arrays.toString(y.shapeInfoDataBuffer().asInt())); log.info("-------------"); */ @@ -885,23 +936,23 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case BROADCAST: loop.execBroadcast(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case BROADCAST_BOOL: loop.execBroadcastBool(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; @@ -912,7 +963,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - return op.z(); + return z; } @@ -1202,6 +1253,22 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { */ @Override public INDArray exec(RandomOp op, Random rng) { + return exec(op, null, rng); + } + + + public INDArray exec(RandomOp op, OpContext oc, Random rng) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + + if(op instanceof BaseRandomOp && ((BaseRandomOp)op).isTripleArgRngOp() && z != null && x == null && y == null){ + //Ugly hack to ensure the triple arg call occurs + //See GaussianDistribution.setZ etc + x = z; + y = z; + } + if (!(rng instanceof CpuNativeRandom)) throw new IllegalStateException( "You should use one of NativeRandom classes for NativeOperations execution. Op class: " + op.getClass().getName()); @@ -1210,30 +1277,30 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { //validateDataType(Nd4j.dataType(), op); - Preconditions.checkArgument(op.z().isR(), "Op.Z must have one of floating point types"); + Preconditions.checkArgument(z.isR(), "Op.Z must have one of floating point types"); - val x = op.x() == null ? null : ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); - if (op.x() != null && op.y() != null && op.z() != null) { + if (x != null && y != null && z != null) { // triple arg call loop.execRandom3(null, op.opNum(), rng.getStatePointer(), // rng state ptr - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - op.extraArgsDataBuff(op.z().dataType()).addressPointer()); - } else if (op.x() != null && op.z() != null) { + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + op.extraArgsDataBuff(z.dataType()).addressPointer()); + } else if (x != null && z != null) { //double arg call loop.execRandom2(null, op.opNum(), rng.getStatePointer(), // rng state ptr - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - op.extraArgsDataBuff(op.z().dataType()).addressPointer()); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + op.extraArgsDataBuff(z.dataType()).addressPointer()); } else { // single arg call loop.execRandom(null, op.opNum(), rng.getStatePointer(), // rng state ptr - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - op.extraArgsDataBuff(op.z().dataType()).addressPointer()); + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + op.extraArgsDataBuff(z.dataType()).addressPointer()); } if (loop.lastErrorCode() != 0) @@ -1241,7 +1308,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { profilingConfigurableHookOut(op, st); - return op.z(); + return z; } @Override @@ -1678,11 +1745,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public List calculateOutputShape(@NonNull CustomOp op) { + return calculateOutputShape(op, null); + } + + @Override + public List calculateOutputShape(@NonNull CustomOp op, OpContext opContext) { val lc = op.opName().toLowerCase(); val hash = op.opHash(); val result = new ArrayList(); - if(op.numInputArguments() < 1 && op.getDescriptor().getNumInputs() != -2) { + int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); + if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); @@ -1690,10 +1763,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { return Collections.emptyList(); } - - val inputBuffers = new PointerPointer<>(op.numInputArguments()); - val inputShapes = new PointerPointer<>(op.numInputArguments()); - val inputArgs = op.inputArguments(); + val inputBuffers = new PointerPointer<>(nIn); + val inputShapes = new PointerPointer<>(nIn); + val inputArgs = opContext != null ? opContext.getInputArrays() : op.inputArguments(); int cnt= 0; for (val in: inputArgs) { if (!in.isEmpty()) @@ -1703,76 +1775,95 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } - val iArgs = op.numIArguments() > 0 ? new LongPointer(op.numIArguments()) : null; + int nIArgs = opContext != null ? opContext.numIArguments() : op.numIArguments(); + val iArgs = nIArgs > 0 ? new LongPointer(nIArgs) : null; cnt = 0; - val iArgs1 = op.iArgs(); - for (val i: iArgs1) - iArgs.put(cnt++, i); + if(opContext != null){ + for (val i: opContext.getIArguments()) + iArgs.put(cnt++, i); + } else { + for (val i: op.iArgs()) + iArgs.put(cnt++, i); + } - val tArgs = op.numTArguments() > 0 ? new DoublePointer(op.numTArguments()) : null; - val bArgs = op.numBArguments() > 0 ? new BooleanPointer(op.numBArguments()) : null; + int nTArgs = opContext != null ? opContext.numTArguments() : op.numTArguments(); + val tArgs = nTArgs > 0 ? new DoublePointer(nTArgs) : null; - val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null; + int nBArgs = opContext != null ? opContext.numBArguments() : op.numBArguments(); + val bArgs = nBArgs > 0 ? new BooleanPointer(nBArgs) : null; - cnt = 0; - val bArgs1 = op.bArgs(); - for (val b: bArgs1) + int nDArgs = opContext != null ? opContext.numDArguments() : op.numDArguments(); + val dArgs = nDArgs > 0 ? new IntPointer(nDArgs) : null; + + cnt = 0; + if(opContext != null){ + for (val b: opContext.getBArguments()) bArgs.put(cnt++, b); - - cnt = 0; - val tArgs1 = op.tArgs(); - for (val t: tArgs1) - tArgs.put(cnt++, t); - - cnt = 0; - val dArgs1 = op.dArgs(); - for (val d: dArgs1) - dArgs.put(cnt++, d.toInt()); + } else { + for (val b: op.bArgs()) + bArgs.put(cnt++, b); + } - OpaqueShapeList ptrptr; - try { - ptrptr = loop.calculateOutputShapes2(null, - hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs, - op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments(), dArgs, op.numDArguments()); + cnt = 0; + if(opContext != null){ + for (val b: opContext.getTArguments()) + tArgs.put(cnt++, b); + } else { + for (val b: op.tArgs()) + tArgs.put(cnt++, b); + } - if (loop.lastErrorCode() != 0) - throw new RuntimeException(loop.lastErrorMessage()); - } catch (Throwable t){ - StringBuilder sb = new StringBuilder(); - sb.append("Inputs: [("); - for( int i=0; i 0) - sb.append("), ("); - sb.append(Shape.shapeToStringShort(inputArgs.get(i))); - } - sb.append(")]"); - if(op instanceof DifferentialFunction && ((DifferentialFunction)op).getSameDiff() != null){ - appendSameDiffInfo(sb, (DifferentialFunction) op); - } + cnt = 0; + if(opContext != null){ + for (val b: opContext.getDArguments()) + dArgs.put(cnt++, b.toInt()); + } else { + for (val b: op.dArgs()) + dArgs.put(cnt++, b.toInt()); + } - log.error("Failed to calculate output shapes for op " + op.opName() + ". Attempted to execute with " + - String.valueOf(op.numInputArguments()) + " inputs, " + - String.valueOf(op.numOutputArguments()) + " outputs, "+ - String.valueOf(op.numTArguments()) + " targs and " + - String.valueOf(op.numIArguments()) + " iargs. " + - sb.toString() + - " - Please see above message (printed out from c++) for a possible cause of error."); - throw t; + + OpaqueShapeList ptrptr; + try { + ptrptr = loop.calculateOutputShapes2(null, + hash, inputBuffers, inputShapes, nIn, tArgs, + nTArgs, iArgs, nIArgs, bArgs, nBArgs, dArgs, nDArgs); + + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + } catch (Throwable t){ + StringBuilder sb = new StringBuilder(); + sb.append("Inputs: [("); + for( int i=0; i 0) + sb.append("), ("); + sb.append(Shape.shapeToStringShort(inputArgs.get(i))); } + sb.append(")]"); + if(op instanceof DifferentialFunction && ((DifferentialFunction)op).getSameDiff() != null){ + appendSameDiffInfo(sb, (DifferentialFunction) op); + } + + int nOut = opContext != null ? opContext.numOutputArguments() : op.numOutputArguments(); + log.error("Failed to calculate output shapes for op {}. Attempted to execute with {} inputs, {} outputs, " + + "{} targs, {} iargs, {} bargs and {} dargs. {} - Please see above message (printed out from c++) for a possible cause of error.", + op.opName(), nIn, nOut, nTArgs, nIArgs, nBArgs, nDArgs, sb.toString()); + throw t; + } if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - if (ptrptr == null) - throw new RuntimeException(); + if (ptrptr == null) + throw new RuntimeException(); - for (int e = 0; e < loop.getShapeListSize(ptrptr); e++ ) - result.add(getShapeFromPointer(new PagedPointer(loop.getShape(ptrptr, e)).asLongPointer())); + for (int e = 0; e < loop.getShapeListSize(ptrptr); e++ ) + result.add(getShapeFromPointer(new PagedPointer(loop.getShape(ptrptr, e)).asLongPointer())); - loop.deleteShapeList(ptrptr); + loop.deleteShapeList(ptrptr); if(log.isTraceEnabled()){ String[] arr = new String[result.size()]; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 6c9633a41..4f228717a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -385,6 +385,7 @@ public class RandomOpValidation extends BaseOpValidation { @Test public void testUniformDtype(){ + Nd4j.getRandom().setSeed(12345); for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ SameDiff sd = SameDiff.create(); SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java new file mode 100644 index 000000000..7addd5098 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java @@ -0,0 +1,169 @@ +package org.nd4j.autodiff.samediff; + +import lombok.extern.slf4j.Slf4j; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; +import org.nd4j.imports.TFGraphs.TFGraphTestZooModels; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.AtomicBoolean; +import org.nd4j.resources.Resources; + +import java.io.File; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +@Slf4j +public class SameDiffMultiThreadTests extends BaseND4JTest { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + @Override + public long getTimeoutMilliseconds() { + return 60000L; + } + + @Test + public void testSimple() throws Exception { + + int nThreads = 4; + int nRuns = 1000; + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 10); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10); + + SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 10, 10)); + SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10)); + SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, 10, 10)); + SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 10)); + SDVariable w3 = sd.var("w3", Nd4j.rand(DataType.FLOAT, 10, 10)); + SDVariable b3 = sd.var("b3", Nd4j.rand(DataType.FLOAT, 10)); + + SDVariable l1 = sd.nn.tanh(in.mmul(w1).add(b1)); + SDVariable l2 = sd.nn.sigmoid(l1.mmul(w2).add(b2)); + SDVariable l3 = sd.nn.softmax("out", l2.mmul(w3).add(b3)); + + SDVariable loss = sd.loss.logLoss("loss", label, l3); + + INDArray[] inputArrs = new INDArray[nThreads]; + INDArray[] expOut = new INDArray[nThreads]; + for( int i=0; i 2) + inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3); + else if(i == 1) + inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3); + else if(i == 2) + inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3); + + expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1"); + Nd4j.getExecutioner().commit(); + } + + AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads]; + AtomicInteger[] counters = new AtomicInteger[nThreads]; + Semaphore s = new Semaphore(nThreads); + CountDownLatch latch = new CountDownLatch(nThreads); + + doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch); + + s.release(nThreads); + latch.await(); + + for(int i=0; i op.z() } + def exec(op: Op, context: OpContext): INDArray = + Nd4j.getExecutioner.exec(op, context) + def exec(op: FilterOps): INDArray = { val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*) for (i <- 0 until op.x().length().toInt) { @@ -408,6 +411,9 @@ class FunctionalOpExecutioner extends OpExecutioner { def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = Nd4j.getExecutioner.calculateOutputShape(op) + def calculateOutputShape(op: CustomOp, ctx: OpContext): java.util.List[LongShapeDescriptor] = + Nd4j.getExecutioner.calculateOutputShape(op, ctx) + /** * Equivalent to calli */