Fix functions of OpaqueVariablesSet
This commit is contained in:
		
							parent
							
								
									b57f1d52cc
								
							
						
					
					
						commit
						8d1fe8b1b3
					
				| @ -1678,14 +1678,14 @@ ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList); | ||||
| 
 | ||||
| ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer); | ||||
| 
 | ||||
| typedef nd4j::graph::VariablesSet OpaqueVariableSet; | ||||
| typedef nd4j::graph::VariablesSet OpaqueVariablesSet; | ||||
| typedef nd4j::graph::Variable OpaqueVariable; | ||||
| 
 | ||||
| ND4J_EXPORT OpaqueVariableSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); | ||||
| ND4J_EXPORT OpaqueVariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); | ||||
| 
 | ||||
| ND4J_EXPORT Nd4jLong getVariableSetSize(OpaqueVariableSet* set); | ||||
| ND4J_EXPORT Nd4jStatus getVariableSetStatus(OpaqueVariableSet* set); | ||||
| ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariableSet* set, Nd4jLong i); | ||||
| ND4J_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set); | ||||
| ND4J_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set); | ||||
| ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i); | ||||
| ND4J_EXPORT int getVariableId(OpaqueVariable* variable); | ||||
| ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable); | ||||
| ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable); | ||||
| @ -1699,7 +1699,7 @@ ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer); | ||||
| ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer); | ||||
| ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer); | ||||
| 
 | ||||
| ND4J_EXPORT void deleteVariablesSet(OpaqueVariableSet pointer); | ||||
| ND4J_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer); | ||||
| 
 | ||||
| // GraphState creation
 | ||||
| ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id); | ||||
|  | ||||
| @ -2337,11 +2337,11 @@ nd4j::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLo | ||||
|     return nullptr; | ||||
| } | ||||
| 
 | ||||
| Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) { | ||||
| Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) { | ||||
|     return set->size(); | ||||
| } | ||||
| 
 | ||||
| Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) { | ||||
| Nd4jStatus getVariablesSetStatus(nd4j::graph::VariablesSet* set) { | ||||
|     return set->status(); | ||||
| } | ||||
| 
 | ||||
| @ -2396,14 +2396,8 @@ void deleteLongArray(Nd4jPointer pointer) { | ||||
|     delete[] ptr; | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| static void deleteVariablesSetT(Nd4jPointer pointer) { | ||||
|     auto ptr = reinterpret_cast<nd4j::graph::VariablesSet*>(pointer); | ||||
|     delete ptr; | ||||
| } | ||||
| 
 | ||||
| void deleteVariablesSet(Nd4jPointer pointer) { | ||||
|     deleteVariablesSetT<double>(pointer); | ||||
| void deleteVariablesSet(nd4j::graph::VariablesSet* pointer) { | ||||
|     delete pointer; | ||||
| } | ||||
| 
 | ||||
| const char* getAllOperations() { | ||||
|  | ||||
| @ -1935,6 +1935,19 @@ void execAggregate(Nd4jPointer *extraPointers, | ||||
|     nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed"); | ||||
| } | ||||
| 
 | ||||
| void batchExecutor(Nd4jPointer *extraPointers, | ||||
|                                int numAggregates, | ||||
|                                int opNum, | ||||
|                                int maxArgs, | ||||
|                                int maxShapes, | ||||
|                                int maxIntArrays, | ||||
|                                int maxIntArraySize, | ||||
|                                int maxIdx, | ||||
|                                int maxReals, | ||||
|                                void *ptrToArguments, | ||||
|                                nd4j::DataType dtype) { | ||||
| } | ||||
| 
 | ||||
| void execAggregateBatch(Nd4jPointer *extraPointers, | ||||
| 									int numAggregates, int opNum, | ||||
| 									int maxArgs, int maxShapes, | ||||
| @ -2808,11 +2821,11 @@ VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, N | ||||
| 	return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs); | ||||
| } | ||||
| 
 | ||||
| Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) { | ||||
| Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) { | ||||
|     return set->size(); | ||||
| } | ||||
| 
 | ||||
| Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) { | ||||
| Nd4jStatus getVariablesSetStatus(nd4j::graph::VariablesSet* set) { | ||||
|     return set->status(); | ||||
| } | ||||
| 
 | ||||
| @ -2867,14 +2880,8 @@ void deleteLongArray(Nd4jPointer pointer) { | ||||
| 	delete[] ptr; | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| static void deleteVariablesSetT(Nd4jPointer pointer) { | ||||
| 	nd4j::graph::VariablesSet* ptr = reinterpret_cast<nd4j::graph::VariablesSet*>(pointer); | ||||
| 	delete ptr; | ||||
| } | ||||
| 
 | ||||
| void deleteVariablesSet(Nd4jPointer pointer) { | ||||
| 	deleteVariablesSetT<double>(pointer); | ||||
| void deleteVariablesSet(nd4j::graph::VariablesSet* pointer) { | ||||
| 	delete pointer; | ||||
| } | ||||
| 
 | ||||
| void deleteShapeList(Nd4jPointer shapeList) { | ||||
|  | ||||
| @ -1068,11 +1068,11 @@ public interface NativeOps { | ||||
| 
 | ||||
|     int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer); | ||||
| 
 | ||||
|     OpaqueVariableSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs); | ||||
|     OpaqueVariablesSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs); | ||||
| 
 | ||||
|     long getVariableSetSize(OpaqueVariableSet set); | ||||
|     int getVariableSetStatus(OpaqueVariableSet set); | ||||
|     OpaqueVariable getVariable(OpaqueVariableSet set, long i); | ||||
|     long getVariablesSetSize(OpaqueVariablesSet set); | ||||
|     int getVariablesSetStatus(OpaqueVariablesSet set); | ||||
|     OpaqueVariable getVariable(OpaqueVariablesSet set, long i); | ||||
|     int getVariableId(OpaqueVariable variable); | ||||
|     int getVariableIndex(OpaqueVariable variable); | ||||
|     String getVariableName(OpaqueVariable variable); | ||||
| @ -1095,7 +1095,7 @@ public interface NativeOps { | ||||
| 
 | ||||
|     void deleteNPArrayMap(Pointer pointer); | ||||
| 
 | ||||
|     void deleteVariablesSet(OpaqueVariableSet pointer); | ||||
|     void deleteVariablesSet(OpaqueVariablesSet pointer); | ||||
| 
 | ||||
|     // GraphState creation | ||||
|     Pointer getGraphState(long id); | ||||
|  | ||||
| @ -22,6 +22,6 @@ import org.bytedeco.javacpp.Pointer; | ||||
|  * | ||||
|  * @author saudet | ||||
|  */ | ||||
| public class OpaqueVariableSet extends Pointer { | ||||
|     public OpaqueVariableSet(Pointer p) { super(p); } | ||||
| public class OpaqueVariablesSet extends Pointer { | ||||
|     public OpaqueVariablesSet(Pointer p) { super(p); } | ||||
| } | ||||
| @ -76,7 +76,7 @@ import org.nd4j.nativeblas.OpaqueConstantDataBuffer; | ||||
| import org.nd4j.nativeblas.OpaqueShapeList; | ||||
| import org.nd4j.nativeblas.OpaqueTadPack; | ||||
| import org.nd4j.nativeblas.OpaqueVariable; | ||||
| import org.nd4j.nativeblas.OpaqueVariableSet; | ||||
| import org.nd4j.nativeblas.OpaqueVariablesSet; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| @ -2427,14 +2427,14 @@ public class CudaExecutioner extends DefaultOpExecutioner { | ||||
| 
 | ||||
|         val newMap = new LinkedHashMap<String, INDArray>(); | ||||
| 
 | ||||
|         OpaqueVariableSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); | ||||
|         OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); | ||||
| 
 | ||||
|         OpStatus status = OpStatus.byNumber(nativeOps.getVariableSetStatus(result)); | ||||
|         OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result)); | ||||
| 
 | ||||
|         if (status != OpStatus.ND4J_STATUS_OK) | ||||
|             throw new ND4JIllegalStateException("Op execution failed: " + status); | ||||
| 
 | ||||
|         for (int e = 0; e < nativeOps.getVariableSetSize(result); e++) { | ||||
|         for (int e = 0; e < nativeOps.getVariablesSetSize(result); e++) { | ||||
|             OpaqueVariable var = nativeOps.getVariable(result, e); | ||||
|             int nodeId = nativeOps.getVariableId(var); | ||||
|             int index = nativeOps.getVariableIndex(var); | ||||
|  | ||||
| @ -2988,13 +2988,13 @@ public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList); | ||||
| 
 | ||||
| public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); | ||||
| 
 | ||||
| public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); | ||||
| public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); | ||||
| public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); | ||||
| public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); | ||||
| public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); | ||||
| public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); | ||||
| 
 | ||||
| public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set); | ||||
| public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set); | ||||
| public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i); | ||||
| public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set); | ||||
| public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set); | ||||
| public native OpaqueVariable getVariable(OpaqueVariablesSet set, @Cast("Nd4jLong") long i); | ||||
| public native int getVariableId(OpaqueVariable variable); | ||||
| public native int getVariableIndex(OpaqueVariable variable); | ||||
| public native @Cast("char*") String getVariableName(OpaqueVariable variable); | ||||
| @ -3008,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer); | ||||
| public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer); | ||||
| public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer); | ||||
| 
 | ||||
| public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer); | ||||
| public native void deleteVariablesSet(OpaqueVariablesSet pointer); | ||||
| 
 | ||||
| // GraphState creation | ||||
| public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); | ||||
|  | ||||
| @ -116,7 +116,7 @@ public class Nd4jCudaPresets implements InfoMapper { | ||||
|                 .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) | ||||
|                 .put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper")) | ||||
|                 .put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList")) | ||||
|                 .put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet")) | ||||
|                 .put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet")) | ||||
|                 .put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable")) | ||||
|                 .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) | ||||
|                 .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) | ||||
|  | ||||
| @ -74,7 +74,7 @@ import org.nd4j.nativeblas.OpaqueConstantDataBuffer; | ||||
| import org.nd4j.nativeblas.OpaqueShapeList; | ||||
| import org.nd4j.nativeblas.OpaqueTadPack; | ||||
| import org.nd4j.nativeblas.OpaqueVariable; | ||||
| import org.nd4j.nativeblas.OpaqueVariableSet; | ||||
| import org.nd4j.nativeblas.OpaqueVariablesSet; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| @ -1950,14 +1950,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { | ||||
| 
 | ||||
|         val newMap = new LinkedHashMap<String, INDArray>(); | ||||
| 
 | ||||
|             OpaqueVariableSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); | ||||
|             OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); | ||||
| 
 | ||||
|             OpStatus status = OpStatus.byNumber(loop.getVariableSetStatus(result)); | ||||
|             OpStatus status = OpStatus.byNumber(loop.getVariablesSetStatus(result)); | ||||
| 
 | ||||
|             if (status != OpStatus.ND4J_STATUS_OK) | ||||
|                 throw new ND4JIllegalStateException("Op execution failed: " + status); | ||||
| 
 | ||||
|             for (int e = 0; e < loop.getVariableSetSize(result); e++) { | ||||
|             for (int e = 0; e < loop.getVariablesSetSize(result); e++) { | ||||
|                 OpaqueVariable var = loop.getVariable(result, e); | ||||
|                 int nodeId = loop.getVariableId(var); | ||||
|                 int index = loop.getVariableIndex(var); | ||||
|  | ||||
| @ -2988,13 +2988,13 @@ public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList); | ||||
| 
 | ||||
| public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); | ||||
| 
 | ||||
| public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); | ||||
| public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); | ||||
| public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); | ||||
| public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); | ||||
| public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); | ||||
| public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); | ||||
| 
 | ||||
| public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set); | ||||
| public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set); | ||||
| public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i); | ||||
| public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set); | ||||
| public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set); | ||||
| public native OpaqueVariable getVariable(OpaqueVariablesSet set, @Cast("Nd4jLong") long i); | ||||
| public native int getVariableId(OpaqueVariable variable); | ||||
| public native int getVariableIndex(OpaqueVariable variable); | ||||
| public native @Cast("char*") String getVariableName(OpaqueVariable variable); | ||||
| @ -3008,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer); | ||||
| public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer); | ||||
| public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer); | ||||
| 
 | ||||
| public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer); | ||||
| public native void deleteVariablesSet(OpaqueVariablesSet pointer); | ||||
| 
 | ||||
| // GraphState creation | ||||
| public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); | ||||
|  | ||||
| @ -159,7 +159,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { | ||||
|                         .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) | ||||
|                         .put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper")) | ||||
|                         .put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList")) | ||||
|                         .put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet")) | ||||
|                         .put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet")) | ||||
|                         .put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable")) | ||||
|                         .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) | ||||
|                         .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user