SameDiff multi-threaded inference (#263)

* #8682 Don't log openmp BLAS threads for CUDA

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8654 Add SameDiff multi-threaded tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Switching to op context for SameDiff exec

Signed-off-by: Alex Black <blacka101@gmail.com>

* Next steps

Signed-off-by: Alex Black <blacka101@gmail.com>

* Most back to passing

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Better tests, test refactoring

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small tweak

Signed-off-by: Alex Black <blacka101@gmail.com>

* Code duplication reduction

Signed-off-by: Alex Black <blacka101@gmail.com>

* More code deduplication

Signed-off-by: Alex Black <blacka101@gmail.com>

* CUDA fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* More CUDA fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* More fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* ND4S small fixes

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-20 21:24:39 +11:00 committed by GitHub
parent b23ebee432
commit f79207033b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 1426 additions and 691 deletions

View File

@ -31,6 +31,7 @@ import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.shade.jackson.annotation.JsonIgnore; import org.nd4j.shade.jackson.annotation.JsonIgnore;
@ -708,6 +709,10 @@ public abstract class DifferentialFunction {
throw new ND4JIllegalStateException("calculateOutputShape() method leaked out for [" + this.opName() + "]"); throw new ND4JIllegalStateException("calculateOutputShape() method leaked out for [" + this.opName() + "]");
} }
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
throw new ND4JIllegalStateException("calculateOutputShape(OpContext) method leaked out for [" + this.opName() + "]");
}
/** /**
* Calculate the data types for the output arrays. * Calculate the data types for the output arrays.
* Though datatypes can also be inferred from {@link #calculateOutputShape()}, this method differs in that it does not * Though datatypes can also be inferred from {@link #calculateOutputShape()}, this method differs in that it does not

View File

@ -5,6 +5,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
/** /**
@ -60,12 +61,12 @@ public abstract class BaseListener implements Listener {
} }
@Override @Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
//No op //No op
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
//No op //No op
} }

View File

@ -5,6 +5,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
/** /**
@ -104,7 +105,7 @@ public interface Listener {
* @param at Current iteration/epoch etc * @param at Current iteration/epoch etc
* @param op Operation that has just been executed * @param op Operation that has just been executed
*/ */
void preOpExecution(SameDiff sd, At at, SameDiffOp op); void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext);
/** /**
* Called at the end of each operation execution<br> * Called at the end of each operation execution<br>
@ -117,7 +118,7 @@ public interface Listener {
* @param op Operation that has just been executed * @param op Operation that has just been executed
* @param outputs The output arrays for the just-executed operation * @param outputs The output arrays for the just-executed operation
*/ */
void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs); void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs);
/** /**
* Called when any activation becomes available. * Called when any activation becomes available.
@ -127,7 +128,7 @@ public interface Listener {
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br> * Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
* It is guaranteed to be called for variables from requiredVariables().<br> * It is guaranteed to be called for variables from requiredVariables().<br>
* <br> * <br>
* Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, INDArray[])} - * Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, OpContext, INDArray[])} -
* both contain the same information/arrays * both contain the same information/arrays
* *
* @param sd The SameDiff instance * @param sd The SameDiff instance

View File

@ -9,6 +9,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -44,7 +45,7 @@ public class ArraySavingListener extends BaseListener {
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
List<String> outNames = op.getOutputsOfOp(); List<String> outNames = op.getOutputsOfOp();
for(int i=0; i<outputs.length; i++ ){ for(int i=0; i<outputs.length; i++ ){
String filename = (count++) + "_" + outNames.get(i).replaceAll("/", "__") + ".bin"; String filename = (count++) + "_" + outNames.get(i).replaceAll("/", "__") + ".bin";

View File

@ -11,6 +11,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.ScalarOp; import org.nd4j.linalg.api.ops.ScalarOp;
import java.util.Arrays; import java.util.Arrays;
@ -77,7 +78,7 @@ public class ExecDebuggingListener extends BaseListener {
} }
@Override @Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
if(lastIter != at.iteration()){ if(lastIter != at.iteration()){
lastIter = at.iteration(); lastIter = at.iteration();
stepThisIter = 0; stepThisIter = 0;

View File

@ -9,6 +9,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
@ -79,12 +80,12 @@ public class OpBenchmarkListener extends BaseListener {
} }
@Override @Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
start = System.currentTimeMillis(); start = System.currentTimeMillis();
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
long now = System.currentTimeMillis(); long now = System.currentTimeMillis();
if (mode == Mode.SINGLE_ITER_PRINT && printActive && (now-start) > this.minRuntime) { if (mode == Mode.SINGLE_ITER_PRINT && printActive && (now-start) > this.minRuntime) {

View File

@ -19,6 +19,7 @@ import org.nd4j.graph.UIInfoType;
import org.nd4j.graph.UIStaticInfoRecord; import org.nd4j.graph.UIStaticInfoRecord;
import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -410,7 +411,7 @@ public class UIListener extends BaseListener {
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
//Do training set evaluation, if required //Do training set evaluation, if required

View File

@ -30,6 +30,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.AtomicBoolean; import org.nd4j.linalg.primitives.AtomicBoolean;
@ -192,7 +193,7 @@ public class ProfilingListener extends BaseListener {
} }
@Override @Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
if (logActive) { if (logActive) {
opStartNano = System.nanoTime(); opStartNano = System.nanoTime();
@ -202,7 +203,7 @@ public class ProfilingListener extends BaseListener {
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
if (logActive) { if (logActive) {
long now = System.nanoTime(); long now = System.nanoTime();

View File

@ -105,7 +105,6 @@ import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)} * In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
*/ */
@AllArgsConstructor @AllArgsConstructor
@Builder
@Slf4j @Slf4j
public class SameDiff extends SDBaseOps { public class SameDiff extends SDBaseOps {
protected static final String GRAD_FN_KEY = "grad"; protected static final String GRAD_FN_KEY = "grad";
@ -1232,25 +1231,6 @@ public class SameDiff extends SDBaseOps {
return result; return result;
} }
/**
* Create a new SameDiff instance from an existing instance.
* Note that state (variables and functions) is shared between the two SameDiff instance
*
* @param originalSameDiff Original SameDiff instance
* @return Copy
*/
public static SameDiff create(SameDiff originalSameDiff) {
SameDiff ret = SameDiff.builder()
.sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances)
.build();
ret.variables.putAll(originalSameDiff.variables);
//ensuring proper sameDiff reference
DifferentialFunctionFactory differentialFunctionFactory = new DifferentialFunctionFactory(ret);
ret.functionFactory = differentialFunctionFactory;
return ret;
}
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;

View File

@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.internal;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.listeners.Listener;
@ -46,6 +47,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import java.util.*; import java.util.*;
@ -65,7 +67,7 @@ import java.util.*;
* @author Alex Black * @author Alex Black
*/ */
@Slf4j @Slf4j
public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> { public class InferenceSession extends AbstractSession<INDArray, Pair<SameDiffOp,OpContext>> {
private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" + private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" +
"Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed."; "Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";
@ -83,6 +85,8 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
private IdentityDependencyTracker<INDArray, Dep> arrayUseTracker = new IdentityDependencyTracker<>(); private IdentityDependencyTracker<INDArray, Dep> arrayUseTracker = new IdentityDependencyTracker<>();
private Map<String,OpContext> opContexts = new HashMap<>();
public InferenceSession(@NonNull SameDiff sameDiff) { public InferenceSession(@NonNull SameDiff sameDiff) {
super(sameDiff); super(sameDiff);
mmgr = new ArrayCacheMemoryMgr(); mmgr = new ArrayCacheMemoryMgr();
@ -204,18 +208,19 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} }
@Override @Override
public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs, public INDArray[] getOutputs(Pair<SameDiffOp,OpContext> opPair, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) { Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
SameDiffOp op = opPair.getFirst();
at.setFrameIter(outputFrameIter); at.setFrameIter(outputFrameIter);
if (listeners != null && listeners.size() > 0) { if (listeners != null && listeners.size() > 0) {
SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName()); SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName());
for (Listener l : listeners) { for (Listener l : listeners) {
if (l.isActive(at.operation())) if (l.isActive(at.operation()))
l.preOpExecution(sameDiff, at, sdOp); l.preOpExecution(sameDiff, at, sdOp, opPair.getSecond());
} }
} }
INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); INDArray[] out = doExec(op.getOp(), opPair.getRight(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
if (log.isTraceEnabled()) { if (log.isTraceEnabled()) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
@ -246,7 +251,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} }
l.opExecution(sameDiff, at, batch, op, out); l.opExecution(sameDiff, at, batch, op, opPair.getSecond(), out);
for (String varName : namedOuts.keySet()) { for (String varName : namedOuts.keySet()) {
l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName)); l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName));
@ -255,6 +260,8 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} }
} }
op.getOp().clearArrays(); op.getOp().clearArrays();
if(opPair.getSecond() != null)
opPair.getSecond().purge();
//Record array uses for memory management/deallocation //Record array uses for memory management/deallocation
@ -343,7 +350,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
return out; return out;
} }
public INDArray[] doExec(DifferentialFunction op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs, public INDArray[] doExec(DifferentialFunction op, OpContext opContext, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs) { Set<String> constAndPhInputs) {
int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size()) int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size())
@ -467,31 +474,31 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
return new INDArray[]{out}; return new INDArray[]{out};
} else if (op instanceof Assert) { } else if (op instanceof Assert) {
Assert a = (Assert)op; Assert a = (Assert)op;
boolean condition = a.getInputArgument(0).getDouble(0) != 0.0; boolean condition = opContext.getInputArray(0).getDouble(0) != 0.0;
if(!condition){ if(!condition){
//Assertion failed //Assertion failed
String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution"; String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution";
if(a.numInputArguments() >= 3) { if(a.numInputArguments() >= 3) {
INDArray msg = a.getInputArgument(2); INDArray msg = opContext.getInputArray(2);
if (msg != null && msg.dataType() == DataType.UTF8) { if (msg != null && msg.dataType() == DataType.UTF8) {
s += ": " + msg.getString(0); s += ": " + msg.getString(0);
} }
} }
if(a.numInputArguments() >= 5){ if(a.numInputArguments() >= 5){
INDArray arr = a.getInputArgument(4); INDArray arr = opContext.getInputArray(4);
s += "\n" + arr; s += "\n" + arr;
} }
throw new IllegalStateException(s); throw new IllegalStateException(s);
} }
return ((Assert) op).outputArguments().toArray(new INDArray[0]); return opContext.getOutputArrays().toArray(new INDArray[0]);
} else if (op instanceof CustomOp) { } else if (op instanceof CustomOp) {
CustomOp c = (CustomOp) op; CustomOp c = (CustomOp) op;
Nd4j.exec(c); Nd4j.exec(c, opContext);
return c.outputArguments().toArray(new INDArray[0]); return opContext.getOutputArrays().toArray(new INDArray[0]);
} else if (op instanceof Op) { } else if (op instanceof Op) {
Op o = (Op) op; Op o = (Op) op;
Nd4j.exec(o); Nd4j.exec(o, opContext);
return new INDArray[]{o.z()}; return new INDArray[]{opContext.getOutputArray(0)};
} else { } else {
throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName()); throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
} }
@ -774,7 +781,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} }
@Override @Override
public SameDiffOp getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs, public Pair<SameDiffOp,OpContext> getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs, Map<String, INDArray> placeholderValues, Set<String> allReqVariables) { Set<String> constAndPhInputs, Map<String, INDArray> placeholderValues, Set<String> allReqVariables) {
SameDiffOp sdo = sameDiff.getOps().get(opName); SameDiffOp sdo = sameDiff.getOps().get(opName);
DifferentialFunction df = sdo.getOp(); DifferentialFunction df = sdo.getOp();
@ -786,7 +793,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration ||
df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) { df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) {
//Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case //Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case
return sdo; return new Pair<>(sdo, null);
} }
//Infer the args based on the inputs (variable + frame + iteration) //Infer the args based on the inputs (variable + frame + iteration)
@ -839,24 +846,39 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//TODO let's find a way to use in-place modification for loops where possible to reduce memory requirements //TODO let's find a way to use in-place modification for loops where possible to reduce memory requirements
boolean isLoop = !frameIter.getFrame().equals(OUTER_FRAME) && frameIter.getIteration() > 0; boolean isLoop = !frameIter.getFrame().equals(OUTER_FRAME) && frameIter.getIteration() > 0;
OpContext oc = opContexts.get(opName);
if(oc == null){
oc = Nd4j.getExecutioner().buildContext();
opContexts.put(opName, oc);
}
if (df instanceof CustomOp) { if (df instanceof CustomOp) {
DynamicCustomOp customOp = (DynamicCustomOp) df; DynamicCustomOp customOp = (DynamicCustomOp) df;
if (args != null) { if (args != null) {
customOp.setInputArguments(args); oc.setInputArrays(args);
} }
if (df instanceof Identity) { if (df instanceof Identity) {
//We don't need to allocate an output array for Identity, we pass through the input array without copying //We don't need to allocate an output array for Identity, we pass through the input array without copying
return sdo; return new Pair<>(sdo, oc);
} }
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape(); if(customOp.numIArguments() > 0)
oc.setIArguments(customOp.iArgs());
if(customOp.numDArguments() > 0)
oc.setDArguments(customOp.dArgs());
if(customOp.numTArguments() > 0)
oc.setTArguments(customOp.tArgs());
if(customOp.numBArguments() > 0)
oc.setBArguments(customOp.bArgs());
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape(oc);
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
String[] outNames = df.outputVariablesNames(); String[] outNames = df.outputVariablesNames();
Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" + Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" +
" with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length); " with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length);
for (int i = 0; i < outShape.size(); i++) { for (int i = 0; i < outShape.size(); i++) {
INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i));
LongShapeDescriptor reqShape = outShape.get(i); LongShapeDescriptor reqShape = outShape.get(i);
//Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/deeplearning4j/deeplearning4j/issues/6872 //Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/deeplearning4j/deeplearning4j/issues/6872
@ -870,7 +892,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
boolean isOutput = allReqVariables.contains(outNames[i]); boolean isOutput = allReqVariables.contains(outNames[i]);
INDArray out = mmgr.allocate(isOutput, reqShape); INDArray out = mmgr.allocate(isOutput, reqShape);
customOp.setOutputArgument(i, out); oc.setOutputArray(i, out);
} }
} else if (df instanceof Op) { } else if (df instanceof Op) {
@ -909,9 +931,9 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} }
if (args != null && args.length > 0) { if (args != null && args.length > 0) {
op.setX(args[0]); oc.setInputArray(0, args[0]);
if (args.length == 2 && !axisArg) if (args.length == 2 && !axisArg)
op.setY(args[1]); oc.setInputArray(1, args[1]);
} }
@ -920,18 +942,18 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]); boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]);
if (emptyReduce) { if (emptyReduce) {
//Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
INDArray z = mmgr.allocate(false, op.x().dataType(), op.x().shape()); INDArray z = mmgr.allocate(false, oc.getInputArray(0).dataType(), oc.getInputArray(0).shape());
op.setZ(z); oc.setOutputArray(0, z);
} else { } else {
List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape(); List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape(oc);
Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass());
LongShapeDescriptor lsd = outputShape.get(0); LongShapeDescriptor lsd = outputShape.get(0);
INDArray z = mmgr.allocate(isOutput, lsd); INDArray z = mmgr.allocate(isOutput, lsd);
op.setZ(z); oc.setOutputArray(0, z);
} }
} }
return sdo; return new Pair<>(sdo, oc);
} }

View File

@ -11,10 +11,12 @@ import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicDouble; import org.nd4j.linalg.primitives.AtomicDouble;
import org.nd4j.linalg.primitives.Pair;
import java.util.*; import java.util.*;
@ -135,10 +137,11 @@ public class TrainingSession extends InferenceSession {
} }
@Override @Override
public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs, public INDArray[] getOutputs(Pair<SameDiffOp, OpContext> opPair, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) { Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
//Get outputs from InferenceSession //Get outputs from InferenceSession
INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables); INDArray[] out = super.getOutputs(opPair, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables);
SameDiffOp op = opPair.getFirst();
List<String> outputs = op.getOutputsOfOp(); List<String> outputs = op.getOutputsOfOp();
int outIdx = 0; int outIdx = 0;

View File

@ -12,6 +12,8 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.List; import java.util.List;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
/** /**
@ -36,7 +38,7 @@ public class ActivationGradientCheckListener extends BaseListener {
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener"); Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener");
Preconditions.checkState(eps != 0.0, "Epsilon has not been set"); Preconditions.checkState(eps != 0.0, "Epsilon has not been set");

View File

@ -14,7 +14,10 @@ import org.nd4j.linalg.api.ops.Op;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
public class NonInplaceValidationListener extends BaseListener { public class NonInplaceValidationListener extends BaseListener {
@ -33,25 +36,25 @@ public class NonInplaceValidationListener extends BaseListener {
} }
@Override @Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext oc) {
if(op.getOp().isInPlace()){ if(op.getOp().isInPlace()){
//Don't check inplace op //Don't check inplace op
return; return;
} }
if(op.getOp() instanceof Op){ if(op.getOp() instanceof Op){
Op o = (Op)op.getOp(); Op o = (Op)op.getOp();
if(o.x() == null){ if(oc.getInputArray(0) == null){
//No input op //No input op
return; return;
} else if(o.y() == null){ } else if(oc.getInputArray(1) == null){
opInputsOrig = new INDArray[]{o.x()}; opInputsOrig = new INDArray[]{oc.getInputArray(0)};
opInputs = new INDArray[]{o.x().dup()}; opInputs = new INDArray[]{oc.getInputArray(0).dup()};
} else { } else {
opInputsOrig = new INDArray[]{o.x(), o.y()}; opInputsOrig = new INDArray[]{oc.getInputArray(0), oc.getInputArray(1)};
opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; opInputs = new INDArray[]{oc.getInputArray(0).dup(), oc.getInputArray(1).dup()};
} }
} else if(op.getOp() instanceof DynamicCustomOp){ } else if(op.getOp() instanceof DynamicCustomOp){
val arr = ((DynamicCustomOp) op.getOp()).inputArguments(); List<INDArray> arr = oc.getInputArrays(); // ((DynamicCustomOp) op.getOp()).inputArguments();
opInputs = new INDArray[arr.size()]; opInputs = new INDArray[arr.size()];
opInputsOrig = new INDArray[arr.size()]; opInputsOrig = new INDArray[arr.size()];
for( int i=0; i<arr.size(); i++ ){ for( int i=0; i<arr.size(); i++ ){
@ -64,7 +67,7 @@ public class NonInplaceValidationListener extends BaseListener {
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
if(op.getOp().isInPlace()){ if(op.getOp().isInPlace()){
//Don't check inplace op //Don't check inplace op
return; return;

View File

@ -93,6 +93,12 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();

View File

@ -55,6 +55,11 @@ public abstract class BaseOpContext implements OpContext {
return fastpath_i; return fastpath_i;
} }
@Override
public int numIArguments() {
return fastpath_i.size();
}
@Override @Override
public void setTArguments(double... arguments) { public void setTArguments(double... arguments) {
fastpath_t.clear(); fastpath_t.clear();
@ -67,6 +72,11 @@ public abstract class BaseOpContext implements OpContext {
return fastpath_t; return fastpath_t;
} }
@Override
public int numTArguments() {
return fastpath_t.size();
}
@Override @Override
public void setBArguments(boolean... arguments) { public void setBArguments(boolean... arguments) {
fastpath_b.clear(); fastpath_b.clear();
@ -79,6 +89,11 @@ public abstract class BaseOpContext implements OpContext {
return fastpath_b; return fastpath_b;
} }
@Override
public int numBArguments() {
return fastpath_b.size();
}
@Override @Override
public void setDArguments(DataType... arguments) { public void setDArguments(DataType... arguments) {
fastpath_d.clear(); fastpath_d.clear();
@ -91,6 +106,11 @@ public abstract class BaseOpContext implements OpContext {
return fastpath_d; return fastpath_d;
} }
@Override
public int numDArguments() {
return fastpath_d.size();
}
@Override @Override
public void setInputArray(int index, @NonNull INDArray array) { public void setInputArray(int index, @NonNull INDArray array) {
fastpath_in.put(index, array); fastpath_in.put(index, array);
@ -110,6 +130,16 @@ public abstract class BaseOpContext implements OpContext {
return result; return result;
} }
@Override
public int numInputArguments() {
return fastpath_in.size();
}
@Override
public INDArray getInputArray(int idx) {
return fastpath_in.get(idx);
}
@Override @Override
public List<INDArray> getOutputArrays() { public List<INDArray> getOutputArrays() {
val result = new ArrayList<INDArray>(); val result = new ArrayList<INDArray>();
@ -129,6 +159,15 @@ public abstract class BaseOpContext implements OpContext {
fastpath_out.put(index, array); fastpath_out.put(index, array);
} }
@Override
public INDArray getOutputArray(int i) {
return fastpath_out.get(i);
}
@Override
public int numOutputArguments() {
return fastpath_out.size();
}
@Override @Override
public void setInputArrays(@NonNull List<INDArray> arrays) { public void setInputArrays(@NonNull List<INDArray> arrays) {

View File

@ -72,19 +72,33 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo
} }
@Override @Override
public boolean validateDataTypes() { public DataType resultType(OpContext oc) {
if (y() != null) return DataType.BOOL;
Preconditions.checkArgument(x().dataType() == y().dataType(),"Op.X type must be the same as Op.Y:" + }
@Override
public boolean validateDataTypes(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
INDArray y = oc != null ? oc.getInputArray(1) : y();
if (y != null)
Preconditions.checkArgument(x.dataType() == y.dataType(),"Op.X type must be the same as Op.Y:" +
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
if (z() != null) INDArray z = oc != null ? oc.getOutputArray(0) : z();
Preconditions.checkArgument(z().isB(), "Op.X type must be bool: got type %s for op %s", x.dataType(), getClass()); if (z != null)
Preconditions.checkArgument(z.isB(), "Op.Z type must be bool: got type %s for op %s", z.dataType(), getClass());
return true; return true;
} }
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();

View File

@ -90,27 +90,43 @@ public abstract class BaseReduceFloatOp extends BaseReduceOp implements ReduceFl
@Override @Override
public DataType resultType() { public DataType resultType() {
if (this.x() != null && this.x().isR()) return resultType(null);
return this.x().dataType(); }
@Override
public DataType resultType(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if (x != null && x.isR())
return x.dataType();
return Nd4j.defaultFloatingPointType(); return Nd4j.defaultFloatingPointType();
} }
@Override @Override
public boolean validateDataTypes() { public boolean validateDataTypes(OpContext oc) {
if (y() != null) INDArray x = oc != null ? oc.getInputArray(0) : x();
Preconditions.checkArgument(x().dataType() == y().dataType(), INDArray y = oc != null ? oc.getInputArray(1) : y();
"Op.X [%s] type must be the same as Op.Y [%s] for op %s: x.shape=%ndShape, y.shape=%ndShape", x().dataType(), if (y != null)
y().dataType(), getClass().getName(), x(), y() ); Preconditions.checkArgument(x.dataType() == y.dataType(),
"Op.X [%s] type must be the same as Op.Y [%s] for op %s: x.shape=%ndShape, y.shape=%ndShape", x.dataType(),
y.dataType(), getClass().getName(), x, y );
if (z() != null) INDArray z = oc != null ? oc.getOutputArray(0) : z();
Preconditions.checkArgument(z().isR(),"Op.Z (result array) must be one of floating types: z datatype = %s", z().dataType()); if (z != null)
Preconditions.checkArgument(z.isR(),"Op.Z (result array) must be one of floating types: z datatype = %s", z.dataType());
return true; return true;
} }
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();

View File

@ -69,19 +69,33 @@ public abstract class BaseReduceLongOp extends BaseReduceOp implements ReduceLon
} }
@Override @Override
public boolean validateDataTypes() { public DataType resultType(OpContext oc) {
if (y() != null) return DataType.LONG;
Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X type must be the same as Op.Y:" + }
@Override
public boolean validateDataTypes(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
INDArray y = oc != null ? oc.getInputArray(1) : y();
if (y != null)
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X type must be the same as Op.Y:" +
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
if (z() != null) INDArray z = oc != null ? oc.getOutputArray(0) : z();
Preconditions.checkArgument( z().dataType() == DataType.LONG,"Op.Z must be long: has type %s for op %s", z().dataType(), getClass()); if (z != null)
Preconditions.checkArgument( z.dataType() == DataType.LONG,"Op.Z must be long: has type %s for op %s", z.dataType(), getClass());
return true; return true;
} }
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();

View File

@ -77,26 +77,42 @@ public abstract class BaseReduceSameOp extends BaseReduceOp implements ReduceSam
} }
@Override @Override
public boolean validateDataTypes() { public DataType resultType(OpContext oc){
if (y() != null) return oc.getInputArray(0).dataType();
Preconditions.checkArgument(x().dataType() == y().dataType(),"Op.X type must be the same as Op.Y type:" + }
@Override
public boolean validateDataTypes(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
INDArray y = oc != null ? oc.getInputArray(1) : y();
if (y != null)
Preconditions.checkArgument(x.dataType() == y.dataType(),"Op.X type must be the same as Op.Y type:" +
" x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName());
if (z() != null) INDArray z = oc != null ? oc.getOutputArray(0) : z();
Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must be the same as Op.X type. Op.X.datatype=%s, " + if (z != null)
"Op.Z.datatype=%s", x().dataType(), z.dataType()); Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must be the same as Op.X type. Op.X.datatype=%s, " +
"Op.Z.datatype=%s", x.dataType(), z.dataType());
return true; return true;
} }
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();
//Calculate reduction shape. Note that reduction on scalar - returns a scalar //Calculate reduction shape. Note that reduction on scalar - returns a scalar
long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims());
return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, this.resultType())); DataType rt = oc != null ? resultType(oc) : resultType();
return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, rt));
} }
@Override @Override

View File

@ -98,6 +98,12 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp {
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();

View File

@ -115,6 +115,13 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp {
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
val ret = new ArrayList<LongShapeDescriptor>(1); val ret = new ArrayList<LongShapeDescriptor>(1);
long[] s; long[] s;

View File

@ -89,7 +89,12 @@ public abstract class BaseTransformAnyOp extends BaseTransformOp implements Tran
} }
@Override @Override
public boolean validateDataTypes(boolean experimentalMode) { public DataType resultType(OpContext oc) {
return oc.getInputArray(0).dataType();
}
@Override
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
return true; return true;
} }

View File

@ -88,20 +88,34 @@ public abstract class BaseTransformBoolOp extends BaseTransformOp implements Tra
} }
@Override @Override
public boolean validateDataTypes(boolean experimentalMode) { public DataType resultType(OpContext oc) {
return DataType.BOOL;
}
@Override
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
INDArray y = oc != null ? oc.getInputArray(1) : y();
INDArray z = oc != null ? oc.getOutputArray(0) : z();
if (y() != null) if (y() != null)
Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X must be the same type as Op.Y: " + Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must be the same type as Op.Y: " +
"x.datatype=%s, y.datatype=%s", x().dataType(), y.dataType()); "x.datatype=%s, y.datatype=%s", x.dataType(), y.dataType());
if (z() != null) if (z != null)
Preconditions.checkArgument(z().isB(),"Op.Z type must be bool: z.datatype=%s for op %s", z().dataType(), getClass()); Preconditions.checkArgument(z.isB(),"Op.Z type must be bool: z.datatype=%s for op %s", z.dataType(), getClass());
return true; return true;
} }
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), DataType.BOOL)); return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), DataType.BOOL));

View File

@ -72,19 +72,37 @@ public abstract class BaseTransformFloatOp extends BaseTransformOp implements Tr
} }
@Override @Override
public boolean validateDataTypes(boolean experimentalMode) { public DataType resultType(OpContext oc) {
if (y() != null && !experimentalMode) { if (oc.getInputArray(0) != null && oc.getInputArray(0).isR())
return oc.getInputArray(0).dataType();
return Nd4j.defaultFloatingPointType();
}
@Override
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
INDArray y = oc != null ? oc.getInputArray(1) : y();
INDArray z = oc != null ? oc.getOutputArray(0) : z();
if (y != null && !experimentalMode) {
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y"); Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y");
} }
if (z() != null) if (z != null)
Preconditions.checkArgument(z().isR(),"Op.Z must be one of floating types: z.datatype=%s for op %s", z().dataType(), getClass()); Preconditions.checkArgument(z.isR(),"Op.Z must be one of floating types: z.datatype=%s for op %s", z.dataType(), getClass());
return true; return true;
} }
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.isR() ? x.dataType() : Nd4j.defaultFloatingPointType())); return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.isR() ? x.dataType() : Nd4j.defaultFloatingPointType()));

View File

@ -89,22 +89,36 @@ public abstract class BaseTransformSameOp extends BaseTransformOp implements Tra
} }
@Override @Override
public boolean validateDataTypes(boolean experimentalMode) { public DataType resultType(OpContext oc) {
if (y() != null) { return oc.getInputArray(0).dataType();
Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X type must be the same as Op.Y type: x.datatype=%s, y.datatype=%s for op %s", }
x().dataType(), y().dataType(), getClass());
@Override
public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
INDArray y = oc != null ? oc.getInputArray(1) : y();
INDArray z = oc != null ? oc.getOutputArray(0) : z();
if (y != null) {
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X type must be the same as Op.Y type: x.datatype=%s, y.datatype=%s for op %s",
x.dataType(), y.dataType(), getClass());
} }
if (z() != null) if (z != null)
Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must be the same as Op.X type: x.datatype=%s, z.datatype=%s for op %s", Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must be the same as Op.X type: x.datatype=%s, z.datatype=%s for op %s",
x().dataType(), z.dataType(), getClass()); x.dataType(), z.dataType(), getClass());
return true; return true;
} }
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(x == null) if(x == null)
return Collections.emptyList(); return Collections.emptyList();

View File

@ -76,20 +76,28 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T
return this.x().dataType(); return this.x().dataType();
} }
@Override
public DataType resultType(OpContext opContext) {
return opContext.getInputArray(0).dataType();
}
@Override @Override
public boolean validateDataTypes(boolean experimentalMode) { public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
Preconditions.checkArgument(x().isR(), "Op.X must be one of floating types: x.datatype=%s for op %s", x().dataType(), getClass()); INDArray x = oc != null ? oc.getInputArray(0) : x();
INDArray y = oc != null ? oc.getInputArray(1) : y();
INDArray z = oc != null ? oc.getOutputArray(0) : z();
Preconditions.checkArgument(x.isR(), "Op.X must be one of floating types: x.datatype=%s for op %s", x.dataType(), getClass());
if (y() != null) { if (y != null) {
Preconditions.checkArgument(y().isR(), "Op.Y must be one of floating types: y.datatype=%s for op %s", y().dataType(), getClass()); Preconditions.checkArgument(y.isR(), "Op.Y must be one of floating types: y.datatype=%s for op %s", y.dataType(), getClass());
if (!experimentalMode) if (!experimentalMode)
Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y"); Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y");
} }
if (z() != null) if (z() != null)
Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must have the same type as Op.X: x.datatype=%s, z.datatype=%s for op %s", Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must have the same type as Op.X: x.datatype=%s, z.datatype=%s for op %s",
x.dataType(), z.dataType(), getClass()); x.dataType(), z.dataType(), getClass());
return true; return true;
@ -102,6 +110,13 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T
return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.dataType())); return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.dataType()));
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
if(oc.getInputArray(0) == null)
return Collections.emptyList();
return Collections.singletonList(LongShapeDescriptor.fromShape(oc.getInputArray(0).shape(), oc.getInputArray(0).dataType()));
}
@Override @Override
public List<org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(List<org.nd4j.linalg.api.buffer.DataType> dataTypes){ public List<org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(List<org.nd4j.linalg.api.buffer.DataType> dataTypes){
//All strict tranform ops: FP in, FP out //All strict tranform ops: FP in, FP out

View File

@ -108,10 +108,16 @@ public interface CustomOp {
/** /**
* Calculate the output shape for this op * Calculate the output shape for this op
* @return * @return Output array shapes
*/ */
List<LongShapeDescriptor> calculateOutputShape(); List<LongShapeDescriptor> calculateOutputShape();
/**
* Calculate the output shape for this op
* @return Output array shapes
*/
List<LongShapeDescriptor> calculateOutputShape(OpContext opContext);
/** /**
* Get the custom op descriptor if one is available. * Get the custom op descriptor if one is available.
* @return * @return

View File

@ -493,6 +493,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
val descriptor = getDescriptor(); val descriptor = getDescriptor();
if (outputShapes != null && !outputShapes.isEmpty()) if (outputShapes != null && !outputShapes.isEmpty())
return outputShapes; return outputShapes;
@ -504,34 +509,41 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
//not fully initialized: missing integer args //not fully initialized: missing integer args
if (descriptor.getNumIArgs() >= 0 && numIArguments() < descriptor.getNumIArgs()) { int nI = oc != null ? oc.numIArguments() : numIArguments();
if (descriptor.getNumIArgs() >= 0 && nI < descriptor.getNumIArgs()) {
if(log.isTraceEnabled()){ if(log.isTraceEnabled()){
log.trace("Could not calculate output shape for op {}: not fully initialized ({} IArgs specified, " + log.trace("Could not calculate output shape for op {}: not fully initialized ({} IArgs specified, " +
"{} required)", getClass().getName(),numIArguments(), descriptor.getNumIArgs()); "{} required)", getClass().getName(), nI, descriptor.getNumIArgs());
} }
return Collections.emptyList(); return Collections.emptyList();
} }
//not fully initialized: missing floating point args //not fully initialized: missing floating point args
if (descriptor.getNumTArgs() >= 0 && numTArguments() < descriptor.getNumTArgs()) { int nT = oc != null ? oc.numTArguments() : numTArguments();
if (descriptor.getNumTArgs() >= 0 && nT < descriptor.getNumTArgs()) {
if(log.isTraceEnabled()){ if(log.isTraceEnabled()){
log.trace("Could not calculate output shape for op {}: not fully initialized ({} TArgs specified, " + log.trace("Could not calculate output shape for op {}: not fully initialized ({} TArgs specified, " +
"{} required)", getClass().getName(),numTArguments(), descriptor.getNumTArgs()); "{} required)", getClass().getName(), nT, descriptor.getNumTArgs());
} }
return Collections.emptyList(); return Collections.emptyList();
} }
//not fully initialized: missing INDArray input args //not fully initialized: missing INDArray input args
if(descriptor.getNumInputs() >= 0 && numInputArguments() < descriptor.getNumInputs()){ int nIn = oc != null ? oc.numInputArguments() : numInputArguments();
if(descriptor.getNumInputs() >= 0 && nIn < descriptor.getNumInputs()){
if(log.isTraceEnabled()){ if(log.isTraceEnabled()){
log.trace("Could not calculate output shape for op {}: not fully initialized ({} input (INDArray) args specified, " + log.trace("Could not calculate output shape for op {}: not fully initialized ({} input (INDArray) args specified, " +
"{} required)", getClass().getName(),numInputArguments(), descriptor.getNumInputs()); "{} required)", getClass().getName(), nIn, descriptor.getNumInputs());
} }
return Collections.emptyList(); return Collections.emptyList();
} }
List<LongShapeDescriptor> ret = Nd4j.getExecutioner().calculateOutputShape(this); List<LongShapeDescriptor> ret;
if(oc == null)
ret = Nd4j.getExecutioner().calculateOutputShape(this);
else
ret = Nd4j.getExecutioner().calculateOutputShape(this, oc);
return ret; return ret;
} }

View File

@ -89,6 +89,14 @@ public class NoOp extends DynamicCustomOp {
return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor()); return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor());
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
if(oc.getInputArrays() != null && !oc.getInputArrays().isEmpty()){
return Collections.singletonList(oc.getInputArray(0).shapeDescriptor());
}
return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor());
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
return Collections.singletonList(DataType.BOOL); return Collections.singletonList(DataType.BOOL);

View File

@ -39,12 +39,15 @@ public interface OpContext extends AutoCloseable {
List<Long> getIArguments(); List<Long> getIArguments();
int numIArguments();
/** /**
* This method sets floating point arguments required for operation * This method sets floating point arguments required for operation
* @param arguments * @param arguments
*/ */
void setTArguments(double... arguments); void setTArguments(double... arguments);
List<Double> getTArguments(); List<Double> getTArguments();
int numTArguments();
/** /**
* This method sets data type arguments required for operation * This method sets data type arguments required for operation
@ -52,14 +55,15 @@ public interface OpContext extends AutoCloseable {
*/ */
void setDArguments(DataType... arguments); void setDArguments(DataType... arguments);
List<DataType> getDArguments(); List<DataType> getDArguments();
int numDArguments();
/** /**
* This method sets boolean arguments required for operation * This method sets boolean arguments required for operation
* @param arguments * @param arguments
*/ */
void setBArguments(boolean... arguments); void setBArguments(boolean... arguments);
List<Boolean> getBArguments(); List<Boolean> getBArguments();
int numBArguments();
/** /**
* This method sets root-level seed for rng * This method sets root-level seed for rng
@ -99,6 +103,10 @@ public interface OpContext extends AutoCloseable {
*/ */
List<INDArray> getInputArrays(); List<INDArray> getInputArrays();
int numInputArguments();
INDArray getInputArray(int idx);
/** /**
* This method adds INDArray as output for future op call * This method adds INDArray as output for future op call
* @param index * @param index
@ -124,6 +132,10 @@ public interface OpContext extends AutoCloseable {
*/ */
List<INDArray> getOutputArrays(); List<INDArray> getOutputArrays();
INDArray getOutputArray(int i);
int numOutputArguments();
/** /**
* This method returns pointer to context, to be used during native op execution * This method returns pointer to context, to be used during native op execution
* @return * @return

View File

@ -86,7 +86,9 @@ public interface ReduceOp extends Op {
*/ */
DataType resultType(); DataType resultType();
boolean validateDataTypes(); DataType resultType(OpContext oc);
boolean validateDataTypes(OpContext oc);
Number getFinalResult(); Number getFinalResult();

View File

@ -31,7 +31,9 @@ public interface TransformOp extends Op {
*/ */
DataType resultType(); DataType resultType();
DataType resultType(OpContext opContext);
Type getOpType(); Type getOpType();
boolean validateDataTypes(boolean experimentalMode); boolean validateDataTypes(OpContext opContext, boolean experimentalMode);
} }

View File

@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor; import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -237,6 +238,11 @@ public class ScatterUpdate implements CustomOp {
return Nd4j.getExecutioner().calculateOutputShape(this); return Nd4j.getExecutioner().calculateOutputShape(this);
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
return Nd4j.getExecutioner().calculateOutputShape(this, opContext);
}
@Override @Override
public CustomOpDescriptor getDescriptor() { public CustomOpDescriptor getDescriptor() {
return op.getDescriptor(); return op.getDescriptor();

View File

@ -55,7 +55,7 @@ import java.util.*;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j @Slf4j
public class DefaultOpExecutioner implements OpExecutioner { public abstract class DefaultOpExecutioner implements OpExecutioner {
private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: deeplearning4j.org/docs/latest/nd4j-overview#workspaces-panic"; private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: deeplearning4j.org/docs/latest/nd4j-overview#workspaces-panic";
@ -108,9 +108,10 @@ public class DefaultOpExecutioner implements OpExecutioner {
} }
@Override @Override
public INDArray exec(Op op) { public abstract INDArray exec(Op op);
throw new IllegalStateException("Java computation no longer supported");
} @Override
public abstract INDArray exec(Op op, OpContext opContext);
@Override @Override
public Op execAndReturn(Op op) { public Op execAndReturn(Op op) {
@ -175,24 +176,16 @@ public class DefaultOpExecutioner implements OpExecutioner {
} }
@Override @Override
public INDArray exec(ReduceOp op) { public abstract INDArray exec(ReduceOp op);
throw new UnsupportedOperationException("Java computation no longer supported");
}
@Override @Override
public INDArray exec(Variance accumulation) { public abstract INDArray exec(Variance accumulation);
throw new UnsupportedOperationException("Operation should use exec special");
}
@Override @Override
public INDArray exec(IndexAccumulation op) { public abstract INDArray exec(IndexAccumulation op);
throw new UnsupportedOperationException("Operation should use exec special");
}
@Override @Override
public INDArray exec(BroadcastOp broadcast) { public abstract INDArray exec(BroadcastOp broadcast);
throw new IllegalStateException("Java computation no longer supported");
}
@Override @Override
public void exec(MetaOp op) { public void exec(MetaOp op) {
@ -215,9 +208,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
} }
@Override @Override
public INDArray exec(ScalarOp op) { public abstract INDArray exec(ScalarOp op);
throw new UnsupportedOperationException();
}
@Override @Override
public void exec(List<Aggregate> batch) { public void exec(List<Aggregate> batch) {
@ -241,9 +232,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
* @param rng * @param rng
*/ */
@Override @Override
public INDArray exec(RandomOp op, Random rng) { public abstract INDArray exec(RandomOp op, Random rng);
throw new UnsupportedOperationException();
}
@Deprecated @Deprecated
@ -741,6 +730,11 @@ public class DefaultOpExecutioner implements OpExecutioner {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape(CustomOp op, OpContext opContext) {
throw new UnsupportedOperationException();
}
@Override @Override
public INDArray[] allocateOutputArrays(CustomOp op){ public INDArray[] allocateOutputArrays(CustomOp op){
List<LongShapeDescriptor> shapes = calculateOutputShape(op); List<LongShapeDescriptor> shapes = calculateOutputShape(op);
@ -946,4 +940,44 @@ public class DefaultOpExecutioner implements OpExecutioner {
public String runFullBenchmarkSuit(boolean printOut) { public String runFullBenchmarkSuit(boolean printOut) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
public void setX(INDArray x, Op op, OpContext oc){
if(oc != null)
oc.setInputArray(0, x);
else
op.setX(x);
}
public INDArray getX(Op op, OpContext oc){
if( oc != null )
return oc.getInputArray(0);
return op.x();
}
public void setY(INDArray y, Op op, OpContext oc){
if(oc != null)
oc.setInputArray(1, y);
else
op.setY(y);
}
public INDArray getY(Op op, OpContext oc){
if( oc != null )
return oc.getInputArray(1);
return op.y();
}
public void setZ(INDArray z, Op op, OpContext oc){
if(oc != null)
oc.setOutputArray(0, z);
else
op.setZ(z);
}
public INDArray getZ(Op op, OpContext oc){
if( oc != null )
return oc.getOutputArray(0);
return op.z();
}
} }

View File

@ -98,6 +98,13 @@ public interface OpExecutioner {
*/ */
INDArray exec(Op op); INDArray exec(Op op);
/**
* Execute the operation
*
* @param op the operation to execute
*/
INDArray exec(Op op, OpContext opContext);
/**Execute a TransformOp and return the result /**Execute a TransformOp and return the result
* @param op the operation to execute * @param op the operation to execute
*/ */
@ -364,6 +371,8 @@ public interface OpExecutioner {
List<LongShapeDescriptor> calculateOutputShape(CustomOp op); List<LongShapeDescriptor> calculateOutputShape(CustomOp op);
List<LongShapeDescriptor> calculateOutputShape(CustomOp op, OpContext opContext);
/** /**
* Equivalent to calli * Equivalent to calli
*/ */

View File

@ -26,6 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -150,6 +151,11 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
return OUT_SHAPE; return OUT_SHAPE;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){
return OUT_SHAPE;
}
public Op.Type opType() { public Op.Type opType() {
return Op.Type.LOGIC; return Op.Type.LOGIC;
} }

View File

@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseReduceOp; import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -131,8 +132,14 @@ public class Variance extends BaseReduceOp {
@Override @Override
public DataType resultType() { public DataType resultType() {
if (this.x() != null && this.x().isR()) return resultType(null);
return this.x().dataType(); }
@Override
public DataType resultType(OpContext oc){
INDArray x = oc != null ? oc.getInputArray(0) : x();
if (x != null && x.isR())
return x.dataType();
if(this.arg() != null){ if(this.arg() != null){
return this.arg().dataType(); return this.arg().dataType();
@ -142,14 +149,18 @@ public class Variance extends BaseReduceOp {
} }
@Override @Override
public boolean validateDataTypes() { public boolean validateDataTypes(OpContext oc) {
if (!x().isR()) INDArray x = oc != null ? oc.getInputArray(0) : x();
if (x != null && !x.isR()) {
return false;
}
INDArray y = oc != null ? oc.getInputArray(1) : y();
if (y != null && !y.isR())
return false; return false;
if (y() != null && !y().isR()) INDArray z = oc != null ? oc.getOutputArray(0) : z();
return false; if (z != null && !z.isR())
if (z() != null && !z().isR())
return false; return false;
return true; return true;
@ -157,15 +168,22 @@ public class Variance extends BaseReduceOp {
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
if(args().length < 1) { return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
INDArray x = oc != null ? oc.getInputArray(0) : x();
if(oc == null && args().length < 1) {
throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found."); throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found.");
} }
long[] argShape = arg().getShape(); long[] argShape = arg().getShape();
if (argShape == null && x() == null) { if (argShape == null && x == null) {
return Collections.emptyList(); return Collections.emptyList();
} }
long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x().shape() : argShape); long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x.shape() : argShape);
val ret = new ArrayList<LongShapeDescriptor>(1); val ret = new ArrayList<LongShapeDescriptor>(1);
val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims()); val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims());

View File

@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -94,20 +95,29 @@ public class MaxOut extends BaseTransformOp {
return Nd4j.defaultFloatingPointType(); return Nd4j.defaultFloatingPointType();
} }
@Override
public DataType resultType(OpContext oc) {
return Nd4j.defaultFloatingPointType();
}
@Override @Override
public Type getOpType() { public Type getOpType() {
return Type.TRANSFORM_STRICT; return Type.TRANSFORM_STRICT;
} }
@Override @Override
public boolean validateDataTypes(boolean experimentalMode) { public boolean validateDataTypes(OpContext oc, boolean experimentalMode) {
if (!x().isR()) INDArray x = oc != null ? oc.getInputArray(0) : x();
INDArray y = oc != null ? oc.getInputArray(1) : y();
INDArray z = oc != null ? oc.getOutputArray(0) : z();
if (!x.isR())
return false; return false;
if (y() != null && !y().isR()) if (y != null && !y().isR())
return false; return false;
if (z() != null && z().dataType() != x().dataType()) if (z != null && z().dataType() != x().dataType())
return false; return false;
return true; return true;

View File

@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp; import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.RandomOp; import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
@ -65,6 +66,11 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp {
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return calculateOutputShape(null);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
if(shape != null){ if(shape != null){
return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType())); return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType()));
} else { } else {
@ -83,4 +89,8 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp {
public boolean isInPlace(){ public boolean isInPlace(){
return x == null || x == z || x.data().pointer().address() == z.data().pointer().address(); return x == null || x == z || x.data().pointer().address() == z.data().pointer().address();
} }
public boolean isTripleArgRngOp(){
return false;
}
} }

View File

@ -139,4 +139,9 @@ public class BinomialDistribution extends BaseRandomOp {
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
return Collections.singletonList(DataType.DOUBLE); return Collections.singletonList(DataType.DOUBLE);
} }
@Override
public boolean isTripleArgRngOp() {
return true;
}
} }

View File

@ -138,4 +138,9 @@ public class GaussianDistribution extends BaseRandomOp {
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
return Collections.singletonList(DataType.DOUBLE); return Collections.singletonList(DataType.DOUBLE);
} }
@Override
public boolean isTripleArgRngOp() {
return true;
}
} }

View File

@ -135,4 +135,9 @@ public class LogNormalDistribution extends BaseRandomOp {
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
return Collections.singletonList(DataType.DOUBLE); return Collections.singletonList(DataType.DOUBLE);
} }
@Override
public boolean isTripleArgRngOp() {
return true;
}
} }

View File

@ -136,4 +136,9 @@ public class TruncatedNormalDistribution extends BaseRandomOp {
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
return Collections.singletonList(DataType.DOUBLE); return Collections.singletonList(DataType.DOUBLE);
} }
@Override
public boolean isTripleArgRngOp() {
return true;
}
} }

View File

@ -6556,6 +6556,10 @@ public class Nd4j {
return getExecutioner().exec(op); return getExecutioner().exec(op);
} }
public static INDArray exec(Op op, OpContext context){
return getExecutioner().exec(op, context);
}
/** /**
* Execute the operation and return the result * Execute the operation and return the result
* *

View File

@ -54,7 +54,7 @@ public abstract class Nd4jBlas implements Blas {
} }
String logInit = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION); String logInit = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION);
if(logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit)) { if(logOpenMPBlasThreads() && (logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit))) {
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads()); log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
} }
} }
@ -74,4 +74,8 @@ public abstract class Nd4jBlas implements Blas {
} }
return Vendor.values()[vendor]; return Vendor.values()[vendor];
} }
public boolean logOpenMPBlasThreads(){
return true;
}
} }

View File

@ -134,4 +134,9 @@ public class CudaBlas extends Nd4jBlas {
public int getBlasVendorId() { public int getBlasVendorId() {
return 1; return 1;
} }
@Override
public boolean logOpenMPBlasThreads() {
return false;
}
} }

View File

@ -20,6 +20,7 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -127,7 +128,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
// the only entry place for TADless ops // the only entry place for TADless ops
processAsGridOp(op); processAsGridOp(op);
} else if (op instanceof BroadcastOp) { } else if (op instanceof BroadcastOp) {
invoke((BroadcastOp) op); invoke((BroadcastOp) op, null);
} else { } else {
//logger.info("Random op: {}", op.getClass().getSimpleName()); //logger.info("Random op: {}", op.getClass().getSimpleName());
pushToGrid(new OpDescriptor(op)); pushToGrid(new OpDescriptor(op));
@ -238,7 +239,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
flushQueue(); flushQueue();
//logger.info("Sending TransformOp to CudaExecutioner"); //logger.info("Sending TransformOp to CudaExecutioner");
super.invoke(t); super.invoke(t, null);
} else if (op instanceof Variance) { } else if (op instanceof Variance) {
Variance acc = (Variance) op; Variance acc = (Variance) op;
if (flush) if (flush)
@ -258,7 +259,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
flushQueue(); flushQueue();
//logger.info("Sending ScalarOp to CudaExecutioner"); //logger.info("Sending ScalarOp to CudaExecutioner");
super.invoke(sc); super.invoke(sc, null);
} else if (op instanceof BroadcastOp) { } else if (op instanceof BroadcastOp) {
BroadcastOp broadcastOp = (BroadcastOp) op; BroadcastOp broadcastOp = (BroadcastOp) op;
if (flush) if (flush)
@ -268,7 +269,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
if (dimensions != null) { if (dimensions != null) {
super.exec(broadcastOp); super.exec(broadcastOp);
} else { } else {
super.invoke(broadcastOp); super.invoke(broadcastOp, null);
} }
} else if (op instanceof IndexAccumulation) { } else if (op instanceof IndexAccumulation) {
IndexAccumulation indexAccumulation = (IndexAccumulation) op; IndexAccumulation indexAccumulation = (IndexAccumulation) op;
@ -690,7 +691,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
flushQueue(); flushQueue();
buildZ(op, new int[] {Integer.MAX_VALUE}); buildZ(op, new int[] {Integer.MAX_VALUE});
super.invoke(op, new int[] {Integer.MAX_VALUE}); super.invoke(op, null, new int[] {Integer.MAX_VALUE});
} else { } else {
buildZ(op, dimension); buildZ(op, dimension);
processAsGridOp(op, dimension); processAsGridOp(op, dimension);
@ -708,7 +709,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
// FIXME: remove CudaContext return opType. We just don't need it // FIXME: remove CudaContext return opType. We just don't need it
@Override @Override
protected CudaContext invoke(BroadcastOp op) { protected CudaContext invoke(BroadcastOp op, OpContext oc) {
Preconditions.checkState(oc == null);
processAsGridOp(op, op.getDimension()); processAsGridOp(op, op.getDimension());
return null; return null;
@ -716,7 +718,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
// FIXME: remove CudaContext return opType. We just don't need it // FIXME: remove CudaContext return opType. We just don't need it
@Override @Override
protected CudaContext invoke(ScalarOp op) { protected CudaContext invoke(ScalarOp op, OpContext oc) {
Preconditions.checkState(oc == null);
processAsGridOp(op, null); processAsGridOp(op, null);
return null; return null;
@ -724,7 +727,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
// FIXME: remove CudaContext return opType. We just don't need it // FIXME: remove CudaContext return opType. We just don't need it
@Override @Override
protected CudaContext invoke(TransformOp op) { protected CudaContext invoke(TransformOp op, OpContext oc) {
Preconditions.checkState( oc == null);
processAsGridOp(op, null); processAsGridOp(op, null);
return null; return null;
} }

View File

@ -385,6 +385,7 @@ public class RandomOpValidation extends BaseOpValidation {
@Test @Test
public void testUniformDtype(){ public void testUniformDtype(){
Nd4j.getRandom().setSeed(12345);
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100)); SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));

View File

@ -0,0 +1,169 @@
package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.imports.TFGraphs.TFGraphTestZooModels;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@Slf4j
public class SameDiffMultiThreadTests extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 60000L;
}
@Test
public void testSimple() throws Exception {
int nThreads = 4;
int nRuns = 1000;
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 10);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 10, 10));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10));
SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, 10, 10));
SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 10));
SDVariable w3 = sd.var("w3", Nd4j.rand(DataType.FLOAT, 10, 10));
SDVariable b3 = sd.var("b3", Nd4j.rand(DataType.FLOAT, 10));
SDVariable l1 = sd.nn.tanh(in.mmul(w1).add(b1));
SDVariable l2 = sd.nn.sigmoid(l1.mmul(w2).add(b2));
SDVariable l3 = sd.nn.softmax("out", l2.mmul(w3).add(b3));
SDVariable loss = sd.loss.logLoss("loss", label, l3);
INDArray[] inputArrs = new INDArray[nThreads];
INDArray[] expOut = new INDArray[nThreads];
for( int i=0; i<nThreads; i++ ){
inputArrs[i] = Nd4j.rand(DataType.FLOAT, i+1, 10);
expOut[i] = sd.outputSingle(Collections.singletonMap("in", inputArrs[i]), "out");
}
Semaphore s = new Semaphore(nThreads);
CountDownLatch latch = new CountDownLatch(nThreads);
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
AtomicInteger[] counters = new AtomicInteger[nThreads];
doTest(sd, nThreads, nRuns, inputArrs, expOut, "in", "out", failuresByThread, counters, s, latch);
s.release(nThreads);
latch.await();
for(int i=0; i<nThreads; i++ ){
assertFalse("Thread " + i + " failed", failuresByThread[i].get());
}
for(int i=0; i<nThreads; i++ ){
assertEquals("Thread " + i + " number of runs", nRuns, counters[i].get());
}
}
@Test
public void testMobilenet() throws Exception {
TFGraphTestZooModels.currentTestDir = testDir.newFolder();
File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt");
SameDiff sd = TFGraphTestZooModels.LOADER.apply(f, "mobilenet_v2_1.0_224");
// System.out.println(sd.summary());
int nThreads = 4;
int nRuns = 30;
INDArray[] inputArrs = new INDArray[nThreads];
INDArray[] expOut = new INDArray[nThreads];
for( int i=0; i<nThreads; i++ ){
if(i == 0 || i > 2)
inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 1)
inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 2)
inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3);
expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1");
Nd4j.getExecutioner().commit();
}
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
AtomicInteger[] counters = new AtomicInteger[nThreads];
Semaphore s = new Semaphore(nThreads);
CountDownLatch latch = new CountDownLatch(nThreads);
doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch);
s.release(nThreads);
latch.await();
for(int i=0; i<nThreads; i++ ){
assertFalse("Thread " + i + " failed", failuresByThread[i].get());
}
for(int i=0; i<nThreads; i++ ){
assertEquals("Thread " + i + " number of runs", nRuns, counters[i].get());
}
}
public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut,
String inName, String outName,
AtomicBoolean[] failuresByThread, AtomicInteger[] counters, Semaphore s, CountDownLatch latch){
for( int i=0; i<nThreads; i++ ){
failuresByThread[i] = new AtomicBoolean(false);
counters[i] = new AtomicInteger(0);
final int j=i;
Thread t = new Thread(new Runnable() {
@Override
public void run() {
try{
s.acquire(1);
for( int i=0; i<nRuns; i++ ){
INDArray out = sd.outputSingle(Collections.singletonMap(inName, inputArrs[j]), outName);
Nd4j.getExecutioner().commit();
INDArray exp = expOut[j];
if(!exp.equals(out)){
failuresByThread[j].set(true);
log.error("Failure in thread: {}/{} - iteration {}\nExpected ={}\nActual={}", Thread.currentThread().getId(), j, i, exp, out);
break;
}
if(out.closeable())
out.close();
// if(i % 100 == 0){
// log.info("Thread {} at {}", Thread.currentThread().getId(), i);
// }
counters[j].addAndGet(1);
}
} catch (Throwable t){
log.error("Error in thread: {}", Thread.currentThread().getId(), t);
} finally {
latch.countDown();
}
}
});
t.start();
}
}
}

View File

@ -99,6 +99,10 @@ public class SameDiffTests extends BaseNd4jTest {
@ClassRule @ClassRule
public static TemporaryFolder folder = new TemporaryFolder(); public static TemporaryFolder folder = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 999999999L;
}
@Before @Before
public void before() { public void before() {

View File

@ -36,6 +36,7 @@ import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.IrisDataSetIterator;
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
@ -336,12 +337,12 @@ public class ListenerTest extends BaseNd4jTest {
} }
@Override @Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
preOpExecutionCount++; preOpExecutionCount++;
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
opExecutionCount++; opExecutionCount++;
} }

View File

@ -10,6 +10,8 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.*; import java.util.*;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
public class OpExecOrderListener extends BaseListener { public class OpExecOrderListener extends BaseListener {
@ -24,7 +26,7 @@ public class OpExecOrderListener extends BaseListener {
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
String opName = op.getName(); String opName = op.getName();
if(!opSet.contains(opName)){ if(!opSet.contains(opName)){
opNamesList.add(opName); opNamesList.add(opName);

View File

@ -6,6 +6,7 @@ import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
/** /**
@ -20,7 +21,7 @@ public class ExecPrintListener extends BaseListener {
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------"); System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------");
for(INDArray arr : outputs){ for(INDArray arr : outputs){
System.out.println(arr); System.out.println(arr);

View File

@ -24,6 +24,7 @@ import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -56,7 +57,7 @@ public class ImportDebugListener extends BaseListener {
} }
@Override @Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
//No op //No op
for( int i=0; i<outputs.length; i++ ) { for( int i=0; i<outputs.length; i++ ) {

View File

@ -750,7 +750,8 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
INDArray toPermute = Nd4j.create(Nd4j.linspace(0, 7, 8, DataType.DOUBLE).data(), new long[] {2, 2, 2}); INDArray toPermute = Nd4j.create(Nd4j.linspace(0, 7, 8, DataType.DOUBLE).data(), new long[] {2, 2, 2});
INDArray permuted = toPermute.permute(2, 1, 0); INDArray permuted = toPermute.dup().permute(2, 1, 0);
boolean eq = toPermute.equals(permuted);
assertNotEquals(toPermute, permuted); assertNotEquals(toPermute, permuted);
INDArray permuteOther = toPermute.permute(1, 2, 0); INDArray permuteOther = toPermute.permute(1, 2, 0);

View File

@ -86,6 +86,9 @@ class FunctionalOpExecutioner extends OpExecutioner {
case _ => op.z() case _ => op.z()
} }
def exec(op: Op, context: OpContext): INDArray =
Nd4j.getExecutioner.exec(op, context)
def exec(op: FilterOps): INDArray = { def exec(op: FilterOps): INDArray = {
val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*) val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*)
for (i <- 0 until op.x().length().toInt) { for (i <- 0 until op.x().length().toInt) {
@ -408,6 +411,9 @@ class FunctionalOpExecutioner extends OpExecutioner {
def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] =
Nd4j.getExecutioner.calculateOutputShape(op) Nd4j.getExecutioner.calculateOutputShape(op)
def calculateOutputShape(op: CustomOp, ctx: OpContext): java.util.List[LongShapeDescriptor] =
Nd4j.getExecutioner.calculateOutputShape(op, ctx)
/** /**
* Equivalent to calli * Equivalent to calli
*/ */