Fix functions of OpaqueVariablesSet

master
Samuel Audet 2019-07-26 07:22:44 +00:00 committed by AlexDBlack
parent b57f1d52cc
commit 8d1fe8b1b3
11 changed files with 58 additions and 57 deletions

View File

@ -1678,14 +1678,14 @@ ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer); ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer);
typedef nd4j::graph::VariablesSet OpaqueVariableSet; typedef nd4j::graph::VariablesSet OpaqueVariablesSet;
typedef nd4j::graph::Variable OpaqueVariable; typedef nd4j::graph::Variable OpaqueVariable;
ND4J_EXPORT OpaqueVariableSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); ND4J_EXPORT OpaqueVariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
ND4J_EXPORT Nd4jLong getVariableSetSize(OpaqueVariableSet* set); ND4J_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set);
ND4J_EXPORT Nd4jStatus getVariableSetStatus(OpaqueVariableSet* set); ND4J_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set);
ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariableSet* set, Nd4jLong i); ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i);
ND4J_EXPORT int getVariableId(OpaqueVariable* variable); ND4J_EXPORT int getVariableId(OpaqueVariable* variable);
ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable); ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable);
ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable); ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable);
@ -1699,7 +1699,7 @@ ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer);
ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer); ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer);
ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer); ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer);
ND4J_EXPORT void deleteVariablesSet(OpaqueVariableSet pointer); ND4J_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer);
// GraphState creation // GraphState creation
ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id); ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id);

View File

@ -2337,11 +2337,11 @@ nd4j::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLo
return nullptr; return nullptr;
} }
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) { Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) {
return set->size(); return set->size();
} }
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) { Nd4jStatus getVariablesSetStatus(nd4j::graph::VariablesSet* set) {
return set->status(); return set->status();
} }
@ -2396,14 +2396,8 @@ void deleteLongArray(Nd4jPointer pointer) {
delete[] ptr; delete[] ptr;
} }
template <typename T> void deleteVariablesSet(nd4j::graph::VariablesSet* pointer) {
static void deleteVariablesSetT(Nd4jPointer pointer) { delete pointer;
auto ptr = reinterpret_cast<nd4j::graph::VariablesSet*>(pointer);
delete ptr;
}
void deleteVariablesSet(Nd4jPointer pointer) {
deleteVariablesSetT<double>(pointer);
} }
const char* getAllOperations() { const char* getAllOperations() {

View File

@ -1935,6 +1935,19 @@ void execAggregate(Nd4jPointer *extraPointers,
nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed");
} }
void batchExecutor(Nd4jPointer *extraPointers,
int numAggregates,
int opNum,
int maxArgs,
int maxShapes,
int maxIntArrays,
int maxIntArraySize,
int maxIdx,
int maxReals,
void *ptrToArguments,
nd4j::DataType dtype) {
}
void execAggregateBatch(Nd4jPointer *extraPointers, void execAggregateBatch(Nd4jPointer *extraPointers,
int numAggregates, int opNum, int numAggregates, int opNum,
int maxArgs, int maxShapes, int maxArgs, int maxShapes,
@ -2808,11 +2821,11 @@ VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, N
return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs); return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs);
} }
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) { Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) {
return set->size(); return set->size();
} }
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) { Nd4jStatus getVariablesSetStatus(nd4j::graph::VariablesSet* set) {
return set->status(); return set->status();
} }
@ -2867,14 +2880,8 @@ void deleteLongArray(Nd4jPointer pointer) {
delete[] ptr; delete[] ptr;
} }
template <typename T> void deleteVariablesSet(nd4j::graph::VariablesSet* pointer) {
static void deleteVariablesSetT(Nd4jPointer pointer) { delete pointer;
nd4j::graph::VariablesSet* ptr = reinterpret_cast<nd4j::graph::VariablesSet*>(pointer);
delete ptr;
}
void deleteVariablesSet(Nd4jPointer pointer) {
deleteVariablesSetT<double>(pointer);
} }
void deleteShapeList(Nd4jPointer shapeList) { void deleteShapeList(Nd4jPointer shapeList) {

View File

@ -1068,11 +1068,11 @@ public interface NativeOps {
int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer); int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer);
OpaqueVariableSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs); OpaqueVariablesSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
long getVariableSetSize(OpaqueVariableSet set); long getVariablesSetSize(OpaqueVariablesSet set);
int getVariableSetStatus(OpaqueVariableSet set); int getVariablesSetStatus(OpaqueVariablesSet set);
OpaqueVariable getVariable(OpaqueVariableSet set, long i); OpaqueVariable getVariable(OpaqueVariablesSet set, long i);
int getVariableId(OpaqueVariable variable); int getVariableId(OpaqueVariable variable);
int getVariableIndex(OpaqueVariable variable); int getVariableIndex(OpaqueVariable variable);
String getVariableName(OpaqueVariable variable); String getVariableName(OpaqueVariable variable);
@ -1095,7 +1095,7 @@ public interface NativeOps {
void deleteNPArrayMap(Pointer pointer); void deleteNPArrayMap(Pointer pointer);
void deleteVariablesSet(OpaqueVariableSet pointer); void deleteVariablesSet(OpaqueVariablesSet pointer);
// GraphState creation // GraphState creation
Pointer getGraphState(long id); Pointer getGraphState(long id);

View File

@ -22,6 +22,6 @@ import org.bytedeco.javacpp.Pointer;
* *
* @author saudet * @author saudet
*/ */
public class OpaqueVariableSet extends Pointer { public class OpaqueVariablesSet extends Pointer {
public OpaqueVariableSet(Pointer p) { super(p); } public OpaqueVariablesSet(Pointer p) { super(p); }
} }

View File

@ -76,7 +76,7 @@ import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList; import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack; import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable; import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariableSet; import org.nd4j.nativeblas.OpaqueVariablesSet;
import java.util.*; import java.util.*;
@ -2427,14 +2427,14 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val newMap = new LinkedHashMap<String, INDArray>(); val newMap = new LinkedHashMap<String, INDArray>();
OpaqueVariableSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
OpStatus status = OpStatus.byNumber(nativeOps.getVariableSetStatus(result)); OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result));
if (status != OpStatus.ND4J_STATUS_OK) if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status); throw new ND4JIllegalStateException("Op execution failed: " + status);
for (int e = 0; e < nativeOps.getVariableSetSize(result); e++) { for (int e = 0; e < nativeOps.getVariablesSetSize(result); e++) {
OpaqueVariable var = nativeOps.getVariable(result, e); OpaqueVariable var = nativeOps.getVariable(result, e);
int nodeId = nativeOps.getVariableId(var); int nodeId = nativeOps.getVariableId(var);
int index = nativeOps.getVariableIndex(var); int index = nativeOps.getVariableIndex(var);

View File

@ -2988,13 +2988,13 @@ public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set); public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set);
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set); public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set);
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i); public native OpaqueVariable getVariable(OpaqueVariablesSet set, @Cast("Nd4jLong") long i);
public native int getVariableId(OpaqueVariable variable); public native int getVariableId(OpaqueVariable variable);
public native int getVariableIndex(OpaqueVariable variable); public native int getVariableIndex(OpaqueVariable variable);
public native @Cast("char*") String getVariableName(OpaqueVariable variable); public native @Cast("char*") String getVariableName(OpaqueVariable variable);
@ -3008,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer); public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer); public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer); public native void deleteVariablesSet(OpaqueVariablesSet pointer);
// GraphState creation // GraphState creation
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);

View File

@ -116,7 +116,7 @@ public class Nd4jCudaPresets implements InfoMapper {
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper")) .put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList")) .put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet")) .put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet"))
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable")) .put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))

View File

@ -74,7 +74,7 @@ import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList; import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack; import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable; import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariableSet; import org.nd4j.nativeblas.OpaqueVariablesSet;
import java.util.*; import java.util.*;
@ -1950,14 +1950,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
val newMap = new LinkedHashMap<String, INDArray>(); val newMap = new LinkedHashMap<String, INDArray>();
OpaqueVariableSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
OpStatus status = OpStatus.byNumber(loop.getVariableSetStatus(result)); OpStatus status = OpStatus.byNumber(loop.getVariablesSetStatus(result));
if (status != OpStatus.ND4J_STATUS_OK) if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status); throw new ND4JIllegalStateException("Op execution failed: " + status);
for (int e = 0; e < loop.getVariableSetSize(result); e++) { for (int e = 0; e < loop.getVariablesSetSize(result); e++) {
OpaqueVariable var = loop.getVariable(result, e); OpaqueVariable var = loop.getVariable(result, e);
int nodeId = loop.getVariableId(var); int nodeId = loop.getVariableId(var);
int index = loop.getVariableIndex(var); int index = loop.getVariableIndex(var);

View File

@ -2988,13 +2988,13 @@ public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set); public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set);
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set); public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set);
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i); public native OpaqueVariable getVariable(OpaqueVariablesSet set, @Cast("Nd4jLong") long i);
public native int getVariableId(OpaqueVariable variable); public native int getVariableId(OpaqueVariable variable);
public native int getVariableIndex(OpaqueVariable variable); public native int getVariableIndex(OpaqueVariable variable);
public native @Cast("char*") String getVariableName(OpaqueVariable variable); public native @Cast("char*") String getVariableName(OpaqueVariable variable);
@ -3008,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer); public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer); public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer); public native void deleteVariablesSet(OpaqueVariablesSet pointer);
// GraphState creation // GraphState creation
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);

View File

@ -159,7 +159,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper")) .put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList")) .put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet")) .put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet"))
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable")) .put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))