execution mode (#183)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * execution mode java side Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * move exec mode to ContextPrototype Signed-off-by: raver119 <raver119@gmail.com> * copyrights Signed-off-by: raver119 <raver119@gmail.com>master
parent
03a1859d0a
commit
531a72fabd
|
@ -1600,6 +1600,7 @@ ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId);
|
|||
ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr);
|
||||
ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow);
|
||||
ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride);
|
||||
ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode);
|
||||
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
|
||||
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
|
||||
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||
|
|
|
@ -2796,6 +2796,13 @@ void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) {
|
|||
ptr->allowHelpers(reallyAllow);
|
||||
}
|
||||
|
||||
void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
|
||||
if (execMode < 0 || execMode > 2)
|
||||
execMode = 0;
|
||||
|
||||
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
|
||||
}
|
||||
|
||||
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
||||
}
|
||||
|
|
|
@ -3780,6 +3780,13 @@ void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) {
|
|||
ptr->allowHelpers(reallyAllow);
|
||||
}
|
||||
|
||||
void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
|
||||
if (execMode < 0 || execMode > 2)
|
||||
execMode = 0;
|
||||
|
||||
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
|
||||
}
|
||||
|
||||
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
|
||||
try {
|
||||
auto dtype = DataTypeUtils::fromInt(dataType);
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SD_EXECUTIONMODE_H
|
||||
#define SD_EXECUTIONMODE_H
|
||||
|
||||
namespace samediff {
|
||||
enum ExecutionMode {
|
||||
MODE_UNDEFINED = 0,
|
||||
MODE_TRAINING = 1,
|
||||
MODE_INFERENCE = 2,
|
||||
};
|
||||
}
|
||||
|
||||
#endif //SD_EXECUTIONMODE_H
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -208,6 +209,12 @@ namespace nd4j {
|
|||
|
||||
void setShapeFunctionOverride(bool reallyOverride);
|
||||
bool shapeFunctionOverride();
|
||||
|
||||
samediff::ExecutionMode executionMode();
|
||||
void setExecutionMode(samediff::ExecutionMode executionMode);
|
||||
|
||||
bool isTraining();
|
||||
bool isInference();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -28,6 +29,7 @@
|
|||
#include <RandomGenerator.h>
|
||||
#include <ops/declarable/OpDescriptor.h>
|
||||
#include <execution/Engine.h>
|
||||
#include <execution/ExecutionMode.h>
|
||||
|
||||
#ifndef __STANDALONE_BUILD__
|
||||
#include <config.h>
|
||||
|
@ -60,6 +62,8 @@ namespace nd4j {
|
|||
|
||||
// target engine for execution
|
||||
samediff::Engine _engine = DEFAULT_ENGINE;
|
||||
|
||||
samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED;
|
||||
public:
|
||||
explicit ContextPrototype(nd4j::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false);
|
||||
~ContextPrototype() = default;
|
||||
|
|
|
@ -535,6 +535,22 @@ namespace nd4j {
|
|||
bool Context::shapeFunctionOverride() {
|
||||
return _shapeFunctionOverride;
|
||||
}
|
||||
|
||||
samediff::ExecutionMode Context::executionMode() {
|
||||
return _execMode;
|
||||
}
|
||||
|
||||
void Context::setExecutionMode(samediff::ExecutionMode executionMode) {
|
||||
_execMode = executionMode;
|
||||
}
|
||||
|
||||
bool Context::isTraining() {
|
||||
return _execMode == samediff::ExecutionMode::MODE_TRAINING;
|
||||
}
|
||||
|
||||
bool Context::isInference() {
|
||||
return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
@ -35,6 +37,10 @@ public abstract class BaseOpContext implements OpContext {
|
|||
protected List<Boolean> fastpath_b = new ArrayList<>();
|
||||
protected List<Long> fastpath_i = new ArrayList<>();
|
||||
|
||||
@Setter()
|
||||
@Getter
|
||||
protected ExecutionMode executionMode = ExecutionMode.UNDEFINED;
|
||||
|
||||
@Override
|
||||
public void setIArguments(long... arguments) {
|
||||
fastpath_i.clear();
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
/**
|
||||
* This enum describes execution mode for current op/graph. For some operations different execution modes might yield performance/implementation differences
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public enum ExecutionMode {
|
||||
UNDEFINED,
|
||||
TRAINING,
|
||||
INFERENCE
|
||||
}
|
|
@ -142,4 +142,17 @@ public interface OpContext extends AutoCloseable {
|
|||
* @param reallyOverride
|
||||
*/
|
||||
void shapeFunctionOverride(boolean reallyOverride);
|
||||
|
||||
/**
|
||||
* This method returns current execution mode for Context
|
||||
* @return
|
||||
*/
|
||||
ExecutionMode getExecutionMode();
|
||||
|
||||
/**
|
||||
* This method allows to set certain execution mode
|
||||
*
|
||||
* @param mode
|
||||
*/
|
||||
void setExecutionMode(ExecutionMode mode);
|
||||
}
|
||||
|
|
|
@ -1158,6 +1158,7 @@ public interface NativeOps {
|
|||
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
|
||||
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
|
||||
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
|
||||
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||
void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride);
|
||||
void deleteGraphContext(OpaqueContext ptr);
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
|||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||
import org.nd4j.linalg.api.ops.ExecutionMode;
|
||||
import org.nd4j.linalg.api.ops.OpContext;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
|
||||
|
@ -126,4 +127,10 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
public void shapeFunctionOverride(boolean reallyOverride) {
|
||||
nativeOps.ctxShapeFunctionOverride(context, reallyOverride);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setExecutionMode(@NonNull ExecutionMode mode) {
|
||||
super.setExecutionMode(mode);
|
||||
nativeOps.ctxSetExecutionMode(context, mode.ordinal());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -759,6 +759,40 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
|||
// #endif //SD_ENGINE_H
|
||||
|
||||
|
||||
// Parsed from execution/ExecutionMode.h
|
||||
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
// #ifndef SD_EXECUTIONMODE_H
|
||||
// #define SD_EXECUTIONMODE_H
|
||||
/** enum samediff::ExecutionMode */
|
||||
public static final int
|
||||
MODE_UNDEFINED = 0,
|
||||
MODE_TRAINING = 1,
|
||||
MODE_INFERENCE = 2;
|
||||
|
||||
|
||||
// #endif //SD_EXECUTIONMODE_H
|
||||
|
||||
|
||||
// Parsed from memory/MemoryType.h
|
||||
|
||||
//
|
||||
|
@ -3055,6 +3089,7 @@ public native OpaqueContext createGraphContext(int nodeId);
|
|||
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
|
||||
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
|
||||
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
|
||||
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
|
||||
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
|
@ -6416,6 +6451,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
|
||||
public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride);
|
||||
public native @Cast("bool") boolean shapeFunctionOverride();
|
||||
|
||||
public native @Cast("samediff::ExecutionMode") int executionMode();
|
||||
public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode);
|
||||
|
||||
public native @Cast("bool") boolean isTraining();
|
||||
public native @Cast("bool") boolean isInference();
|
||||
}
|
||||
|
||||
|
||||
|
@ -6456,6 +6497,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
// #include <RandomGenerator.h>
|
||||
// #include <ops/declarable/OpDescriptor.h>
|
||||
// #include <execution/Engine.h>
|
||||
// #include <execution/ExecutionMode.h>
|
||||
|
||||
// #ifndef __STANDALONE_BUILD__
|
||||
// #include <config.h>
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
|
|||
"array/TadPack.h",
|
||||
"execution/ErrorReference.h",
|
||||
"execution/Engine.h",
|
||||
"execution/ExecutionMode.h",
|
||||
"memory/MemoryType.h",
|
||||
"Environment.h",
|
||||
"types/utf8string.h",
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.bytedeco.javacpp.LongPointer;
|
|||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||
import org.nd4j.linalg.api.ops.ExecutionMode;
|
||||
import org.nd4j.linalg.api.ops.OpContext;
|
||||
import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
@ -118,4 +119,10 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
|
|||
public void shapeFunctionOverride(boolean reallyOverride) {
|
||||
nativeOps.ctxShapeFunctionOverride(context, reallyOverride);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setExecutionMode(@NonNull ExecutionMode mode) {
|
||||
super.setExecutionMode(mode);
|
||||
nativeOps.ctxSetExecutionMode(context, mode.ordinal());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -780,6 +780,40 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
|||
// #endif //SD_ENGINE_H
|
||||
|
||||
|
||||
// Parsed from execution/ExecutionMode.h
|
||||
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
// #ifndef SD_EXECUTIONMODE_H
|
||||
// #define SD_EXECUTIONMODE_H
|
||||
/** enum samediff::ExecutionMode */
|
||||
public static final int
|
||||
MODE_UNDEFINED = 0,
|
||||
MODE_TRAINING = 1,
|
||||
MODE_INFERENCE = 2;
|
||||
|
||||
|
||||
// #endif //SD_EXECUTIONMODE_H
|
||||
|
||||
|
||||
// Parsed from Environment.h
|
||||
|
||||
/*******************************************************************************
|
||||
|
@ -3058,6 +3092,7 @@ public native OpaqueContext createGraphContext(int nodeId);
|
|||
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
|
||||
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
|
||||
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
|
||||
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
|
||||
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
|
@ -6419,6 +6454,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
|
||||
public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride);
|
||||
public native @Cast("bool") boolean shapeFunctionOverride();
|
||||
|
||||
public native @Cast("samediff::ExecutionMode") int executionMode();
|
||||
public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode);
|
||||
|
||||
public native @Cast("bool") boolean isTraining();
|
||||
public native @Cast("bool") boolean isInference();
|
||||
}
|
||||
|
||||
|
||||
|
@ -6459,6 +6500,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
// #include <RandomGenerator.h>
|
||||
// #include <ops/declarable/OpDescriptor.h>
|
||||
// #include <execution/Engine.h>
|
||||
// #include <execution/ExecutionMode.h>
|
||||
|
||||
// #ifndef __STANDALONE_BUILD__
|
||||
// #include <config.h>
|
||||
|
|
|
@ -42,6 +42,7 @@ import java.util.Scanner;
|
|||
"array/TadPack.h",
|
||||
"execution/ErrorReference.h",
|
||||
"execution/Engine.h",
|
||||
"execution/ExecutionMode.h",
|
||||
"Environment.h",
|
||||
"types/utf8string.h",
|
||||
"NativeOps.h",
|
||||
|
|
Loading…
Reference in New Issue