execution mode (#183)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* execution mode java side

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* move exec mode to ContextPrototype

Signed-off-by: raver119 <raver119@gmail.com>

* copyrights

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-27 10:00:07 +03:00 committed by GitHub
parent 03a1859d0a
commit 531a72fabd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 222 additions and 0 deletions

View File

@ -1600,6 +1600,7 @@ ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId);
ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr);
ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow);
ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride); ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride);
ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode);
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);

View File

@ -2796,6 +2796,13 @@ void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) {
ptr->allowHelpers(reallyAllow); ptr->allowHelpers(reallyAllow);
} }
void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
if (execMode < 0 || execMode > 2)
execMode = 0;
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
}
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
} }

View File

@ -3780,6 +3780,13 @@ void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) {
ptr->allowHelpers(reallyAllow); ptr->allowHelpers(reallyAllow);
} }
void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
if (execMode < 0 || execMode > 2)
execMode = 0;
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
}
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
try { try {
auto dtype = DataTypeUtils::fromInt(dataType); auto dtype = DataTypeUtils::fromInt(dataType);

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SD_EXECUTIONMODE_H
#define SD_EXECUTIONMODE_H
namespace samediff {
enum ExecutionMode {
MODE_UNDEFINED = 0,
MODE_TRAINING = 1,
MODE_INFERENCE = 2,
};
}
#endif //SD_EXECUTIONMODE_H

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -208,6 +209,12 @@ namespace nd4j {
void setShapeFunctionOverride(bool reallyOverride); void setShapeFunctionOverride(bool reallyOverride);
bool shapeFunctionOverride(); bool shapeFunctionOverride();
samediff::ExecutionMode executionMode();
void setExecutionMode(samediff::ExecutionMode executionMode);
bool isTraining();
bool isInference();
}; };
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -28,6 +29,7 @@
#include <RandomGenerator.h> #include <RandomGenerator.h>
#include <ops/declarable/OpDescriptor.h> #include <ops/declarable/OpDescriptor.h>
#include <execution/Engine.h> #include <execution/Engine.h>
#include <execution/ExecutionMode.h>
#ifndef __STANDALONE_BUILD__ #ifndef __STANDALONE_BUILD__
#include <config.h> #include <config.h>
@ -60,6 +62,8 @@ namespace nd4j {
// target engine for execution // target engine for execution
samediff::Engine _engine = DEFAULT_ENGINE; samediff::Engine _engine = DEFAULT_ENGINE;
samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED;
public: public:
explicit ContextPrototype(nd4j::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false); explicit ContextPrototype(nd4j::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false);
~ContextPrototype() = default; ~ContextPrototype() = default;

View File

@ -535,6 +535,22 @@ namespace nd4j {
bool Context::shapeFunctionOverride() { bool Context::shapeFunctionOverride() {
return _shapeFunctionOverride; return _shapeFunctionOverride;
} }
samediff::ExecutionMode Context::executionMode() {
return _execMode;
}
void Context::setExecutionMode(samediff::ExecutionMode executionMode) {
_execMode = executionMode;
}
bool Context::isTraining() {
return _execMode == samediff::ExecutionMode::MODE_TRAINING;
}
bool Context::isInference() {
return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
}
} }
} }

View File

@ -16,7 +16,9 @@
package org.nd4j.linalg.api.ops; package org.nd4j.linalg.api.ops;
import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.Setter;
import lombok.val; import lombok.val;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -35,6 +37,10 @@ public abstract class BaseOpContext implements OpContext {
protected List<Boolean> fastpath_b = new ArrayList<>(); protected List<Boolean> fastpath_b = new ArrayList<>();
protected List<Long> fastpath_i = new ArrayList<>(); protected List<Long> fastpath_i = new ArrayList<>();
@Setter()
@Getter
protected ExecutionMode executionMode = ExecutionMode.UNDEFINED;
@Override @Override
public void setIArguments(long... arguments) { public void setIArguments(long... arguments) {
fastpath_i.clear(); fastpath_i.clear();

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops;
/**
* This enum describes execution mode for current op/graph. For some operations different execution modes might yield performance/implementation differences
*
* @author raver119@gmail.com
*/
public enum ExecutionMode {
UNDEFINED,
TRAINING,
INFERENCE
}

View File

@ -142,4 +142,17 @@ public interface OpContext extends AutoCloseable {
* @param reallyOverride * @param reallyOverride
*/ */
void shapeFunctionOverride(boolean reallyOverride); void shapeFunctionOverride(boolean reallyOverride);
/**
* This method returns current execution mode for Context
* @return
*/
ExecutionMode getExecutionMode();
/**
* This method allows to set certain execution mode
*
* @param mode
*/
void setExecutionMode(ExecutionMode mode);
} }

View File

@ -1158,6 +1158,7 @@ public interface NativeOps {
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow); void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride); void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride);
void deleteGraphContext(OpaqueContext ptr); void deleteGraphContext(OpaqueContext ptr);

View File

@ -27,6 +27,7 @@ import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.BaseOpContext;
import org.nd4j.linalg.api.ops.ExecutionMode;
import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
@ -126,4 +127,10 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
public void shapeFunctionOverride(boolean reallyOverride) { public void shapeFunctionOverride(boolean reallyOverride) {
nativeOps.ctxShapeFunctionOverride(context, reallyOverride); nativeOps.ctxShapeFunctionOverride(context, reallyOverride);
} }
@Override
public void setExecutionMode(@NonNull ExecutionMode mode) {
super.setExecutionMode(mode);
nativeOps.ctxSetExecutionMode(context, mode.ordinal());
}
} }

View File

@ -759,6 +759,40 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
// #endif //SD_ENGINE_H // #endif //SD_ENGINE_H
// Parsed from execution/ExecutionMode.h
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
// #ifndef SD_EXECUTIONMODE_H
// #define SD_EXECUTIONMODE_H
/** enum samediff::ExecutionMode */
public static final int
MODE_UNDEFINED = 0,
MODE_TRAINING = 1,
MODE_INFERENCE = 2;
// #endif //SD_EXECUTIONMODE_H
// Parsed from memory/MemoryType.h // Parsed from memory/MemoryType.h
// //
@ -3055,6 +3089,7 @@ public native OpaqueContext createGraphContext(int nodeId);
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
@ -6416,6 +6451,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride);
public native @Cast("bool") boolean shapeFunctionOverride(); public native @Cast("bool") boolean shapeFunctionOverride();
public native @Cast("samediff::ExecutionMode") int executionMode();
public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode);
public native @Cast("bool") boolean isTraining();
public native @Cast("bool") boolean isInference();
} }
@ -6456,6 +6497,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <RandomGenerator.h> // #include <RandomGenerator.h>
// #include <ops/declarable/OpDescriptor.h> // #include <ops/declarable/OpDescriptor.h>
// #include <execution/Engine.h> // #include <execution/Engine.h>
// #include <execution/ExecutionMode.h>
// #ifndef __STANDALONE_BUILD__ // #ifndef __STANDALONE_BUILD__
// #include <config.h> // #include <config.h>

View File

@ -39,6 +39,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
"array/TadPack.h", "array/TadPack.h",
"execution/ErrorReference.h", "execution/ErrorReference.h",
"execution/Engine.h", "execution/Engine.h",
"execution/ExecutionMode.h",
"memory/MemoryType.h", "memory/MemoryType.h",
"Environment.h", "Environment.h",
"types/utf8string.h", "types/utf8string.h",

View File

@ -23,6 +23,7 @@ import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.BaseOpContext;
import org.nd4j.linalg.api.ops.ExecutionMode;
import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -118,4 +119,10 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
public void shapeFunctionOverride(boolean reallyOverride) { public void shapeFunctionOverride(boolean reallyOverride) {
nativeOps.ctxShapeFunctionOverride(context, reallyOverride); nativeOps.ctxShapeFunctionOverride(context, reallyOverride);
} }
@Override
public void setExecutionMode(@NonNull ExecutionMode mode) {
super.setExecutionMode(mode);
nativeOps.ctxSetExecutionMode(context, mode.ordinal());
}
} }

View File

@ -780,6 +780,40 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
// #endif //SD_ENGINE_H // #endif //SD_ENGINE_H
// Parsed from execution/ExecutionMode.h
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
// #ifndef SD_EXECUTIONMODE_H
// #define SD_EXECUTIONMODE_H
/** enum samediff::ExecutionMode */
public static final int
MODE_UNDEFINED = 0,
MODE_TRAINING = 1,
MODE_INFERENCE = 2;
// #endif //SD_EXECUTIONMODE_H
// Parsed from Environment.h // Parsed from Environment.h
/******************************************************************************* /*******************************************************************************
@ -3058,6 +3092,7 @@ public native OpaqueContext createGraphContext(int nodeId);
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
@ -6419,6 +6454,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride);
public native @Cast("bool") boolean shapeFunctionOverride(); public native @Cast("bool") boolean shapeFunctionOverride();
public native @Cast("samediff::ExecutionMode") int executionMode();
public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode);
public native @Cast("bool") boolean isTraining();
public native @Cast("bool") boolean isInference();
} }
@ -6459,6 +6500,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <RandomGenerator.h> // #include <RandomGenerator.h>
// #include <ops/declarable/OpDescriptor.h> // #include <ops/declarable/OpDescriptor.h>
// #include <execution/Engine.h> // #include <execution/Engine.h>
// #include <execution/ExecutionMode.h>
// #ifndef __STANDALONE_BUILD__ // #ifndef __STANDALONE_BUILD__
// #include <config.h> // #include <config.h>

View File

@ -42,6 +42,7 @@ import java.util.Scanner;
"array/TadPack.h", "array/TadPack.h",
"execution/ErrorReference.h", "execution/ErrorReference.h",
"execution/Engine.h", "execution/Engine.h",
"execution/ExecutionMode.h",
"Environment.h", "Environment.h",
"types/utf8string.h", "types/utf8string.h",
"NativeOps.h", "NativeOps.h",