diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index cd6274dfb..862ffa42f 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1600,6 +1600,7 @@ ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId); ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride); +ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode); ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 1b1d22fbf..2e203584d 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2796,6 +2796,13 @@ void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { ptr->allowHelpers(reallyAllow); } +void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) { + if (execMode < 0 || execMode > 2) + execMode = 0; + + ptr->setExecutionMode((samediff::ExecutionMode) execMode); +} + nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); } diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 419cadef5..b7995cb75 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3780,6 +3780,13 @@ void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { ptr->allowHelpers(reallyAllow); } +void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) { + if (execMode < 0 || execMode > 2) + execMode = 0; + + ptr->setExecutionMode((samediff::ExecutionMode) execMode); +} + OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { try { auto dtype = DataTypeUtils::fromInt(dataType); diff --git a/libnd4j/include/execution/ExecutionMode.h b/libnd4j/include/execution/ExecutionMode.h new file mode 100644 index 000000000..ea97e3fc9 --- /dev/null +++ b/libnd4j/include/execution/ExecutionMode.h @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_EXECUTIONMODE_H +#define SD_EXECUTIONMODE_H + +namespace samediff { + enum ExecutionMode { + MODE_UNDEFINED = 0, + MODE_TRAINING = 1, + MODE_INFERENCE = 2, + }; +} + +#endif //SD_EXECUTIONMODE_H diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index f4fa6d16d..5e0f094e1 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -208,6 +209,12 @@ namespace nd4j { void setShapeFunctionOverride(bool reallyOverride); bool shapeFunctionOverride(); + + samediff::ExecutionMode executionMode(); + void setExecutionMode(samediff::ExecutionMode executionMode); + + bool isTraining(); + bool isInference(); }; } } diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index a9d05b7b4..bf5d389e4 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,6 +29,7 @@ #include #include #include +#include #ifndef __STANDALONE_BUILD__ #include @@ -60,6 +62,8 @@ namespace nd4j { // target engine for execution samediff::Engine _engine = DEFAULT_ENGINE; + + samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED; public: explicit ContextPrototype(nd4j::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false); ~ContextPrototype() = default; diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 4876675dc..5efd13a20 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -535,6 +535,22 @@ namespace nd4j { bool Context::shapeFunctionOverride() { return _shapeFunctionOverride; } + + samediff::ExecutionMode Context::executionMode() { + return _execMode; + } + + void Context::setExecutionMode(samediff::ExecutionMode executionMode) { + _execMode = executionMode; + } + + bool Context::isTraining() { + return _execMode == samediff::ExecutionMode::MODE_TRAINING; + } + + bool Context::isInference() { + return _execMode == samediff::ExecutionMode::MODE_INFERENCE; + } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java index ffdce18ee..050868b36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java @@ -16,7 +16,9 @@ package org.nd4j.linalg.api.ops; +import lombok.Getter; import lombok.NonNull; +import lombok.Setter; import lombok.val; import org.nd4j.linalg.api.ndarray.INDArray; @@ -35,6 +37,10 @@ public abstract class BaseOpContext implements OpContext { protected List fastpath_b = new ArrayList<>(); protected List fastpath_i = new ArrayList<>(); + @Setter() + @Getter + protected ExecutionMode executionMode = ExecutionMode.UNDEFINED; + @Override public void setIArguments(long... arguments) { fastpath_i.clear(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ExecutionMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ExecutionMode.java new file mode 100644 index 000000000..5d33e744c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ExecutionMode.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops; + +/** + * This enum describes execution mode for current op/graph. For some operations different execution modes might yield performance/implementation differences + * + * @author raver119@gmail.com + */ +public enum ExecutionMode { + UNDEFINED, + TRAINING, + INFERENCE +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index dda6aef24..3deefe7c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -142,4 +142,17 @@ public interface OpContext extends AutoCloseable { * @param reallyOverride */ void shapeFunctionOverride(boolean reallyOverride); + + /** + * This method returns current execution mode for Context + * @return + */ + ExecutionMode getExecutionMode(); + + /** + * This method allows to set certain execution mode + * + * @param mode + */ + void setExecutionMode(ExecutionMode mode); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index 1d1b837e7..95c97068e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1158,6 +1158,7 @@ public interface NativeOps { void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow); + void ctxSetExecutionMode(OpaqueContext ptr, int execMode); void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride); void deleteGraphContext(OpaqueContext ptr); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 6f37be02a..487f38232 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -27,6 +27,7 @@ import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; +import org.nd4j.linalg.api.ops.ExecutionMode; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; @@ -126,4 +127,10 @@ public class CudaOpContext extends BaseOpContext implements OpContext { public void shapeFunctionOverride(boolean reallyOverride) { nativeOps.ctxShapeFunctionOverride(context, reallyOverride); } + + @Override + public void setExecutionMode(@NonNull ExecutionMode mode) { + super.setExecutionMode(mode); + nativeOps.ctxSetExecutionMode(context, mode.ordinal()); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 5aa685c7a..d466c41a5 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -759,6 +759,40 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { // #endif //SD_ENGINE_H +// Parsed from execution/ExecutionMode.h + +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_EXECUTIONMODE_H +// #define SD_EXECUTIONMODE_H + /** enum samediff::ExecutionMode */ + public static final int + MODE_UNDEFINED = 0, + MODE_TRAINING = 1, + MODE_INFERENCE = 2; + + +// #endif //SD_EXECUTIONMODE_H + + // Parsed from memory/MemoryType.h // @@ -3055,6 +3089,7 @@ public native OpaqueContext createGraphContext(int nodeId); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); +public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); @@ -6416,6 +6451,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); public native @Cast("bool") boolean shapeFunctionOverride(); + + public native @Cast("samediff::ExecutionMode") int executionMode(); + public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode); + + public native @Cast("bool") boolean isTraining(); + public native @Cast("bool") boolean isInference(); } @@ -6456,6 +6497,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #include // #include // #include +// #include // #ifndef __STANDALONE_BUILD__ // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index aa6d91519..cdfefce31 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -39,6 +39,7 @@ import org.bytedeco.javacpp.tools.InfoMapper; "array/TadPack.h", "execution/ErrorReference.h", "execution/Engine.h", + "execution/ExecutionMode.h", "memory/MemoryType.h", "Environment.h", "types/utf8string.h", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index fce391a05..1863d6c1c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -23,6 +23,7 @@ import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; +import org.nd4j.linalg.api.ops.ExecutionMode; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; import org.nd4j.linalg.primitives.Pair; @@ -118,4 +119,10 @@ public class CpuOpContext extends BaseOpContext implements OpContext { public void shapeFunctionOverride(boolean reallyOverride) { nativeOps.ctxShapeFunctionOverride(context, reallyOverride); } + + @Override + public void setExecutionMode(@NonNull ExecutionMode mode) { + super.setExecutionMode(mode); + nativeOps.ctxSetExecutionMode(context, mode.ordinal()); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index cfabc651c..99002b3a6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -780,6 +780,40 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // #endif //SD_ENGINE_H +// Parsed from execution/ExecutionMode.h + +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_EXECUTIONMODE_H +// #define SD_EXECUTIONMODE_H + /** enum samediff::ExecutionMode */ + public static final int + MODE_UNDEFINED = 0, + MODE_TRAINING = 1, + MODE_INFERENCE = 2; + + +// #endif //SD_EXECUTIONMODE_H + + // Parsed from Environment.h /******************************************************************************* @@ -3058,6 +3092,7 @@ public native OpaqueContext createGraphContext(int nodeId); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); +public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); @@ -6419,6 +6454,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); public native @Cast("bool") boolean shapeFunctionOverride(); + + public native @Cast("samediff::ExecutionMode") int executionMode(); + public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode); + + public native @Cast("bool") boolean isTraining(); + public native @Cast("bool") boolean isInference(); } @@ -6459,6 +6500,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #include // #include // #include +// #include // #ifndef __STANDALONE_BUILD__ // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index c2fca8d89..577626864 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -42,6 +42,7 @@ import java.util.Scanner; "array/TadPack.h", "execution/ErrorReference.h", "execution/Engine.h", + "execution/ExecutionMode.h", "Environment.h", "types/utf8string.h", "NativeOps.h",