diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 141ecb6ec..01b656861 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1518,7 +1518,7 @@ ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPoi typedef nd4j::ShapeList OpaqueShapeList; ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); -ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs); +ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs); ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i); diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 3ba971aa5..0410c833b 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -1974,7 +1974,7 @@ void deleteShapeList(Nd4jPointer shapeList) { delete list; } -nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { +nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { nd4j::graph::VariableSpace varSpace; Context block(2, &varSpace); nd4j::ShapeList inShapes; @@ -1988,6 +1988,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D for (int e = 0; e < numBArgs; e++) block.getBArguments()->push_back(bArgs[e]); + for (int e = 0; e < numDArgs; e++) + block.getDArguments()->push_back((nd4j::DataType) dArgs[e]); + for (int e = 0; e < numInputShapes; e++) { auto shape_ = reinterpret_cast(inputShapes[e]); @@ -2015,11 +2018,11 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D return shapeList; } -nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { +nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { try { auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs); + return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 45de82b32..d65dcaed5 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -2684,7 +2684,7 @@ const char* getAllCustomOps() { } -nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { +nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { nd4j::graph::VariableSpace varSpace; Context block(2, &varSpace); nd4j::ShapeList inShapes; @@ -2698,6 +2698,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D for (int e = 0; e < numBArgs; e++) block.getBArguments()->push_back(bArgs[e]); + for (int e = 0; e < numDArgs; e++) + block.getDArguments()->push_back((nd4j::DataType) dArgs[e]); + for (int e = 0; e < numInputShapes; e++) { auto shape_ = reinterpret_cast(inputShapes[e]); @@ -2722,12 +2725,12 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D return shapeList; } -nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { +nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { try { auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, - iArgs, numIArgs, bArgs, numBArgs); + iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs index 7fa9722db..c94e0fcc4 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs @@ -112,6 +112,14 @@ public struct FlatNode : IFlatbufferObject public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } } public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } } + public DType ExtraTypes(int j) { int o = __p.__offset(48); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; } + public int ExtraTypesLength { get { int o = __p.__offset(48); return o != 0 ? __p.__vector_len(o) : 0; } } +#if ENABLE_SPAN_T + public Span GetExtraTypesBytes() { return __p.__vector_as_span(48); } +#else + public ArraySegment? GetExtraTypesBytes() { return __p.__vector_as_arraysegment(48); } +#endif + public DType[] GetExtraTypesArray() { return __p.__vector_as_array(48); } public static Offset CreateFlatNode(FlatBufferBuilder builder, int id = 0, @@ -135,9 +143,11 @@ public struct FlatNode : IFlatbufferObject Offset scalarOffset = default(Offset), VectorOffset controlDepsOffset = default(VectorOffset), VectorOffset varControlDepsOffset = default(VectorOffset), - VectorOffset controlDepForOffset = default(VectorOffset)) { - builder.StartObject(22); + VectorOffset controlDepForOffset = default(VectorOffset), + VectorOffset extraTypesOffset = default(VectorOffset)) { + builder.StartObject(23); FlatNode.AddOpNum(builder, opNum); + FlatNode.AddExtraTypes(builder, extraTypesOffset); FlatNode.AddControlDepFor(builder, controlDepForOffset); FlatNode.AddVarControlDeps(builder, varControlDepsOffset); FlatNode.AddControlDeps(builder, controlDepsOffset); @@ -162,7 +172,7 @@ public struct FlatNode : IFlatbufferObject return FlatNode.EndFlatNode(builder); } - public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); } + public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(23); } public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); } @@ -224,6 +234,10 @@ public struct FlatNode : IFlatbufferObject public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddExtraTypes(FlatBufferBuilder builder, VectorOffset extraTypesOffset) { builder.AddOffset(22, extraTypesOffset.Value, 0); } + public static VectorOffset CreateExtraTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } + public static VectorOffset CreateExtraTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } + public static void StartExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static Offset EndFlatNode(FlatBufferBuilder builder) { int o = builder.EndObject(); return new Offset(o); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java index 8a72cc00a..2fe0a0ee9 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java @@ -72,6 +72,10 @@ public final class FlatNode extends Table { public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } + public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; } + public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; } + public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); } + public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); } public static int createFlatNode(FlatBufferBuilder builder, int id, @@ -95,9 +99,11 @@ public final class FlatNode extends Table { int scalarOffset, int controlDepsOffset, int varControlDepsOffset, - int controlDepForOffset) { - builder.startObject(22); + int controlDepForOffset, + int extraTypesOffset) { + builder.startObject(23); FlatNode.addOpNum(builder, opNum); + FlatNode.addExtraTypes(builder, extraTypesOffset); FlatNode.addControlDepFor(builder, controlDepForOffset); FlatNode.addVarControlDeps(builder, varControlDepsOffset); FlatNode.addControlDeps(builder, controlDepsOffset); @@ -122,7 +128,7 @@ public final class FlatNode extends Table { return FlatNode.endFlatNode(builder); } - public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); } + public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } @@ -171,6 +177,9 @@ public final class FlatNode extends Table { public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); } + public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } + public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static int endFlatNode(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py index 889eca62f..d5104efb6 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py @@ -339,7 +339,29 @@ class FlatNode(object): return self._tab.VectorLen(o) return 0 -def FlatNodeStart(builder): builder.StartObject(22) + # FlatNode + def ExtraTypes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # FlatNode + def ExtraTypesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o) + return 0 + + # FlatNode + def ExtraTypesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + +def FlatNodeStart(builder): builder.StartObject(23) def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0) def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0) @@ -375,4 +397,6 @@ def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTR def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0) def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatNodeAddExtraTypes(builder, extraTypes): builder.PrependUOffsetTRelativeSlot(22, flatbuffers.number_types.UOffsetTFlags.py_type(extraTypes), 0) +def FlatNodeStartExtraTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1) def FlatNodeEnd(builder): return builder.EndObject() diff --git a/libnd4j/include/graph/generated/node_generated.h b/libnd4j/include/graph/generated/node_generated.h index 6ca85f7b0..92f4ab126 100644 --- a/libnd4j/include/graph/generated/node_generated.h +++ b/libnd4j/include/graph/generated/node_generated.h @@ -38,7 +38,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_SCALAR = 40, VT_CONTROLDEPS = 42, VT_VARCONTROLDEPS = 44, - VT_CONTROLDEPFOR = 46 + VT_CONTROLDEPFOR = 46, + VT_EXTRATYPES = 48 }; int32_t id() const { return GetField(VT_ID, 0); @@ -106,6 +107,9 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *controlDepFor() const { return GetPointer> *>(VT_CONTROLDEPFOR); } + const flatbuffers::Vector *extraTypes() const { + return GetPointer *>(VT_EXTRATYPES); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_ID) && @@ -153,6 +157,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_CONTROLDEPFOR) && verifier.VerifyVector(controlDepFor()) && verifier.VerifyVectorOfStrings(controlDepFor()) && + VerifyOffset(verifier, VT_EXTRATYPES) && + verifier.VerifyVector(extraTypes()) && verifier.EndTable(); } }; @@ -226,6 +232,9 @@ struct FlatNodeBuilder { void add_controlDepFor(flatbuffers::Offset>> controlDepFor) { fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor); } + void add_extraTypes(flatbuffers::Offset> extraTypes) { + fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes); + } explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -261,9 +270,11 @@ inline flatbuffers::Offset CreateFlatNode( flatbuffers::Offset scalar = 0, flatbuffers::Offset>> controlDeps = 0, flatbuffers::Offset>> varControlDeps = 0, - flatbuffers::Offset>> controlDepFor = 0) { + flatbuffers::Offset>> controlDepFor = 0, + flatbuffers::Offset> extraTypes = 0) { FlatNodeBuilder builder_(_fbb); builder_.add_opNum(opNum); + builder_.add_extraTypes(extraTypes); builder_.add_controlDepFor(controlDepFor); builder_.add_varControlDeps(varControlDeps); builder_.add_controlDeps(controlDeps); @@ -311,7 +322,8 @@ inline flatbuffers::Offset CreateFlatNodeDirect( flatbuffers::Offset scalar = 0, const std::vector> *controlDeps = nullptr, const std::vector> *varControlDeps = nullptr, - const std::vector> *controlDepFor = nullptr) { + const std::vector> *controlDepFor = nullptr, + const std::vector *extraTypes = nullptr) { return nd4j::graph::CreateFlatNode( _fbb, id, @@ -335,7 +347,8 @@ inline flatbuffers::Offset CreateFlatNodeDirect( scalar, controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, varControlDeps ? _fbb.CreateVector>(*varControlDeps) : 0, - controlDepFor ? _fbb.CreateVector>(*controlDepFor) : 0); + controlDepFor ? _fbb.CreateVector>(*controlDepFor) : 0, + extraTypes ? _fbb.CreateVector(*extraTypes) : 0); } inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) { diff --git a/libnd4j/include/graph/generated/node_generated.js b/libnd4j/include/graph/generated/node_generated.js index dd83c4356..3f831a582 100644 --- a/libnd4j/include/graph/generated/node_generated.js +++ b/libnd4j/include/graph/generated/node_generated.js @@ -398,11 +398,36 @@ nd4j.graph.FlatNode.prototype.controlDepForLength = function() { return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; }; +/** + * @param {number} index + * @returns {nd4j.graph.DType} + */ +nd4j.graph.FlatNode.prototype.extraTypes = function(index) { + var offset = this.bb.__offset(this.bb_pos, 48); + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0); +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatNode.prototype.extraTypesLength = function() { + var offset = this.bb.__offset(this.bb_pos, 48); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @returns {Int8Array} + */ +nd4j.graph.FlatNode.prototype.extraTypesArray = function() { + var offset = this.bb.__offset(this.bb_pos, 48); + return offset ? new Int8Array(this.bb.bytes().buffer, this.bb.bytes().byteOffset + this.bb.__vector(this.bb_pos + offset), this.bb.__vector_len(this.bb_pos + offset)) : null; +}; + /** * @param {flatbuffers.Builder} builder */ nd4j.graph.FlatNode.startFlatNode = function(builder) { - builder.startObject(22); + builder.startObject(23); }; /** @@ -854,6 +879,35 @@ nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) { builder.startVector(4, numElems, 4); }; +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} extraTypesOffset + */ +nd4j.graph.FlatNode.addExtraTypes = function(builder, extraTypesOffset) { + builder.addFieldOffset(22, extraTypesOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatNode.createExtraTypesVector = function(builder, data) { + builder.startVector(1, data.length, 1); + for (var i = data.length - 1; i >= 0; i--) { + builder.addInt8(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatNode.startExtraTypesVector = function(builder, numElems) { + builder.startVector(1, numElems, 1); +}; + /** * @param {flatbuffers.Builder} builder * @returns {flatbuffers.Offset} diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 47c31cdf7..4c79ccb3e 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -587,9 +587,9 @@ namespace nd4j { block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); } - if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { - for (int e = 0; e < (int) node->outputTypes()->size(); e++) { - block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e)); } } @@ -624,9 +624,9 @@ namespace nd4j { block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); } - if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { - for (int e = 0; e < (int) node->outputTypes()->size(); e++) { - block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e)); } } @@ -664,9 +664,9 @@ namespace nd4j { block->getBArguments()->push_back(node->extraBools()->Get(e)); } - if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { - for (int e = 0; e < (int) node->outputTypes()->size(); e++) { - block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e)); } } diff --git a/libnd4j/include/graph/scheme/node.fbs b/libnd4j/include/graph/scheme/node.fbs index 92975e216..8e63186f5 100644 --- a/libnd4j/include/graph/scheme/node.fbs +++ b/libnd4j/include/graph/scheme/node.fbs @@ -57,7 +57,9 @@ table FlatNode { controlDeps:[string]; varControlDeps:[string]; controlDepFor:[string]; - + + // DArgs + extraTypes:[DType]; } root_type FlatNode; \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp index a20c0110b..702aa6711 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(ones_as, 1, 1, false) { + CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) { auto output = OUTPUT_VARIABLE(0); output->assign(1); @@ -33,11 +33,21 @@ namespace nd4j { return Status::OK(); } + DECLARE_SHAPE_FN(ones_as) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); + auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); + + nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str()); + + return SHAPELIST(shape); + } + DECLARE_TYPES(ones_as) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::ANY) - ->setSameMode(true); + ->setSameMode(false); } } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp index 6b461043a..56b4264d0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(zeros_as, 1, 1, false) { + CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) { auto out = OUTPUT_VARIABLE(0); out->assign(0); // output is filled by zero by default @@ -35,11 +35,20 @@ namespace nd4j { DECLARE_SYN(zeroslike, zeros_as); DECLARE_SYN(zeros_like, zeros_as); + + DECLARE_SHAPE_FN(zeros_as) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); + auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); + + return SHAPELIST(shape); + } + DECLARE_TYPES(zeros_as) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::ANY) - ->setSameMode(true); + ->setSameMode(false); } } } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 791027baa..c5d0ff207 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -487,7 +487,7 @@ namespace nd4j { * */ #if NOT_EXCLUDED(OP_zeros_as) - DECLARE_OP(zeros_as, 1, 1, false); + DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0); #endif /** @@ -497,7 +497,7 @@ namespace nd4j { * */ #if NOT_EXCLUDED(OP_ones_as) - DECLARE_OP(ones_as, 1, 1, false); + DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0); #endif /** diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index cff57b62d..a85772cec 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) { double tArgs[] = { -1.0, 1.0, 0.01 }; - auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); + auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0); shape::printShapeInfoLinear("Result", shapes->at(0)); ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 05c21a8f0..002a31d6e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -2978,6 +2978,24 @@ TEST_F(DeclarableOpsTests8, ones_as_test2) { delete results; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, ones_as_test3) { + + auto x = NDArrayFactory::create(10.); + //auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); + + nd4j::ops::ones_as op; + + auto results = op.evaluate({&x}, {}, {}, {}, {nd4j::DataType::INT32}); + ASSERT_EQ(Status::OK(), results->status()); + auto y = results->at(0); + ASSERT_TRUE(y->isSameShape(exp)); + ASSERT_TRUE(y->equalsTo(exp)); + + delete results; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index f058d9112..ee828a6e2 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -112,7 +112,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) { Nd4jLong iArgs[] = {1}; auto hash = op.getOpHash(); - auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0); + auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0); ASSERT_EQ(3, shapeList->size()); @@ -1065,7 +1065,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) { NDArray::prepareSpecialUse({}, {&arrayX, &arrayY}); nd4j::ops::greater_equal op; - auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0); + auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); NDArray::registerSpecialUse({}, {&arrayX, &arrayY}); delete shapeList; } diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index 42eb50be0..0306fb555 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -1579,7 +1579,7 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) { #endif auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast(tArgs.data()), tArgs.size(), - const_cast(iArgs.data()), iArgs.size(), nullptr, bArgsF.size()); + const_cast(iArgs.data()), iArgs.size(), nullptr, bArgsF.size(), nullptr, 0); // Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs ASSERT_EQ(1, shapeList->size()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index f3e0510cb..de421b297 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -4704,7 +4704,7 @@ public class SameDiff extends SDBaseOps { 0, 0, -1, - 0, 0, 0, 0, 0, 0, 0, 0, 0); + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); return flatNode; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index a88a9c84f..d87a890ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -17,6 +17,7 @@ package org.nd4j.autodiff.samediff.serde; import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.guava.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import java.nio.ByteOrder; @@ -361,6 +362,11 @@ public class FlatBuffersMapper { for (int i = 0; i < extraBools.length; i++) { extraBools[i] = fn.extraBools(i); } + DataType[] extraDTypes = new DataType[fn.extraTypesLength()]; + for (int i = 0; i < extraDTypes.length; i++) { + extraDTypes[i] = DataType.fromInt(fn.extraTypes(i)); + } + int[] dimensions = new int[fn.dimensionsLength()]; for (int i = 0; i < dimensions.length; i++) { dimensions[i] = fn.dimensions(i); @@ -401,6 +407,7 @@ public class FlatBuffersMapper { ((CustomOp) op).addIArgument(extraInteger); ((CustomOp) op).addTArgument(extraParams); ((CustomOp) op).addBArgument(extraBools); + ((CustomOp) op).addDArgument(extraDTypes); op.setPropertiesForFunction(props); return op; @@ -714,11 +721,20 @@ public class FlatBuffersMapper { } boolean[] boolArgs = null; + byte[] dtypeArgs = null; long[] extraBits = null; if (node.opType() == Op.Type.CUSTOM) { - DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) node; + val dynamicCustomOp = (DynamicCustomOp) node; extraBits = dynamicCustomOp.iArgs(); boolArgs = dynamicCustomOp.bArgs(); + + if (dynamicCustomOp.numDArguments() > 0) { + dtypeArgs = new byte[dynamicCustomOp.numDArguments()]; + val d = dynamicCustomOp.dArgs(); + for (int e = 0; e < dtypeArgs.length; e++) { + dtypeArgs[e] = (byte) d[e].toInt(); + } + } } else if (node instanceof Enter) { // in case of Enter node we'll be storing unique frame reference val frameName = ((Enter) node).getFrameName(); @@ -817,6 +833,7 @@ public class FlatBuffersMapper { int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras); int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits); int bArgs = FlatNode.createExtraBoolsVector(bufferBuilder, boolArgs != null ? boolArgs : new boolean[0]); + int dArgs = FlatNode.createOutputTypesVector(bufferBuilder, dtypeArgs != null ? dtypeArgs : new byte[0]); int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims); int fname = bufferBuilder.createString(node.getOwnName()); int scopeName = bufferBuilder.createString(""); @@ -896,7 +913,8 @@ public class FlatBuffersMapper { scalar, opCds, varCds, - cdsFor + cdsFor, + dArgs ); return flatNode; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java index ca411435d..fdda379b7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java @@ -73,6 +73,10 @@ public final class FlatNode extends Table { public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } + public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; } + public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; } + public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); } + public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); } public static int createFlatNode(FlatBufferBuilder builder, int id, @@ -96,9 +100,11 @@ public final class FlatNode extends Table { int scalarOffset, int controlDepsOffset, int varControlDepsOffset, - int controlDepForOffset) { - builder.startObject(22); + int controlDepForOffset, + int extraTypesOffset) { + builder.startObject(23); FlatNode.addOpNum(builder, opNum); + FlatNode.addExtraTypes(builder, extraTypesOffset); FlatNode.addControlDepFor(builder, controlDepForOffset); FlatNode.addVarControlDeps(builder, varControlDepsOffset); FlatNode.addControlDeps(builder, controlDepsOffset); @@ -123,7 +129,7 @@ public final class FlatNode extends Table { return FlatNode.endFlatNode(builder); } - public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); } + public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } @@ -172,6 +178,9 @@ public final class FlatNode extends Table { public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); } + public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } + public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static int endFlatNode(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index 05ac2495c..cf82510c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -61,6 +61,7 @@ public class DifferentialFunctionClassHolder { add("tArguments"); add("iArguments"); add("bArguments"); + add("dArguments"); add("hash"); add("opName"); add("sameDiff"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java index 050868b36..4a56e2a88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java @@ -20,6 +20,7 @@ import lombok.Getter; import lombok.NonNull; import lombok.Setter; import lombok.val; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; @@ -33,9 +34,10 @@ public abstract class BaseOpContext implements OpContext { protected Map fastpath_in = new HashMap<>(); protected Map fastpath_out = new HashMap<>(); - protected List fastpath_d = new ArrayList<>(); + protected List fastpath_t = new ArrayList<>(); protected List fastpath_b = new ArrayList<>(); protected List fastpath_i = new ArrayList<>(); + protected List fastpath_d = new ArrayList<>(); @Setter() @Getter @@ -55,14 +57,14 @@ public abstract class BaseOpContext implements OpContext { @Override public void setTArguments(double... arguments) { - fastpath_d.clear(); + fastpath_t.clear(); for (val v:arguments) - fastpath_d.add(v); + fastpath_t.add(v); } @Override public List getTArguments(){ - return fastpath_d; + return fastpath_t; } @Override @@ -77,6 +79,18 @@ public abstract class BaseOpContext implements OpContext { return fastpath_b; } + @Override + public void setDArguments(DataType... arguments) { + fastpath_d.clear(); + for (val v:arguments) + fastpath_d.add(v); + } + + @Override + public List getDArguments() { + return fastpath_d; + } + @Override public void setInputArray(int index, @NonNull INDArray array) { fastpath_in.put(index, array); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java index cfa4f9b75..befdfb605 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; @@ -57,12 +58,18 @@ public interface CustomOp { boolean[] bArgs(); + DataType[] dArgs(); + + void addTArgument(double... arg); + void addIArgument(int... arg); void addIArgument(long... arg); void addBArgument(boolean... arg); + void addDArgument(DataType... arg); + void removeIArgument(Integer arg); Boolean getBArgument(int index); @@ -71,8 +78,6 @@ public interface CustomOp { int numIArguments(); - void addTArgument(double... arg); - void removeTArgument(Double arg); Double getTArgument(int index); @@ -81,6 +86,8 @@ public interface CustomOp { int numBArguments(); + int numDArguments(); + void addInputArgument(INDArray... arg); void removeInputArgument(INDArray arg); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index e46dfab4b..f4116ba3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.guava.collect.Lists; import org.nd4j.shade.guava.primitives.Doubles; import org.nd4j.shade.guava.primitives.Longs; @@ -62,6 +63,9 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Builder.Default protected List bArguments = new ArrayList<>(); + @Builder.Default + protected List dArguments = new ArrayList<>(); + @Builder.Default protected List axis = new ArrayList<>(); @@ -77,6 +81,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { iArguments = new ArrayList<>(); tArguments = new ArrayList<>(); bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); } public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) { @@ -93,6 +98,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { iArguments = new ArrayList<>(); tArguments = new ArrayList<>(); bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); } public DynamicCustomOp(String opName, INDArray input, INDArray output, List tArguments, int[] iArguments) { @@ -132,6 +138,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { this.iArguments.add((Long) a.longValue()); } bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); } /** @@ -173,6 +180,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { iArguments = new ArrayList<>(); tArguments = new ArrayList<>(); bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); this.inplaceCall = inPlace; } @@ -185,6 +193,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { iArguments = new ArrayList<>(); tArguments = new ArrayList<>(); bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); } @@ -260,6 +269,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { return hash; } + @Override + public int numDArguments() { + return dArguments.size(); + } + @Override public List outputArguments() { return outputArguments; @@ -280,6 +294,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { return Doubles.toArray(tArguments); } + @Override + public DataType[] dArgs() { + return dArguments.toArray(new DataType[dArguments.size()]); + } + @Override public void addIArgument(int... arg) { for (long a: arg) @@ -323,6 +342,15 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { addTArgument(Doubles.asList(arg).toArray(new Double[arg.length])); } + @Override + public void addDArgument(DataType... arg) { + if (dArguments == null) + dArguments = new ArrayList<>(); + + if (arg != null) + dArguments.addAll(Arrays.asList(arg)); + } + private void addTArgument(Double... arg) { tArguments.addAll(Arrays.asList(arg)); } @@ -650,6 +678,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { private List outputArguments = new ArrayList<>(); private List tArguments = new ArrayList<>(); private List iArguments = new ArrayList<>(); + private List dArguments = new ArrayList<>(); private List bArguments = new ArrayList<>(); protected DynamicCustomOpsBuilder(String opName, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) { @@ -870,6 +899,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { result.iArguments = iArguments; result.tArguments = tArguments; result.bArguments = bArguments; + result.dArguments = dArguments; result.inplaceCall = inplaceCall; result.hash = opHash; result.outputShapes = outputShapes; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index 3deefe7c0..4063746b3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops; import org.bytedeco.javacpp.Pointer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -43,9 +44,15 @@ public interface OpContext extends AutoCloseable { * @param arguments */ void setTArguments(double... arguments); - List getTArguments(); + /** + * This method sets data type arguments required for operation + * @param arguments + */ + void setDArguments(DataType... arguments); + List getDArguments(); + /** * This method sets boolean arguments required for operation * @param arguments diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java index 83020cb57..313b7ccb4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.custom; import lombok.NonNull; import lombok.val; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOpDescriptor; @@ -246,6 +247,21 @@ public class ScatterUpdate implements CustomOp { } + @Override + public DataType[] dArgs() { + return new DataType[0]; + } + + @Override + public void addDArgument(DataType... arg) { + + } + + @Override + public int numDArguments() { + return 0; + } + @Override public void clearArrays() { op.clearArrays(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java index d442e4623..beb9d09b9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java @@ -83,6 +83,9 @@ public class OneHot extends DynamicCustomOp { addIArgument(depth); addTArgument(on); addTArgument(off); + + if (outputType != null) + addDArgument(outputType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java index a8f49bdf2..4b4b3e578 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -53,6 +55,22 @@ public class OnesLike extends DynamicCustomOp { public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) { super(name, sameDiff, new SDVariable[]{input}, false); this.outputType = dataType; + addArgs(); + } + + public OnesLike(@NonNull INDArray input, DataType dataType) { + this.addInputArgument(input); + this.outputType = dataType; + addArgs(); + } + + public OnesLike(@NonNull INDArray input) { + this(input, input.dataType()); + } + + public void addArgs() { + if (outputType != null) + addDArgument(outputType); } @@ -78,6 +96,8 @@ public class OnesLike extends DynamicCustomOp { if(attributesForNode.containsKey("T")) { outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType()); } + + addArgs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index f0bf4bc5e..d32aff5b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -3438,6 +3438,16 @@ public class Nd4j { return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length); } + /** + * Create 2D double array based on java 2d double array. and ordering + * + * @param data the data to use + * @return the created ndarray. + */ + public static INDArray create(int[][] data) { + return createFromArray(data); + } + /** * create 3D int array based on 3D java int array. * @param data java 3D i array. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index 95c97068e..d284974eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1056,7 +1056,7 @@ public interface NativeOps { OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs); - OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs); + OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, @Cast("int *") IntPointer dArgs, int numDArgs); long getShapeListSize(OpaqueShapeList list); LongPointer getShape(OpaqueShapeList list, long i); @@ -1156,6 +1156,7 @@ public interface NativeOps { void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); + void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments); void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow); void ctxSetExecutionMode(OpaqueContext ptr, int execMode); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 04b86dc02..f18bd1459 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1928,6 +1928,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null; + val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null; + cnt = 0; for (val b: op.bArgs()) bArgs.put(cnt++, b); @@ -1936,7 +1938,12 @@ public class CudaExecutioner extends DefaultOpExecutioner { for (val t: op.tArgs()) tArgs.put(cnt++, t); - OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); + cnt = 0; + val dArgs1 = op.dArgs(); + for (val d: dArgs1) + dArgs.put(cnt++, d.toInt()); + + OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments()); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); @@ -2003,6 +2010,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { context.setBArguments(op.bArgs()); context.setIArguments(op.iArgs()); context.setTArguments(op.tArgs()); + context.setDArguments(op.dArgs()); val result = exec(op, context); val states = context.getRngStates(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 487f38232..01127e891 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -18,12 +18,10 @@ package org.nd4j.linalg.jcublas.ops.executioner; import lombok.NonNull; import lombok.val; -import org.bytedeco.javacpp.BooleanPointer; -import org.bytedeco.javacpp.DoublePointer; -import org.bytedeco.javacpp.LongPointer; -import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.*; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; @@ -76,6 +74,18 @@ public class CudaOpContext extends BaseOpContext implements OpContext { } } + @Override + public void setDArguments(DataType... arguments) { + if (arguments.length > 0) { + super.setDArguments(arguments); + val args = new int[arguments.length]; + for (int e = 0; e < arguments.length; e++) + args[e] = arguments[e].toInt(); + + nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length); + }; + } + @Override public void setRngStates(long rootState, long nodeState) { nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index d1840ab63..8d0029bc3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -2984,12 +2984,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs); public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list); public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 1863d6c1c..461646311 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -17,10 +17,9 @@ package org.nd4j.linalg.cpu.nativecpu.ops; import lombok.NonNull; -import org.bytedeco.javacpp.BooleanPointer; -import org.bytedeco.javacpp.DoublePointer; -import org.bytedeco.javacpp.LongPointer; -import org.bytedeco.javacpp.Pointer; +import lombok.val; +import org.bytedeco.javacpp.*; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.ExecutionMode; @@ -73,6 +72,18 @@ public class CpuOpContext extends BaseOpContext implements OpContext { }; } + @Override + public void setDArguments(DataType... arguments) { + if (arguments.length > 0) { + super.setDArguments(arguments); + val args = new int[arguments.length]; + for (int e = 0; e < arguments.length; e++) + args[e] = arguments[e].toInt(); + + nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length); + }; + } + @Override public void setRngStates(long rootState, long nodeState) { nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index ebeab58f4..cc3d17b5f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -1636,6 +1636,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { context.setBArguments(op.bArgs()); context.setIArguments(op.iArgs()); context.setTArguments(op.tArgs()); + context.setDArguments(op.dArgs()); val result = exec(op, context); val states = context.getRngStates(); @@ -1712,6 +1713,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val bArgs = op.numBArguments() > 0 ? new BooleanPointer(op.numBArguments()) : null; + val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null; + cnt = 0; val bArgs1 = op.bArgs(); for (val b: bArgs1) @@ -1722,11 +1725,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { for (val t: tArgs1) tArgs.put(cnt++, t); + cnt = 0; + val dArgs1 = op.dArgs(); + for (val d: dArgs1) + dArgs.put(cnt++, d.toInt()); + + OpaqueShapeList ptrptr; try { ptrptr = loop.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs, - op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments()); + op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments(), dArgs, op.numDArguments()); if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 40a2f5236..93fbb71d7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -2987,12 +2987,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs); public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list); public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i); @@ -17951,7 +17951,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * */ // #if NOT_EXCLUDED(OP_zeros_as) - @Namespace("nd4j::ops") public static class zeros_as extends DeclarableOp { + @Namespace("nd4j::ops") public static class zeros_as extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public zeros_as(Pointer p) { super(p); } @@ -17962,10 +17962,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (zeros_as)super.position(position); } - public zeros_as() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public zeros_as() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -17975,7 +17975,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * */ // #if NOT_EXCLUDED(OP_ones_as) - @Namespace("nd4j::ops") public static class ones_as extends DeclarableOp { + @Namespace("nd4j::ops") public static class ones_as extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public ones_as(Pointer p) { super(p); } @@ -17986,10 +17986,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (ones_as)super.position(position); } - public ones_as() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public ones_as() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index e02e4b91d..59932d670 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -1169,6 +1169,26 @@ public class MiscOpValidation extends BaseOpValidation { assertNull(err); } + @Test + public void testOneHot4() { + + INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); + + SameDiff sd = SameDiff.create(); + SDVariable indices = sd.constant("indices", indicesArr); + int depth = 3; + int axis = -1; + SDVariable oneHot = sd.oneHot("oneHot", indices, depth, axis, 5.0, 0.0, DataType.INT32); + + INDArray exp = Nd4j.create(new int[][]{{5, 0, 0}, {0,0,5}, {0,0,0}, {0, 5, 0}}); + + String err = OpValidation.validate(new TestCase(sd) + .expected(oneHot, exp) + .gradientCheck(false)); + + assertNull(err); + } + @Test public void testOneHot3() { //https://github.com/deeplearning4j/deeplearning4j/issues/6872 @@ -1204,8 +1224,6 @@ public class MiscOpValidation extends BaseOpValidation { assertNull(err); } - - @Test public void testLinspace(){ SameDiff sd = SameDiff.create(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index e7160a1d8..e9d2979c6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeArea; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; +import org.nd4j.linalg.api.ops.impl.shape.OnesLike; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; @@ -1673,4 +1674,13 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + + @Test + public void testOnesLike_1() { + val x = Nd4j.create(DataType.FLOAT, 3, 4, 5); + val e = Nd4j.ones(DataType.INT32, 3, 4, 5); + + val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0]; + assertEquals(e, z); + } }