Fix functions of OpaqueVariablesSet
parent
b57f1d52cc
commit
8d1fe8b1b3
|
@ -1678,14 +1678,14 @@ ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
|
|||
|
||||
ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer);
|
||||
|
||||
typedef nd4j::graph::VariablesSet OpaqueVariableSet;
|
||||
typedef nd4j::graph::VariablesSet OpaqueVariablesSet;
|
||||
typedef nd4j::graph::Variable OpaqueVariable;
|
||||
|
||||
ND4J_EXPORT OpaqueVariableSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
|
||||
ND4J_EXPORT OpaqueVariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
|
||||
|
||||
ND4J_EXPORT Nd4jLong getVariableSetSize(OpaqueVariableSet* set);
|
||||
ND4J_EXPORT Nd4jStatus getVariableSetStatus(OpaqueVariableSet* set);
|
||||
ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariableSet* set, Nd4jLong i);
|
||||
ND4J_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set);
|
||||
ND4J_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set);
|
||||
ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i);
|
||||
ND4J_EXPORT int getVariableId(OpaqueVariable* variable);
|
||||
ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable);
|
||||
ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable);
|
||||
|
@ -1699,7 +1699,7 @@ ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer);
|
|||
ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer);
|
||||
ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer);
|
||||
|
||||
ND4J_EXPORT void deleteVariablesSet(OpaqueVariableSet pointer);
|
||||
ND4J_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer);
|
||||
|
||||
// GraphState creation
|
||||
ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id);
|
||||
|
|
|
@ -2337,11 +2337,11 @@ nd4j::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLo
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
|
||||
Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) {
|
||||
return set->size();
|
||||
}
|
||||
|
||||
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
|
||||
Nd4jStatus getVariablesSetStatus(nd4j::graph::VariablesSet* set) {
|
||||
return set->status();
|
||||
}
|
||||
|
||||
|
@ -2396,14 +2396,8 @@ void deleteLongArray(Nd4jPointer pointer) {
|
|||
delete[] ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void deleteVariablesSetT(Nd4jPointer pointer) {
|
||||
auto ptr = reinterpret_cast<nd4j::graph::VariablesSet*>(pointer);
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void deleteVariablesSet(Nd4jPointer pointer) {
|
||||
deleteVariablesSetT<double>(pointer);
|
||||
void deleteVariablesSet(nd4j::graph::VariablesSet* pointer) {
|
||||
delete pointer;
|
||||
}
|
||||
|
||||
const char* getAllOperations() {
|
||||
|
|
|
@ -1935,6 +1935,19 @@ void execAggregate(Nd4jPointer *extraPointers,
|
|||
nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed");
|
||||
}
|
||||
|
||||
void batchExecutor(Nd4jPointer *extraPointers,
|
||||
int numAggregates,
|
||||
int opNum,
|
||||
int maxArgs,
|
||||
int maxShapes,
|
||||
int maxIntArrays,
|
||||
int maxIntArraySize,
|
||||
int maxIdx,
|
||||
int maxReals,
|
||||
void *ptrToArguments,
|
||||
nd4j::DataType dtype) {
|
||||
}
|
||||
|
||||
void execAggregateBatch(Nd4jPointer *extraPointers,
|
||||
int numAggregates, int opNum,
|
||||
int maxArgs, int maxShapes,
|
||||
|
@ -2808,11 +2821,11 @@ VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, N
|
|||
return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs);
|
||||
}
|
||||
|
||||
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
|
||||
Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) {
|
||||
return set->size();
|
||||
}
|
||||
|
||||
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
|
||||
Nd4jStatus getVariablesSetStatus(nd4j::graph::VariablesSet* set) {
|
||||
return set->status();
|
||||
}
|
||||
|
||||
|
@ -2867,14 +2880,8 @@ void deleteLongArray(Nd4jPointer pointer) {
|
|||
delete[] ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void deleteVariablesSetT(Nd4jPointer pointer) {
|
||||
nd4j::graph::VariablesSet* ptr = reinterpret_cast<nd4j::graph::VariablesSet*>(pointer);
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void deleteVariablesSet(Nd4jPointer pointer) {
|
||||
deleteVariablesSetT<double>(pointer);
|
||||
void deleteVariablesSet(nd4j::graph::VariablesSet* pointer) {
|
||||
delete pointer;
|
||||
}
|
||||
|
||||
void deleteShapeList(Nd4jPointer shapeList) {
|
||||
|
|
|
@ -1068,11 +1068,11 @@ public interface NativeOps {
|
|||
|
||||
int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer);
|
||||
|
||||
OpaqueVariableSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
OpaqueVariablesSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
|
||||
long getVariableSetSize(OpaqueVariableSet set);
|
||||
int getVariableSetStatus(OpaqueVariableSet set);
|
||||
OpaqueVariable getVariable(OpaqueVariableSet set, long i);
|
||||
long getVariablesSetSize(OpaqueVariablesSet set);
|
||||
int getVariablesSetStatus(OpaqueVariablesSet set);
|
||||
OpaqueVariable getVariable(OpaqueVariablesSet set, long i);
|
||||
int getVariableId(OpaqueVariable variable);
|
||||
int getVariableIndex(OpaqueVariable variable);
|
||||
String getVariableName(OpaqueVariable variable);
|
||||
|
@ -1095,7 +1095,7 @@ public interface NativeOps {
|
|||
|
||||
void deleteNPArrayMap(Pointer pointer);
|
||||
|
||||
void deleteVariablesSet(OpaqueVariableSet pointer);
|
||||
void deleteVariablesSet(OpaqueVariablesSet pointer);
|
||||
|
||||
// GraphState creation
|
||||
Pointer getGraphState(long id);
|
||||
|
|
|
@ -22,6 +22,6 @@ import org.bytedeco.javacpp.Pointer;
|
|||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueVariableSet extends Pointer {
|
||||
public OpaqueVariableSet(Pointer p) { super(p); }
|
||||
public class OpaqueVariablesSet extends Pointer {
|
||||
public OpaqueVariablesSet(Pointer p) { super(p); }
|
||||
}
|
|
@ -76,7 +76,7 @@ import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
|
|||
import org.nd4j.nativeblas.OpaqueShapeList;
|
||||
import org.nd4j.nativeblas.OpaqueTadPack;
|
||||
import org.nd4j.nativeblas.OpaqueVariable;
|
||||
import org.nd4j.nativeblas.OpaqueVariableSet;
|
||||
import org.nd4j.nativeblas.OpaqueVariablesSet;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
@ -2427,14 +2427,14 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val newMap = new LinkedHashMap<String, INDArray>();
|
||||
|
||||
OpaqueVariableSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
|
||||
OpStatus status = OpStatus.byNumber(nativeOps.getVariableSetStatus(result));
|
||||
OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result));
|
||||
|
||||
if (status != OpStatus.ND4J_STATUS_OK)
|
||||
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
||||
|
||||
for (int e = 0; e < nativeOps.getVariableSetSize(result); e++) {
|
||||
for (int e = 0; e < nativeOps.getVariablesSetSize(result); e++) {
|
||||
OpaqueVariable var = nativeOps.getVariable(result, e);
|
||||
int nodeId = nativeOps.getVariableId(var);
|
||||
int index = nativeOps.getVariableIndex(var);
|
||||
|
|
|
@ -2988,13 +2988,13 @@ public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
|
|||
|
||||
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||
|
||||
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
|
||||
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
|
||||
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
|
||||
public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set);
|
||||
public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set);
|
||||
public native OpaqueVariable getVariable(OpaqueVariablesSet set, @Cast("Nd4jLong") long i);
|
||||
public native int getVariableId(OpaqueVariable variable);
|
||||
public native int getVariableIndex(OpaqueVariable variable);
|
||||
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
|
||||
|
@ -3008,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
|
|||
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||
|
||||
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
|
||||
public native void deleteVariablesSet(OpaqueVariablesSet pointer);
|
||||
|
||||
// GraphState creation
|
||||
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
|
||||
|
|
|
@ -116,7 +116,7 @@ public class Nd4jCudaPresets implements InfoMapper {
|
|||
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
||||
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
||||
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
|
||||
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
|
||||
.put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet"))
|
||||
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
|
||||
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
|
||||
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))
|
||||
|
|
|
@ -74,7 +74,7 @@ import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
|
|||
import org.nd4j.nativeblas.OpaqueShapeList;
|
||||
import org.nd4j.nativeblas.OpaqueTadPack;
|
||||
import org.nd4j.nativeblas.OpaqueVariable;
|
||||
import org.nd4j.nativeblas.OpaqueVariableSet;
|
||||
import org.nd4j.nativeblas.OpaqueVariablesSet;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
@ -1950,14 +1950,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val newMap = new LinkedHashMap<String, INDArray>();
|
||||
|
||||
OpaqueVariableSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
|
||||
OpStatus status = OpStatus.byNumber(loop.getVariableSetStatus(result));
|
||||
OpStatus status = OpStatus.byNumber(loop.getVariablesSetStatus(result));
|
||||
|
||||
if (status != OpStatus.ND4J_STATUS_OK)
|
||||
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
||||
|
||||
for (int e = 0; e < loop.getVariableSetSize(result); e++) {
|
||||
for (int e = 0; e < loop.getVariablesSetSize(result); e++) {
|
||||
OpaqueVariable var = loop.getVariable(result, e);
|
||||
int nodeId = loop.getVariableId(var);
|
||||
int index = loop.getVariableIndex(var);
|
||||
|
|
|
@ -2988,13 +2988,13 @@ public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
|
|||
|
||||
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||
public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||
|
||||
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
|
||||
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
|
||||
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
|
||||
public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set);
|
||||
public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set);
|
||||
public native OpaqueVariable getVariable(OpaqueVariablesSet set, @Cast("Nd4jLong") long i);
|
||||
public native int getVariableId(OpaqueVariable variable);
|
||||
public native int getVariableIndex(OpaqueVariable variable);
|
||||
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
|
||||
|
@ -3008,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
|
|||
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||
|
||||
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
|
||||
public native void deleteVariablesSet(OpaqueVariablesSet pointer);
|
||||
|
||||
// GraphState creation
|
||||
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
|
||||
|
|
|
@ -159,7 +159,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
|
|||
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
||||
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
||||
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
|
||||
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
|
||||
.put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet"))
|
||||
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
|
||||
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
|
||||
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))
|
||||
|
|
Loading…
Reference in New Issue