Configurable DataType for ops (#201)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * - one more test for OneHot with dtype - one more signature in Nd4j Signed-off-by: raver119 <raver119@gmail.com> * ones_as/zeros_as now accept dtype Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * - more updates for configurable data types - ones_as/zeros_as java side + tests Signed-off-by: raver119 <raver119@gmail.com> * few c++ tests fixed Signed-off-by: raver119 <raver119@gmail.com> * few more changes around DArgs Signed-off-by: raver119 <raver119@gmail.com>master
parent
ba961c7601
commit
5d98cfcf47
|
@ -1518,7 +1518,7 @@ ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPoi
|
||||||
typedef nd4j::ShapeList OpaqueShapeList;
|
typedef nd4j::ShapeList OpaqueShapeList;
|
||||||
|
|
||||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
||||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs);
|
||||||
|
|
||||||
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
|
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
|
||||||
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
|
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
|
||||||
|
|
|
@ -1974,7 +1974,7 @@ void deleteShapeList(Nd4jPointer shapeList) {
|
||||||
delete list;
|
delete list;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||||
nd4j::graph::VariableSpace varSpace;
|
nd4j::graph::VariableSpace varSpace;
|
||||||
Context block(2, &varSpace);
|
Context block(2, &varSpace);
|
||||||
nd4j::ShapeList inShapes;
|
nd4j::ShapeList inShapes;
|
||||||
|
@ -1988,6 +1988,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
for (int e = 0; e < numBArgs; e++)
|
for (int e = 0; e < numBArgs; e++)
|
||||||
block.getBArguments()->push_back(bArgs[e]);
|
block.getBArguments()->push_back(bArgs[e]);
|
||||||
|
|
||||||
|
for (int e = 0; e < numDArgs; e++)
|
||||||
|
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
|
||||||
|
|
||||||
for (int e = 0; e < numInputShapes; e++) {
|
for (int e = 0; e < numInputShapes; e++) {
|
||||||
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
||||||
|
|
||||||
|
@ -2015,11 +2018,11 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||||
try {
|
try {
|
||||||
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||||
|
|
||||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs);
|
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
|
|
@ -2684,7 +2684,7 @@ const char* getAllCustomOps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||||
nd4j::graph::VariableSpace varSpace;
|
nd4j::graph::VariableSpace varSpace;
|
||||||
Context block(2, &varSpace);
|
Context block(2, &varSpace);
|
||||||
nd4j::ShapeList inShapes;
|
nd4j::ShapeList inShapes;
|
||||||
|
@ -2698,6 +2698,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
for (int e = 0; e < numBArgs; e++)
|
for (int e = 0; e < numBArgs; e++)
|
||||||
block.getBArguments()->push_back(bArgs[e]);
|
block.getBArguments()->push_back(bArgs[e]);
|
||||||
|
|
||||||
|
for (int e = 0; e < numDArgs; e++)
|
||||||
|
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
|
||||||
|
|
||||||
for (int e = 0; e < numInputShapes; e++) {
|
for (int e = 0; e < numInputShapes; e++) {
|
||||||
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
||||||
|
|
||||||
|
@ -2722,12 +2725,12 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||||
try {
|
try {
|
||||||
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||||
|
|
||||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs,
|
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs,
|
||||||
iArgs, numIArgs, bArgs, numBArgs);
|
iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
|
|
@ -112,6 +112,14 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||||
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
public DType ExtraTypes(int j) { int o = __p.__offset(48); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
|
||||||
|
public int ExtraTypesLength { get { int o = __p.__offset(48); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
#if ENABLE_SPAN_T
|
||||||
|
public Span<byte> GetExtraTypesBytes() { return __p.__vector_as_span(48); }
|
||||||
|
#else
|
||||||
|
public ArraySegment<byte>? GetExtraTypesBytes() { return __p.__vector_as_arraysegment(48); }
|
||||||
|
#endif
|
||||||
|
public DType[] GetExtraTypesArray() { return __p.__vector_as_array<DType>(48); }
|
||||||
|
|
||||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||||
int id = 0,
|
int id = 0,
|
||||||
|
@ -135,9 +143,11 @@ public struct FlatNode : IFlatbufferObject
|
||||||
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
|
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
|
||||||
VectorOffset controlDepsOffset = default(VectorOffset),
|
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||||
VectorOffset varControlDepsOffset = default(VectorOffset),
|
VectorOffset varControlDepsOffset = default(VectorOffset),
|
||||||
VectorOffset controlDepForOffset = default(VectorOffset)) {
|
VectorOffset controlDepForOffset = default(VectorOffset),
|
||||||
builder.StartObject(22);
|
VectorOffset extraTypesOffset = default(VectorOffset)) {
|
||||||
|
builder.StartObject(23);
|
||||||
FlatNode.AddOpNum(builder, opNum);
|
FlatNode.AddOpNum(builder, opNum);
|
||||||
|
FlatNode.AddExtraTypes(builder, extraTypesOffset);
|
||||||
FlatNode.AddControlDepFor(builder, controlDepForOffset);
|
FlatNode.AddControlDepFor(builder, controlDepForOffset);
|
||||||
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
|
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
|
||||||
FlatNode.AddControlDeps(builder, controlDepsOffset);
|
FlatNode.AddControlDeps(builder, controlDepsOffset);
|
||||||
|
@ -162,7 +172,7 @@ public struct FlatNode : IFlatbufferObject
|
||||||
return FlatNode.EndFlatNode(builder);
|
return FlatNode.EndFlatNode(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); }
|
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(23); }
|
||||||
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
|
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
|
||||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||||
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
|
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
|
||||||
|
@ -224,6 +234,10 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||||
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
|
public static void AddExtraTypes(FlatBufferBuilder builder, VectorOffset extraTypesOffset) { builder.AddOffset(22, extraTypesOffset.Value, 0); }
|
||||||
|
public static VectorOffset CreateExtraTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||||
|
public static VectorOffset CreateExtraTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||||
|
public static void StartExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||||
int o = builder.EndObject();
|
int o = builder.EndObject();
|
||||||
return new Offset<FlatNode>(o);
|
return new Offset<FlatNode>(o);
|
||||||
|
|
|
@ -72,6 +72,10 @@ public final class FlatNode extends Table {
|
||||||
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||||
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; }
|
||||||
|
public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); }
|
||||||
|
public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); }
|
||||||
|
|
||||||
public static int createFlatNode(FlatBufferBuilder builder,
|
public static int createFlatNode(FlatBufferBuilder builder,
|
||||||
int id,
|
int id,
|
||||||
|
@ -95,9 +99,11 @@ public final class FlatNode extends Table {
|
||||||
int scalarOffset,
|
int scalarOffset,
|
||||||
int controlDepsOffset,
|
int controlDepsOffset,
|
||||||
int varControlDepsOffset,
|
int varControlDepsOffset,
|
||||||
int controlDepForOffset) {
|
int controlDepForOffset,
|
||||||
builder.startObject(22);
|
int extraTypesOffset) {
|
||||||
|
builder.startObject(23);
|
||||||
FlatNode.addOpNum(builder, opNum);
|
FlatNode.addOpNum(builder, opNum);
|
||||||
|
FlatNode.addExtraTypes(builder, extraTypesOffset);
|
||||||
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||||
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||||
FlatNode.addControlDeps(builder, controlDepsOffset);
|
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||||
|
@ -122,7 +128,7 @@ public final class FlatNode extends Table {
|
||||||
return FlatNode.endFlatNode(builder);
|
return FlatNode.endFlatNode(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); }
|
||||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||||
|
@ -171,6 +177,9 @@ public final class FlatNode extends Table {
|
||||||
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||||
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); }
|
||||||
|
public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||||
|
public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||||
int o = builder.endObject();
|
int o = builder.endObject();
|
||||||
return o;
|
return o;
|
||||||
|
|
|
@ -339,7 +339,29 @@ class FlatNode(object):
|
||||||
return self._tab.VectorLen(o)
|
return self._tab.VectorLen(o)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def FlatNodeStart(builder): builder.StartObject(22)
|
# FlatNode
|
||||||
|
def ExtraTypes(self, j):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||||
|
if o != 0:
|
||||||
|
a = self._tab.Vector(o)
|
||||||
|
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def ExtraTypesAsNumpy(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def ExtraTypesLength(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.VectorLen(o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def FlatNodeStart(builder): builder.StartObject(23)
|
||||||
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
|
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
|
||||||
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||||
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
|
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
|
||||||
|
@ -375,4 +397,6 @@ def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTR
|
||||||
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
|
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
|
||||||
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
|
def FlatNodeAddExtraTypes(builder, extraTypes): builder.PrependUOffsetTRelativeSlot(22, flatbuffers.number_types.UOffsetTFlags.py_type(extraTypes), 0)
|
||||||
|
def FlatNodeStartExtraTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
|
||||||
def FlatNodeEnd(builder): return builder.EndObject()
|
def FlatNodeEnd(builder): return builder.EndObject()
|
||||||
|
|
|
@ -38,7 +38,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VT_SCALAR = 40,
|
VT_SCALAR = 40,
|
||||||
VT_CONTROLDEPS = 42,
|
VT_CONTROLDEPS = 42,
|
||||||
VT_VARCONTROLDEPS = 44,
|
VT_VARCONTROLDEPS = 44,
|
||||||
VT_CONTROLDEPFOR = 46
|
VT_CONTROLDEPFOR = 46,
|
||||||
|
VT_EXTRATYPES = 48
|
||||||
};
|
};
|
||||||
int32_t id() const {
|
int32_t id() const {
|
||||||
return GetField<int32_t>(VT_ID, 0);
|
return GetField<int32_t>(VT_ID, 0);
|
||||||
|
@ -106,6 +107,9 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
|
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
|
||||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
|
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
|
||||||
}
|
}
|
||||||
|
const flatbuffers::Vector<int8_t> *extraTypes() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_EXTRATYPES);
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int32_t>(verifier, VT_ID) &&
|
VerifyField<int32_t>(verifier, VT_ID) &&
|
||||||
|
@ -153,6 +157,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
|
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
|
||||||
verifier.VerifyVector(controlDepFor()) &&
|
verifier.VerifyVector(controlDepFor()) &&
|
||||||
verifier.VerifyVectorOfStrings(controlDepFor()) &&
|
verifier.VerifyVectorOfStrings(controlDepFor()) &&
|
||||||
|
VerifyOffset(verifier, VT_EXTRATYPES) &&
|
||||||
|
verifier.VerifyVector(extraTypes()) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -226,6 +232,9 @@ struct FlatNodeBuilder {
|
||||||
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
|
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
|
||||||
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
|
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
|
||||||
}
|
}
|
||||||
|
void add_extraTypes(flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes) {
|
||||||
|
fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes);
|
||||||
|
}
|
||||||
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
|
@ -261,9 +270,11 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNode(
|
||||||
flatbuffers::Offset<FlatArray> scalar = 0,
|
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0) {
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes = 0) {
|
||||||
FlatNodeBuilder builder_(_fbb);
|
FlatNodeBuilder builder_(_fbb);
|
||||||
builder_.add_opNum(opNum);
|
builder_.add_opNum(opNum);
|
||||||
|
builder_.add_extraTypes(extraTypes);
|
||||||
builder_.add_controlDepFor(controlDepFor);
|
builder_.add_controlDepFor(controlDepFor);
|
||||||
builder_.add_varControlDeps(varControlDeps);
|
builder_.add_varControlDeps(varControlDeps);
|
||||||
builder_.add_controlDeps(controlDeps);
|
builder_.add_controlDeps(controlDeps);
|
||||||
|
@ -311,7 +322,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
||||||
flatbuffers::Offset<FlatArray> scalar = 0,
|
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr) {
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr,
|
||||||
|
const std::vector<int8_t> *extraTypes = nullptr) {
|
||||||
return nd4j::graph::CreateFlatNode(
|
return nd4j::graph::CreateFlatNode(
|
||||||
_fbb,
|
_fbb,
|
||||||
id,
|
id,
|
||||||
|
@ -335,7 +347,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
||||||
scalar,
|
scalar,
|
||||||
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
||||||
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
|
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
|
||||||
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0);
|
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0,
|
||||||
|
extraTypes ? _fbb.CreateVector<int8_t>(*extraTypes) : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {
|
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {
|
||||||
|
|
|
@ -398,11 +398,36 @@ nd4j.graph.FlatNode.prototype.controlDepForLength = function() {
|
||||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {number} index
|
||||||
|
* @returns {nd4j.graph.DType}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.extraTypes = function(index) {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||||
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {number}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.extraTypesLength = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||||
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {Int8Array}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.extraTypesArray = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||||
|
return offset ? new Int8Array(this.bb.bytes().buffer, this.bb.bytes().byteOffset + this.bb.__vector(this.bb_pos + offset), this.bb.__vector_len(this.bb_pos + offset)) : null;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatNode.startFlatNode = function(builder) {
|
nd4j.graph.FlatNode.startFlatNode = function(builder) {
|
||||||
builder.startObject(22);
|
builder.startObject(23);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -854,6 +879,35 @@ nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) {
|
||||||
builder.startVector(4, numElems, 4);
|
builder.startVector(4, numElems, 4);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {flatbuffers.Offset} extraTypesOffset
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.addExtraTypes = function(builder, extraTypesOffset) {
|
||||||
|
builder.addFieldOffset(22, extraTypesOffset, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {Array.<nd4j.graph.DType>} data
|
||||||
|
* @returns {flatbuffers.Offset}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.createExtraTypesVector = function(builder, data) {
|
||||||
|
builder.startVector(1, data.length, 1);
|
||||||
|
for (var i = data.length - 1; i >= 0; i--) {
|
||||||
|
builder.addInt8(data[i]);
|
||||||
|
}
|
||||||
|
return builder.endVector();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {number} numElems
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.startExtraTypesVector = function(builder, numElems) {
|
||||||
|
builder.startVector(1, numElems, 1);
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @returns {flatbuffers.Offset}
|
* @returns {flatbuffers.Offset}
|
||||||
|
|
|
@ -587,9 +587,9 @@ namespace nd4j {
|
||||||
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||||
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||||
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -624,9 +624,9 @@ namespace nd4j {
|
||||||
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||||
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||||
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -664,9 +664,9 @@ namespace nd4j {
|
||||||
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||||
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||||
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,9 @@ table FlatNode {
|
||||||
controlDeps:[string];
|
controlDeps:[string];
|
||||||
varControlDeps:[string];
|
varControlDeps:[string];
|
||||||
controlDepFor:[string];
|
controlDepFor:[string];
|
||||||
|
|
||||||
|
// DArgs
|
||||||
|
extraTypes:[DType];
|
||||||
}
|
}
|
||||||
|
|
||||||
root_type FlatNode;
|
root_type FlatNode;
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(ones_as, 1, 1, false) {
|
CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) {
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
output->assign(1);
|
output->assign(1);
|
||||||
|
@ -33,11 +33,21 @@ namespace nd4j {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(ones_as) {
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
|
||||||
|
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
|
||||||
|
|
||||||
|
nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str());
|
||||||
|
|
||||||
|
return SHAPELIST(shape);
|
||||||
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(ones_as) {
|
DECLARE_TYPES(ones_as) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
||||||
->setSameMode(true);
|
->setSameMode(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(zeros_as, 1, 1, false) {
|
CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) {
|
||||||
auto out = OUTPUT_VARIABLE(0);
|
auto out = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
out->assign(0); // output is filled by zero by default
|
out->assign(0); // output is filled by zero by default
|
||||||
|
@ -35,11 +35,20 @@ namespace nd4j {
|
||||||
DECLARE_SYN(zeroslike, zeros_as);
|
DECLARE_SYN(zeroslike, zeros_as);
|
||||||
DECLARE_SYN(zeros_like, zeros_as);
|
DECLARE_SYN(zeros_like, zeros_as);
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(zeros_as) {
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
|
||||||
|
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
|
||||||
|
|
||||||
|
return SHAPELIST(shape);
|
||||||
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(zeros_as) {
|
DECLARE_TYPES(zeros_as) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
||||||
->setSameMode(true);
|
->setSameMode(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -487,7 +487,7 @@ namespace nd4j {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_zeros_as)
|
#if NOT_EXCLUDED(OP_zeros_as)
|
||||||
DECLARE_OP(zeros_as, 1, 1, false);
|
DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -497,7 +497,7 @@ namespace nd4j {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_ones_as)
|
#if NOT_EXCLUDED(OP_ones_as)
|
||||||
DECLARE_OP(ones_as, 1, 1, false);
|
DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
|
||||||
|
|
||||||
double tArgs[] = { -1.0, 1.0, 0.01 };
|
double tArgs[] = { -1.0, 1.0, 0.01 };
|
||||||
|
|
||||||
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0);
|
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||||
shape::printShapeInfoLinear("Result", shapes->at(0));
|
shape::printShapeInfoLinear("Result", shapes->at(0));
|
||||||
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));
|
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));
|
||||||
|
|
||||||
|
|
|
@ -2978,6 +2978,24 @@ TEST_F(DeclarableOpsTests8, ones_as_test2) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests8, ones_as_test3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>(10.);
|
||||||
|
//auto y = NDArrayFactory::create<double>(100.);
|
||||||
|
auto exp = NDArrayFactory::create<int>(1.);
|
||||||
|
|
||||||
|
nd4j::ops::ones_as op;
|
||||||
|
|
||||||
|
auto results = op.evaluate({&x}, {}, {}, {}, {nd4j::DataType::INT32});
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
auto y = results->at(0);
|
||||||
|
ASSERT_TRUE(y->isSameShape(exp));
|
||||||
|
ASSERT_TRUE(y->equalsTo(exp));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) {
|
TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) {
|
||||||
|
|
||||||
|
|
|
@ -112,7 +112,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
||||||
Nd4jLong iArgs[] = {1};
|
Nd4jLong iArgs[] = {1};
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
|
|
||||||
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0);
|
||||||
|
|
||||||
ASSERT_EQ(3, shapeList->size());
|
ASSERT_EQ(3, shapeList->size());
|
||||||
|
|
||||||
|
@ -1065,7 +1065,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({}, {&arrayX, &arrayY});
|
NDArray::prepareSpecialUse({}, {&arrayX, &arrayY});
|
||||||
nd4j::ops::greater_equal op;
|
nd4j::ops::greater_equal op;
|
||||||
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||||
NDArray::registerSpecialUse({}, {&arrayX, &arrayY});
|
NDArray::registerSpecialUse({}, {&arrayX, &arrayY});
|
||||||
delete shapeList;
|
delete shapeList;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1579,7 +1579,7 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast<double*>(tArgs.data()), tArgs.size(),
|
auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast<double*>(tArgs.data()), tArgs.size(),
|
||||||
const_cast<Nd4jLong*>(iArgs.data()), iArgs.size(), nullptr, bArgsF.size());
|
const_cast<Nd4jLong*>(iArgs.data()), iArgs.size(), nullptr, bArgsF.size(), nullptr, 0);
|
||||||
// Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs
|
// Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs
|
||||||
ASSERT_EQ(1, shapeList->size());
|
ASSERT_EQ(1, shapeList->size());
|
||||||
|
|
||||||
|
|
|
@ -4704,7 +4704,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
-1,
|
-1,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0);
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
|
||||||
return flatNode;
|
return flatNode;
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.autodiff.samediff.serde;
|
package org.nd4j.autodiff.samediff.serde;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
|
@ -361,6 +362,11 @@ public class FlatBuffersMapper {
|
||||||
for (int i = 0; i < extraBools.length; i++) {
|
for (int i = 0; i < extraBools.length; i++) {
|
||||||
extraBools[i] = fn.extraBools(i);
|
extraBools[i] = fn.extraBools(i);
|
||||||
}
|
}
|
||||||
|
DataType[] extraDTypes = new DataType[fn.extraTypesLength()];
|
||||||
|
for (int i = 0; i < extraDTypes.length; i++) {
|
||||||
|
extraDTypes[i] = DataType.fromInt(fn.extraTypes(i));
|
||||||
|
}
|
||||||
|
|
||||||
int[] dimensions = new int[fn.dimensionsLength()];
|
int[] dimensions = new int[fn.dimensionsLength()];
|
||||||
for (int i = 0; i < dimensions.length; i++) {
|
for (int i = 0; i < dimensions.length; i++) {
|
||||||
dimensions[i] = fn.dimensions(i);
|
dimensions[i] = fn.dimensions(i);
|
||||||
|
@ -401,6 +407,7 @@ public class FlatBuffersMapper {
|
||||||
((CustomOp) op).addIArgument(extraInteger);
|
((CustomOp) op).addIArgument(extraInteger);
|
||||||
((CustomOp) op).addTArgument(extraParams);
|
((CustomOp) op).addTArgument(extraParams);
|
||||||
((CustomOp) op).addBArgument(extraBools);
|
((CustomOp) op).addBArgument(extraBools);
|
||||||
|
((CustomOp) op).addDArgument(extraDTypes);
|
||||||
|
|
||||||
op.setPropertiesForFunction(props);
|
op.setPropertiesForFunction(props);
|
||||||
return op;
|
return op;
|
||||||
|
@ -714,11 +721,20 @@ public class FlatBuffersMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean[] boolArgs = null;
|
boolean[] boolArgs = null;
|
||||||
|
byte[] dtypeArgs = null;
|
||||||
long[] extraBits = null;
|
long[] extraBits = null;
|
||||||
if (node.opType() == Op.Type.CUSTOM) {
|
if (node.opType() == Op.Type.CUSTOM) {
|
||||||
DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) node;
|
val dynamicCustomOp = (DynamicCustomOp) node;
|
||||||
extraBits = dynamicCustomOp.iArgs();
|
extraBits = dynamicCustomOp.iArgs();
|
||||||
boolArgs = dynamicCustomOp.bArgs();
|
boolArgs = dynamicCustomOp.bArgs();
|
||||||
|
|
||||||
|
if (dynamicCustomOp.numDArguments() > 0) {
|
||||||
|
dtypeArgs = new byte[dynamicCustomOp.numDArguments()];
|
||||||
|
val d = dynamicCustomOp.dArgs();
|
||||||
|
for (int e = 0; e < dtypeArgs.length; e++) {
|
||||||
|
dtypeArgs[e] = (byte) d[e].toInt();
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (node instanceof Enter) {
|
} else if (node instanceof Enter) {
|
||||||
// in case of Enter node we'll be storing unique frame reference
|
// in case of Enter node we'll be storing unique frame reference
|
||||||
val frameName = ((Enter) node).getFrameName();
|
val frameName = ((Enter) node).getFrameName();
|
||||||
|
@ -817,6 +833,7 @@ public class FlatBuffersMapper {
|
||||||
int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras);
|
int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras);
|
||||||
int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits);
|
int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits);
|
||||||
int bArgs = FlatNode.createExtraBoolsVector(bufferBuilder, boolArgs != null ? boolArgs : new boolean[0]);
|
int bArgs = FlatNode.createExtraBoolsVector(bufferBuilder, boolArgs != null ? boolArgs : new boolean[0]);
|
||||||
|
int dArgs = FlatNode.createOutputTypesVector(bufferBuilder, dtypeArgs != null ? dtypeArgs : new byte[0]);
|
||||||
int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims);
|
int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims);
|
||||||
int fname = bufferBuilder.createString(node.getOwnName());
|
int fname = bufferBuilder.createString(node.getOwnName());
|
||||||
int scopeName = bufferBuilder.createString("");
|
int scopeName = bufferBuilder.createString("");
|
||||||
|
@ -896,7 +913,8 @@ public class FlatBuffersMapper {
|
||||||
scalar,
|
scalar,
|
||||||
opCds,
|
opCds,
|
||||||
varCds,
|
varCds,
|
||||||
cdsFor
|
cdsFor,
|
||||||
|
dArgs
|
||||||
);
|
);
|
||||||
|
|
||||||
return flatNode;
|
return flatNode;
|
||||||
|
|
|
@ -73,6 +73,10 @@ public final class FlatNode extends Table {
|
||||||
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||||
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; }
|
||||||
|
public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); }
|
||||||
|
public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); }
|
||||||
|
|
||||||
public static int createFlatNode(FlatBufferBuilder builder,
|
public static int createFlatNode(FlatBufferBuilder builder,
|
||||||
int id,
|
int id,
|
||||||
|
@ -96,9 +100,11 @@ public final class FlatNode extends Table {
|
||||||
int scalarOffset,
|
int scalarOffset,
|
||||||
int controlDepsOffset,
|
int controlDepsOffset,
|
||||||
int varControlDepsOffset,
|
int varControlDepsOffset,
|
||||||
int controlDepForOffset) {
|
int controlDepForOffset,
|
||||||
builder.startObject(22);
|
int extraTypesOffset) {
|
||||||
|
builder.startObject(23);
|
||||||
FlatNode.addOpNum(builder, opNum);
|
FlatNode.addOpNum(builder, opNum);
|
||||||
|
FlatNode.addExtraTypes(builder, extraTypesOffset);
|
||||||
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||||
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||||
FlatNode.addControlDeps(builder, controlDepsOffset);
|
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||||
|
@ -123,7 +129,7 @@ public final class FlatNode extends Table {
|
||||||
return FlatNode.endFlatNode(builder);
|
return FlatNode.endFlatNode(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); }
|
||||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||||
|
@ -172,6 +178,9 @@ public final class FlatNode extends Table {
|
||||||
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||||
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); }
|
||||||
|
public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||||
|
public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||||
int o = builder.endObject();
|
int o = builder.endObject();
|
||||||
return o;
|
return o;
|
||||||
|
|
|
@ -61,6 +61,7 @@ public class DifferentialFunctionClassHolder {
|
||||||
add("tArguments");
|
add("tArguments");
|
||||||
add("iArguments");
|
add("iArguments");
|
||||||
add("bArguments");
|
add("bArguments");
|
||||||
|
add("dArguments");
|
||||||
add("hash");
|
add("hash");
|
||||||
add("opName");
|
add("opName");
|
||||||
add("sameDiff");
|
add("sameDiff");
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
@ -33,9 +34,10 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
protected Map<Integer,INDArray> fastpath_in = new HashMap<>();
|
protected Map<Integer,INDArray> fastpath_in = new HashMap<>();
|
||||||
protected Map<Integer,INDArray> fastpath_out = new HashMap<>();
|
protected Map<Integer,INDArray> fastpath_out = new HashMap<>();
|
||||||
|
|
||||||
protected List<Double> fastpath_d = new ArrayList<>();
|
protected List<Double> fastpath_t = new ArrayList<>();
|
||||||
protected List<Boolean> fastpath_b = new ArrayList<>();
|
protected List<Boolean> fastpath_b = new ArrayList<>();
|
||||||
protected List<Long> fastpath_i = new ArrayList<>();
|
protected List<Long> fastpath_i = new ArrayList<>();
|
||||||
|
protected List<DataType> fastpath_d = new ArrayList<>();
|
||||||
|
|
||||||
@Setter()
|
@Setter()
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -55,14 +57,14 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setTArguments(double... arguments) {
|
public void setTArguments(double... arguments) {
|
||||||
fastpath_d.clear();
|
fastpath_t.clear();
|
||||||
for (val v:arguments)
|
for (val v:arguments)
|
||||||
fastpath_d.add(v);
|
fastpath_t.add(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Double> getTArguments(){
|
public List<Double> getTArguments(){
|
||||||
return fastpath_d;
|
return fastpath_t;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -77,6 +79,18 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
return fastpath_b;
|
return fastpath_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setDArguments(DataType... arguments) {
|
||||||
|
fastpath_d.clear();
|
||||||
|
for (val v:arguments)
|
||||||
|
fastpath_d.add(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> getDArguments() {
|
||||||
|
return fastpath_d;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputArray(int index, @NonNull INDArray array) {
|
public void setInputArray(int index, @NonNull INDArray array) {
|
||||||
fastpath_in.put(index, array);
|
fastpath_in.put(index, array);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops;
|
package org.nd4j.linalg.api.ops;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
|
||||||
|
@ -57,12 +58,18 @@ public interface CustomOp {
|
||||||
|
|
||||||
boolean[] bArgs();
|
boolean[] bArgs();
|
||||||
|
|
||||||
|
DataType[] dArgs();
|
||||||
|
|
||||||
|
void addTArgument(double... arg);
|
||||||
|
|
||||||
void addIArgument(int... arg);
|
void addIArgument(int... arg);
|
||||||
|
|
||||||
void addIArgument(long... arg);
|
void addIArgument(long... arg);
|
||||||
|
|
||||||
void addBArgument(boolean... arg);
|
void addBArgument(boolean... arg);
|
||||||
|
|
||||||
|
void addDArgument(DataType... arg);
|
||||||
|
|
||||||
void removeIArgument(Integer arg);
|
void removeIArgument(Integer arg);
|
||||||
|
|
||||||
Boolean getBArgument(int index);
|
Boolean getBArgument(int index);
|
||||||
|
@ -71,8 +78,6 @@ public interface CustomOp {
|
||||||
|
|
||||||
int numIArguments();
|
int numIArguments();
|
||||||
|
|
||||||
void addTArgument(double... arg);
|
|
||||||
|
|
||||||
void removeTArgument(Double arg);
|
void removeTArgument(Double arg);
|
||||||
|
|
||||||
Double getTArgument(int index);
|
Double getTArgument(int index);
|
||||||
|
@ -81,6 +86,8 @@ public interface CustomOp {
|
||||||
|
|
||||||
int numBArguments();
|
int numBArguments();
|
||||||
|
|
||||||
|
int numDArguments();
|
||||||
|
|
||||||
void addInputArgument(INDArray... arg);
|
void addInputArgument(INDArray... arg);
|
||||||
|
|
||||||
void removeInputArgument(INDArray arg);
|
void removeInputArgument(INDArray arg);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops;
|
package org.nd4j.linalg.api.ops;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.shade.guava.collect.Lists;
|
import org.nd4j.shade.guava.collect.Lists;
|
||||||
import org.nd4j.shade.guava.primitives.Doubles;
|
import org.nd4j.shade.guava.primitives.Doubles;
|
||||||
import org.nd4j.shade.guava.primitives.Longs;
|
import org.nd4j.shade.guava.primitives.Longs;
|
||||||
|
@ -62,6 +63,9 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
protected List<Boolean> bArguments = new ArrayList<>();
|
protected List<Boolean> bArguments = new ArrayList<>();
|
||||||
|
|
||||||
|
@Builder.Default
|
||||||
|
protected List<DataType> dArguments = new ArrayList<>();
|
||||||
|
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
protected List<Integer> axis = new ArrayList<>();
|
protected List<Integer> axis = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -77,6 +81,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
iArguments = new ArrayList<>();
|
iArguments = new ArrayList<>();
|
||||||
tArguments = new ArrayList<>();
|
tArguments = new ArrayList<>();
|
||||||
bArguments = new ArrayList<>();
|
bArguments = new ArrayList<>();
|
||||||
|
dArguments = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) {
|
public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) {
|
||||||
|
@ -93,6 +98,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
iArguments = new ArrayList<>();
|
iArguments = new ArrayList<>();
|
||||||
tArguments = new ArrayList<>();
|
tArguments = new ArrayList<>();
|
||||||
bArguments = new ArrayList<>();
|
bArguments = new ArrayList<>();
|
||||||
|
dArguments = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public DynamicCustomOp(String opName, INDArray input, INDArray output, List<Double> tArguments, int[] iArguments) {
|
public DynamicCustomOp(String opName, INDArray input, INDArray output, List<Double> tArguments, int[] iArguments) {
|
||||||
|
@ -132,6 +138,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
this.iArguments.add((Long) a.longValue());
|
this.iArguments.add((Long) a.longValue());
|
||||||
}
|
}
|
||||||
bArguments = new ArrayList<>();
|
bArguments = new ArrayList<>();
|
||||||
|
dArguments = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -173,6 +180,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
iArguments = new ArrayList<>();
|
iArguments = new ArrayList<>();
|
||||||
tArguments = new ArrayList<>();
|
tArguments = new ArrayList<>();
|
||||||
bArguments = new ArrayList<>();
|
bArguments = new ArrayList<>();
|
||||||
|
dArguments = new ArrayList<>();
|
||||||
this.inplaceCall = inPlace;
|
this.inplaceCall = inPlace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,6 +193,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
iArguments = new ArrayList<>();
|
iArguments = new ArrayList<>();
|
||||||
tArguments = new ArrayList<>();
|
tArguments = new ArrayList<>();
|
||||||
bArguments = new ArrayList<>();
|
bArguments = new ArrayList<>();
|
||||||
|
dArguments = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -260,6 +269,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
return hash;
|
return hash;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numDArguments() {
|
||||||
|
return dArguments.size();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<INDArray> outputArguments() {
|
public List<INDArray> outputArguments() {
|
||||||
return outputArguments;
|
return outputArguments;
|
||||||
|
@ -280,6 +294,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
return Doubles.toArray(tArguments);
|
return Doubles.toArray(tArguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType[] dArgs() {
|
||||||
|
return dArguments.toArray(new DataType[dArguments.size()]);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addIArgument(int... arg) {
|
public void addIArgument(int... arg) {
|
||||||
for (long a: arg)
|
for (long a: arg)
|
||||||
|
@ -323,6 +342,15 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
addTArgument(Doubles.asList(arg).toArray(new Double[arg.length]));
|
addTArgument(Doubles.asList(arg).toArray(new Double[arg.length]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addDArgument(DataType... arg) {
|
||||||
|
if (dArguments == null)
|
||||||
|
dArguments = new ArrayList<>();
|
||||||
|
|
||||||
|
if (arg != null)
|
||||||
|
dArguments.addAll(Arrays.asList(arg));
|
||||||
|
}
|
||||||
|
|
||||||
private void addTArgument(Double... arg) {
|
private void addTArgument(Double... arg) {
|
||||||
tArguments.addAll(Arrays.asList(arg));
|
tArguments.addAll(Arrays.asList(arg));
|
||||||
}
|
}
|
||||||
|
@ -650,6 +678,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
private List<INDArray> outputArguments = new ArrayList<>();
|
private List<INDArray> outputArguments = new ArrayList<>();
|
||||||
private List<Double> tArguments = new ArrayList<>();
|
private List<Double> tArguments = new ArrayList<>();
|
||||||
private List<Long> iArguments = new ArrayList<>();
|
private List<Long> iArguments = new ArrayList<>();
|
||||||
|
private List<DataType> dArguments = new ArrayList<>();
|
||||||
private List<Boolean> bArguments = new ArrayList<>();
|
private List<Boolean> bArguments = new ArrayList<>();
|
||||||
|
|
||||||
protected DynamicCustomOpsBuilder(String opName, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) {
|
protected DynamicCustomOpsBuilder(String opName, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) {
|
||||||
|
@ -870,6 +899,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
result.iArguments = iArguments;
|
result.iArguments = iArguments;
|
||||||
result.tArguments = tArguments;
|
result.tArguments = tArguments;
|
||||||
result.bArguments = bArguments;
|
result.bArguments = bArguments;
|
||||||
|
result.dArguments = dArguments;
|
||||||
result.inplaceCall = inplaceCall;
|
result.inplaceCall = inplaceCall;
|
||||||
result.hash = opHash;
|
result.hash = opHash;
|
||||||
result.outputShapes = outputShapes;
|
result.outputShapes = outputShapes;
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops;
|
package org.nd4j.linalg.api.ops;
|
||||||
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
@ -43,9 +44,15 @@ public interface OpContext extends AutoCloseable {
|
||||||
* @param arguments
|
* @param arguments
|
||||||
*/
|
*/
|
||||||
void setTArguments(double... arguments);
|
void setTArguments(double... arguments);
|
||||||
|
|
||||||
List<Double> getTArguments();
|
List<Double> getTArguments();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method sets data type arguments required for operation
|
||||||
|
* @param arguments
|
||||||
|
*/
|
||||||
|
void setDArguments(DataType... arguments);
|
||||||
|
List<DataType> getDArguments();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets boolean arguments required for operation
|
* This method sets boolean arguments required for operation
|
||||||
* @param arguments
|
* @param arguments
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||||
|
@ -246,6 +247,21 @@ public class ScatterUpdate implements CustomOp {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType[] dArgs() {
|
||||||
|
return new DataType[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addDArgument(DataType... arg) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numDArguments() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void clearArrays() {
|
public void clearArrays() {
|
||||||
op.clearArrays();
|
op.clearArrays();
|
||||||
|
|
|
@ -83,6 +83,9 @@ public class OneHot extends DynamicCustomOp {
|
||||||
addIArgument(depth);
|
addIArgument(depth);
|
||||||
addTArgument(on);
|
addTArgument(on);
|
||||||
addTArgument(off);
|
addTArgument(off);
|
||||||
|
|
||||||
|
if (outputType != null)
|
||||||
|
addDArgument(outputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -53,6 +55,22 @@ public class OnesLike extends DynamicCustomOp {
|
||||||
public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
|
public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
|
||||||
super(name, sameDiff, new SDVariable[]{input}, false);
|
super(name, sameDiff, new SDVariable[]{input}, false);
|
||||||
this.outputType = dataType;
|
this.outputType = dataType;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public OnesLike(@NonNull INDArray input, DataType dataType) {
|
||||||
|
this.addInputArgument(input);
|
||||||
|
this.outputType = dataType;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public OnesLike(@NonNull INDArray input) {
|
||||||
|
this(input, input.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addArgs() {
|
||||||
|
if (outputType != null)
|
||||||
|
addDArgument(outputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,6 +96,8 @@ public class OnesLike extends DynamicCustomOp {
|
||||||
if(attributesForNode.containsKey("T")) {
|
if(attributesForNode.containsKey("T")) {
|
||||||
outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
|
outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -3438,6 +3438,16 @@ public class Nd4j {
|
||||||
return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length);
|
return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create 2D double array based on java 2d double array. and ordering
|
||||||
|
*
|
||||||
|
* @param data the data to use
|
||||||
|
* @return the created ndarray.
|
||||||
|
*/
|
||||||
|
public static INDArray create(int[][] data) {
|
||||||
|
return createFromArray(data);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create 3D int array based on 3D java int array.
|
* create 3D int array based on 3D java int array.
|
||||||
* @param data java 3D i array.
|
* @param data java 3D i array.
|
||||||
|
|
|
@ -1056,7 +1056,7 @@ public interface NativeOps {
|
||||||
|
|
||||||
OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
|
OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
|
||||||
|
|
||||||
OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs);
|
OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, @Cast("int *") IntPointer dArgs, int numDArgs);
|
||||||
|
|
||||||
long getShapeListSize(OpaqueShapeList list);
|
long getShapeListSize(OpaqueShapeList list);
|
||||||
LongPointer getShape(OpaqueShapeList list, long i);
|
LongPointer getShape(OpaqueShapeList list, long i);
|
||||||
|
@ -1156,6 +1156,7 @@ public interface NativeOps {
|
||||||
void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
||||||
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||||
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
|
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
|
||||||
|
void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments);
|
||||||
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
|
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
|
||||||
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
|
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
|
||||||
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||||
|
|
|
@ -1928,6 +1928,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null;
|
val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null;
|
||||||
|
|
||||||
|
val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null;
|
||||||
|
|
||||||
cnt = 0;
|
cnt = 0;
|
||||||
for (val b: op.bArgs())
|
for (val b: op.bArgs())
|
||||||
bArgs.put(cnt++, b);
|
bArgs.put(cnt++, b);
|
||||||
|
@ -1936,7 +1938,12 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
for (val t: op.tArgs())
|
for (val t: op.tArgs())
|
||||||
tArgs.put(cnt++, t);
|
tArgs.put(cnt++, t);
|
||||||
|
|
||||||
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
|
cnt = 0;
|
||||||
|
val dArgs1 = op.dArgs();
|
||||||
|
for (val d: dArgs1)
|
||||||
|
dArgs.put(cnt++, d.toInt());
|
||||||
|
|
||||||
|
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments());
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
@ -2003,6 +2010,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
context.setBArguments(op.bArgs());
|
context.setBArguments(op.bArgs());
|
||||||
context.setIArguments(op.iArgs());
|
context.setIArguments(op.iArgs());
|
||||||
context.setTArguments(op.tArgs());
|
context.setTArguments(op.tArgs());
|
||||||
|
context.setDArguments(op.dArgs());
|
||||||
|
|
||||||
val result = exec(op, context);
|
val result = exec(op, context);
|
||||||
val states = context.getRngStates();
|
val states = context.getRngStates();
|
||||||
|
|
|
@ -18,12 +18,10 @@ package org.nd4j.linalg.jcublas.ops.executioner;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.BooleanPointer;
|
import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
|
||||||
import org.bytedeco.javacpp.LongPointer;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||||
|
@ -76,6 +74,18 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setDArguments(DataType... arguments) {
|
||||||
|
if (arguments.length > 0) {
|
||||||
|
super.setDArguments(arguments);
|
||||||
|
val args = new int[arguments.length];
|
||||||
|
for (int e = 0; e < arguments.length; e++)
|
||||||
|
args[e] = arguments[e].toInt();
|
||||||
|
|
||||||
|
nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setRngStates(long rootState, long nodeState) {
|
public void setRngStates(long rootState, long nodeState) {
|
||||||
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
||||||
|
|
|
@ -2984,12 +2984,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe
|
||||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs);
|
||||||
|
|
||||||
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
||||||
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
||||||
|
|
|
@ -17,10 +17,9 @@
|
||||||
package org.nd4j.linalg.cpu.nativecpu.ops;
|
package org.nd4j.linalg.cpu.nativecpu.ops;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.bytedeco.javacpp.BooleanPointer;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.LongPointer;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||||
import org.nd4j.linalg.api.ops.ExecutionMode;
|
import org.nd4j.linalg.api.ops.ExecutionMode;
|
||||||
|
@ -73,6 +72,18 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setDArguments(DataType... arguments) {
|
||||||
|
if (arguments.length > 0) {
|
||||||
|
super.setDArguments(arguments);
|
||||||
|
val args = new int[arguments.length];
|
||||||
|
for (int e = 0; e < arguments.length; e++)
|
||||||
|
args[e] = arguments[e].toInt();
|
||||||
|
|
||||||
|
nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setRngStates(long rootState, long nodeState) {
|
public void setRngStates(long rootState, long nodeState) {
|
||||||
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
||||||
|
|
|
@ -1636,6 +1636,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
context.setBArguments(op.bArgs());
|
context.setBArguments(op.bArgs());
|
||||||
context.setIArguments(op.iArgs());
|
context.setIArguments(op.iArgs());
|
||||||
context.setTArguments(op.tArgs());
|
context.setTArguments(op.tArgs());
|
||||||
|
context.setDArguments(op.dArgs());
|
||||||
|
|
||||||
val result = exec(op, context);
|
val result = exec(op, context);
|
||||||
val states = context.getRngStates();
|
val states = context.getRngStates();
|
||||||
|
@ -1712,6 +1713,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val bArgs = op.numBArguments() > 0 ? new BooleanPointer(op.numBArguments()) : null;
|
val bArgs = op.numBArguments() > 0 ? new BooleanPointer(op.numBArguments()) : null;
|
||||||
|
|
||||||
|
val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null;
|
||||||
|
|
||||||
cnt = 0;
|
cnt = 0;
|
||||||
val bArgs1 = op.bArgs();
|
val bArgs1 = op.bArgs();
|
||||||
for (val b: bArgs1)
|
for (val b: bArgs1)
|
||||||
|
@ -1722,11 +1725,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
for (val t: tArgs1)
|
for (val t: tArgs1)
|
||||||
tArgs.put(cnt++, t);
|
tArgs.put(cnt++, t);
|
||||||
|
|
||||||
|
cnt = 0;
|
||||||
|
val dArgs1 = op.dArgs();
|
||||||
|
for (val d: dArgs1)
|
||||||
|
dArgs.put(cnt++, d.toInt());
|
||||||
|
|
||||||
|
|
||||||
OpaqueShapeList ptrptr;
|
OpaqueShapeList ptrptr;
|
||||||
try {
|
try {
|
||||||
ptrptr = loop.calculateOutputShapes2(null,
|
ptrptr = loop.calculateOutputShapes2(null,
|
||||||
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
|
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
|
||||||
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
|
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments(), dArgs, op.numDArguments());
|
||||||
|
|
||||||
if (loop.lastErrorCode() != 0)
|
if (loop.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(loop.lastErrorMessage());
|
throw new RuntimeException(loop.lastErrorMessage());
|
||||||
|
|
|
@ -2987,12 +2987,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe
|
||||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
|
||||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs);
|
||||||
|
|
||||||
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
||||||
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
||||||
|
@ -17951,7 +17951,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_zeros_as)
|
// #if NOT_EXCLUDED(OP_zeros_as)
|
||||||
@Namespace("nd4j::ops") public static class zeros_as extends DeclarableOp {
|
@Namespace("nd4j::ops") public static class zeros_as extends DeclarableCustomOp {
|
||||||
static { Loader.load(); }
|
static { Loader.load(); }
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
public zeros_as(Pointer p) { super(p); }
|
public zeros_as(Pointer p) { super(p); }
|
||||||
|
@ -17962,10 +17962,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
return (zeros_as)super.position(position);
|
return (zeros_as)super.position(position);
|
||||||
}
|
}
|
||||||
|
|
||||||
public zeros_as() { super((Pointer)null); allocate(); }
|
public zeros_as() { super((Pointer)null); allocate(); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -17975,7 +17975,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_ones_as)
|
// #if NOT_EXCLUDED(OP_ones_as)
|
||||||
@Namespace("nd4j::ops") public static class ones_as extends DeclarableOp {
|
@Namespace("nd4j::ops") public static class ones_as extends DeclarableCustomOp {
|
||||||
static { Loader.load(); }
|
static { Loader.load(); }
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
public ones_as(Pointer p) { super(p); }
|
public ones_as(Pointer p) { super(p); }
|
||||||
|
@ -17986,10 +17986,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
return (ones_as)super.position(position);
|
return (ones_as)super.position(position);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ones_as() { super((Pointer)null); allocate(); }
|
public ones_as() { super((Pointer)null); allocate(); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1169,6 +1169,26 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOneHot4() {
|
||||||
|
|
||||||
|
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable indices = sd.constant("indices", indicesArr);
|
||||||
|
int depth = 3;
|
||||||
|
int axis = -1;
|
||||||
|
SDVariable oneHot = sd.oneHot("oneHot", indices, depth, axis, 5.0, 0.0, DataType.INT32);
|
||||||
|
|
||||||
|
INDArray exp = Nd4j.create(new int[][]{{5, 0, 0}, {0,0,5}, {0,0,0}, {0, 5, 0}});
|
||||||
|
|
||||||
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
.expected(oneHot, exp)
|
||||||
|
.gradientCheck(false));
|
||||||
|
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOneHot3() {
|
public void testOneHot3() {
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
||||||
|
@ -1204,8 +1224,6 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLinspace(){
|
public void testLinspace(){
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
|
@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeArea;
|
||||||
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
||||||
|
@ -1673,4 +1674,13 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
|
|
||||||
assertEquals(expected, ret[0]);
|
assertEquals(expected, ret[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOnesLike_1() {
|
||||||
|
val x = Nd4j.create(DataType.FLOAT, 3, 4, 5);
|
||||||
|
val e = Nd4j.ones(DataType.INT32, 3, 4, 5);
|
||||||
|
|
||||||
|
val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0];
|
||||||
|
assertEquals(e, z);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue