Fix functions of OpaqueVariablesSet

master
Samuel Audet 2019-07-26 07:22:44 +00:00 committed by AlexDBlack
parent b57f1d52cc
commit 8d1fe8b1b3
11 changed files with 58 additions and 57 deletions

View File

@ -1678,14 +1678,14 @@ ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer);
typedef nd4j::graph::VariablesSet OpaqueVariableSet;
typedef nd4j::graph::VariablesSet OpaqueVariablesSet;
typedef nd4j::graph::Variable OpaqueVariable;
ND4J_EXPORT OpaqueVariableSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
ND4J_EXPORT OpaqueVariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
ND4J_EXPORT Nd4jLong getVariableSetSize(OpaqueVariableSet* set);
ND4J_EXPORT Nd4jStatus getVariableSetStatus(OpaqueVariableSet* set);
ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariableSet* set, Nd4jLong i);
ND4J_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set);
ND4J_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set);
ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i);
ND4J_EXPORT int getVariableId(OpaqueVariable* variable);
ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable);
ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable);
@ -1699,7 +1699,7 @@ ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer);
ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer);
ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer);
ND4J_EXPORT void deleteVariablesSet(OpaqueVariableSet pointer);
ND4J_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer);
// GraphState creation
ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id);

View File

@ -2337,11 +2337,11 @@ nd4j::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLo
return nullptr;
}
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) {
return set->size();
}
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
Nd4jStatus getVariablesSetStatus(nd4j::graph::VariablesSet* set) {
return set->status();
}
@ -2396,14 +2396,8 @@ void deleteLongArray(Nd4jPointer pointer) {
delete[] ptr;
}
template <typename T>
static void deleteVariablesSetT(Nd4jPointer pointer) {
auto ptr = reinterpret_cast<nd4j::graph::VariablesSet*>(pointer);
delete ptr;
}
void deleteVariablesSet(Nd4jPointer pointer) {
deleteVariablesSetT<double>(pointer);
void deleteVariablesSet(nd4j::graph::VariablesSet* pointer) {
delete pointer;
}
const char* getAllOperations() {

View File

@ -1935,6 +1935,19 @@ void execAggregate(Nd4jPointer *extraPointers,
nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed");
}
void batchExecutor(Nd4jPointer *extraPointers,
int numAggregates,
int opNum,
int maxArgs,
int maxShapes,
int maxIntArrays,
int maxIntArraySize,
int maxIdx,
int maxReals,
void *ptrToArguments,
nd4j::DataType dtype) {
}
void execAggregateBatch(Nd4jPointer *extraPointers,
int numAggregates, int opNum,
int maxArgs, int maxShapes,
@ -2808,11 +2821,11 @@ VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, N
return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs);
}
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) {
return set->size();
}
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
Nd4jStatus getVariablesSetStatus(nd4j::graph::VariablesSet* set) {
return set->status();
}
@ -2867,14 +2880,8 @@ void deleteLongArray(Nd4jPointer pointer) {
delete[] ptr;
}
template <typename T>
static void deleteVariablesSetT(Nd4jPointer pointer) {
nd4j::graph::VariablesSet* ptr = reinterpret_cast<nd4j::graph::VariablesSet*>(pointer);
delete ptr;
}
void deleteVariablesSet(Nd4jPointer pointer) {
deleteVariablesSetT<double>(pointer);
void deleteVariablesSet(nd4j::graph::VariablesSet* pointer) {
delete pointer;
}
void deleteShapeList(Nd4jPointer shapeList) {

View File

@ -1068,11 +1068,11 @@ public interface NativeOps {
int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer);
OpaqueVariableSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
OpaqueVariablesSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
long getVariableSetSize(OpaqueVariableSet set);
int getVariableSetStatus(OpaqueVariableSet set);
OpaqueVariable getVariable(OpaqueVariableSet set, long i);
long getVariablesSetSize(OpaqueVariablesSet set);
int getVariablesSetStatus(OpaqueVariablesSet set);
OpaqueVariable getVariable(OpaqueVariablesSet set, long i);
int getVariableId(OpaqueVariable variable);
int getVariableIndex(OpaqueVariable variable);
String getVariableName(OpaqueVariable variable);
@ -1095,7 +1095,7 @@ public interface NativeOps {
void deleteNPArrayMap(Pointer pointer);
void deleteVariablesSet(OpaqueVariableSet pointer);
void deleteVariablesSet(OpaqueVariablesSet pointer);
// GraphState creation
Pointer getGraphState(long id);

View File

@ -22,6 +22,6 @@ import org.bytedeco.javacpp.Pointer;
*
* @author saudet
*/
public class OpaqueVariableSet extends Pointer {
public OpaqueVariableSet(Pointer p) { super(p); }
public class OpaqueVariablesSet extends Pointer {
public OpaqueVariablesSet(Pointer p) { super(p); }
}

View File

@ -76,7 +76,7 @@ import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariableSet;
import org.nd4j.nativeblas.OpaqueVariablesSet;
import java.util.*;
@ -2427,14 +2427,14 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val newMap = new LinkedHashMap<String, INDArray>();
OpaqueVariableSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
OpStatus status = OpStatus.byNumber(nativeOps.getVariableSetStatus(result));
OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result));
if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status);
for (int e = 0; e < nativeOps.getVariableSetSize(result); e++) {
for (int e = 0; e < nativeOps.getVariablesSetSize(result); e++) {
OpaqueVariable var = nativeOps.getVariable(result, e);
int nodeId = nativeOps.getVariableId(var);
int index = nativeOps.getVariableIndex(var);

View File

@ -2988,13 +2988,13 @@ public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set);
public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set);
public native OpaqueVariable getVariable(OpaqueVariablesSet set, @Cast("Nd4jLong") long i);
public native int getVariableId(OpaqueVariable variable);
public native int getVariableIndex(OpaqueVariable variable);
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
@ -3008,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
public native void deleteVariablesSet(OpaqueVariablesSet pointer);
// GraphState creation
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);

View File

@ -116,7 +116,7 @@ public class Nd4jCudaPresets implements InfoMapper {
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
.put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet"))
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))

View File

@ -74,7 +74,7 @@ import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariableSet;
import org.nd4j.nativeblas.OpaqueVariablesSet;
import java.util.*;
@ -1950,14 +1950,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
val newMap = new LinkedHashMap<String, INDArray>();
OpaqueVariableSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
OpStatus status = OpStatus.byNumber(loop.getVariableSetStatus(result));
OpStatus status = OpStatus.byNumber(loop.getVariablesSetStatus(result));
if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status);
for (int e = 0; e < loop.getVariableSetSize(result); e++) {
for (int e = 0; e < loop.getVariablesSetSize(result); e++) {
OpaqueVariable var = loop.getVariable(result, e);
int nodeId = loop.getVariableId(var);
int index = loop.getVariableIndex(var);

View File

@ -2988,13 +2988,13 @@ public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set);
public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set);
public native OpaqueVariable getVariable(OpaqueVariablesSet set, @Cast("Nd4jLong") long i);
public native int getVariableId(OpaqueVariable variable);
public native int getVariableIndex(OpaqueVariable variable);
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
@ -3008,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
public native void deleteVariablesSet(OpaqueVariablesSet pointer);
// GraphState creation
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);

View File

@ -159,7 +159,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
.put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet"))
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))