Configurable DataType for ops (#201)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * - one more test for OneHot with dtype - one more signature in Nd4j Signed-off-by: raver119 <raver119@gmail.com> * ones_as/zeros_as now accept dtype Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * - more updates for configurable data types - ones_as/zeros_as java side + tests Signed-off-by: raver119 <raver119@gmail.com> * few c++ tests fixed Signed-off-by: raver119 <raver119@gmail.com> * few more changes around DArgs Signed-off-by: raver119 <raver119@gmail.com>master
parent
ba961c7601
commit
5d98cfcf47
|
@ -1518,7 +1518,7 @@ ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPoi
|
|||
typedef nd4j::ShapeList OpaqueShapeList;
|
||||
|
||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs);
|
||||
|
||||
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
|
||||
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
|
||||
|
|
|
@ -1974,7 +1974,7 @@ void deleteShapeList(Nd4jPointer shapeList) {
|
|||
delete list;
|
||||
}
|
||||
|
||||
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
||||
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||
nd4j::graph::VariableSpace varSpace;
|
||||
Context block(2, &varSpace);
|
||||
nd4j::ShapeList inShapes;
|
||||
|
@ -1988,6 +1988,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
|||
for (int e = 0; e < numBArgs; e++)
|
||||
block.getBArguments()->push_back(bArgs[e]);
|
||||
|
||||
for (int e = 0; e < numDArgs; e++)
|
||||
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
|
||||
|
||||
for (int e = 0; e < numInputShapes; e++) {
|
||||
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
||||
|
||||
|
@ -2015,11 +2018,11 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
|||
return shapeList;
|
||||
}
|
||||
|
||||
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
||||
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||
try {
|
||||
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
|
||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs);
|
||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
||||
} catch (std::exception &e) {
|
||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||
|
|
|
@ -2684,7 +2684,7 @@ const char* getAllCustomOps() {
|
|||
}
|
||||
|
||||
|
||||
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
||||
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||
nd4j::graph::VariableSpace varSpace;
|
||||
Context block(2, &varSpace);
|
||||
nd4j::ShapeList inShapes;
|
||||
|
@ -2698,6 +2698,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
|||
for (int e = 0; e < numBArgs; e++)
|
||||
block.getBArguments()->push_back(bArgs[e]);
|
||||
|
||||
for (int e = 0; e < numDArgs; e++)
|
||||
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
|
||||
|
||||
for (int e = 0; e < numInputShapes; e++) {
|
||||
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
||||
|
||||
|
@ -2722,12 +2725,12 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
|||
return shapeList;
|
||||
}
|
||||
|
||||
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
||||
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||
try {
|
||||
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
|
||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs,
|
||||
iArgs, numIArgs, bArgs, numBArgs);
|
||||
iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
||||
} catch (std::exception &e) {
|
||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||
|
|
|
@ -112,6 +112,14 @@ public struct FlatNode : IFlatbufferObject
|
|||
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
public DType ExtraTypes(int j) { int o = __p.__offset(48); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
|
||||
public int ExtraTypesLength { get { int o = __p.__offset(48); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
#if ENABLE_SPAN_T
|
||||
public Span<byte> GetExtraTypesBytes() { return __p.__vector_as_span(48); }
|
||||
#else
|
||||
public ArraySegment<byte>? GetExtraTypesBytes() { return __p.__vector_as_arraysegment(48); }
|
||||
#endif
|
||||
public DType[] GetExtraTypesArray() { return __p.__vector_as_array<DType>(48); }
|
||||
|
||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||
int id = 0,
|
||||
|
@ -135,9 +143,11 @@ public struct FlatNode : IFlatbufferObject
|
|||
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
|
||||
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||
VectorOffset varControlDepsOffset = default(VectorOffset),
|
||||
VectorOffset controlDepForOffset = default(VectorOffset)) {
|
||||
builder.StartObject(22);
|
||||
VectorOffset controlDepForOffset = default(VectorOffset),
|
||||
VectorOffset extraTypesOffset = default(VectorOffset)) {
|
||||
builder.StartObject(23);
|
||||
FlatNode.AddOpNum(builder, opNum);
|
||||
FlatNode.AddExtraTypes(builder, extraTypesOffset);
|
||||
FlatNode.AddControlDepFor(builder, controlDepForOffset);
|
||||
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
|
||||
FlatNode.AddControlDeps(builder, controlDepsOffset);
|
||||
|
@ -162,7 +172,7 @@ public struct FlatNode : IFlatbufferObject
|
|||
return FlatNode.EndFlatNode(builder);
|
||||
}
|
||||
|
||||
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); }
|
||||
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(23); }
|
||||
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
|
||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
|
||||
|
@ -224,6 +234,10 @@ public struct FlatNode : IFlatbufferObject
|
|||
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static void AddExtraTypes(FlatBufferBuilder builder, VectorOffset extraTypesOffset) { builder.AddOffset(22, extraTypesOffset.Value, 0); }
|
||||
public static VectorOffset CreateExtraTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateExtraTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||
int o = builder.EndObject();
|
||||
return new Offset<FlatNode>(o);
|
||||
|
|
|
@ -72,6 +72,10 @@ public final class FlatNode extends Table {
|
|||
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||
public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; }
|
||||
public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; }
|
||||
public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); }
|
||||
public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); }
|
||||
|
||||
public static int createFlatNode(FlatBufferBuilder builder,
|
||||
int id,
|
||||
|
@ -95,9 +99,11 @@ public final class FlatNode extends Table {
|
|||
int scalarOffset,
|
||||
int controlDepsOffset,
|
||||
int varControlDepsOffset,
|
||||
int controlDepForOffset) {
|
||||
builder.startObject(22);
|
||||
int controlDepForOffset,
|
||||
int extraTypesOffset) {
|
||||
builder.startObject(23);
|
||||
FlatNode.addOpNum(builder, opNum);
|
||||
FlatNode.addExtraTypes(builder, extraTypesOffset);
|
||||
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||
|
@ -122,7 +128,7 @@ public final class FlatNode extends Table {
|
|||
return FlatNode.endFlatNode(builder);
|
||||
}
|
||||
|
||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); }
|
||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||
|
@ -171,6 +177,9 @@ public final class FlatNode extends Table {
|
|||
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); }
|
||||
public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||
public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||
int o = builder.endObject();
|
||||
return o;
|
||||
|
|
|
@ -339,7 +339,29 @@ class FlatNode(object):
|
|||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
def FlatNodeStart(builder): builder.StartObject(22)
|
||||
# FlatNode
|
||||
def ExtraTypes(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
|
||||
return 0
|
||||
|
||||
# FlatNode
|
||||
def ExtraTypesAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
|
||||
return 0
|
||||
|
||||
# FlatNode
|
||||
def ExtraTypesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
def FlatNodeStart(builder): builder.StartObject(23)
|
||||
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
|
||||
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
|
||||
|
@ -375,4 +397,6 @@ def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTR
|
|||
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
|
||||
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||
def FlatNodeAddExtraTypes(builder, extraTypes): builder.PrependUOffsetTRelativeSlot(22, flatbuffers.number_types.UOffsetTFlags.py_type(extraTypes), 0)
|
||||
def FlatNodeStartExtraTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
|
||||
def FlatNodeEnd(builder): return builder.EndObject()
|
||||
|
|
|
@ -38,7 +38,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VT_SCALAR = 40,
|
||||
VT_CONTROLDEPS = 42,
|
||||
VT_VARCONTROLDEPS = 44,
|
||||
VT_CONTROLDEPFOR = 46
|
||||
VT_CONTROLDEPFOR = 46,
|
||||
VT_EXTRATYPES = 48
|
||||
};
|
||||
int32_t id() const {
|
||||
return GetField<int32_t>(VT_ID, 0);
|
||||
|
@ -106,6 +107,9 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
|
||||
}
|
||||
const flatbuffers::Vector<int8_t> *extraTypes() const {
|
||||
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_EXTRATYPES);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyField<int32_t>(verifier, VT_ID) &&
|
||||
|
@ -153,6 +157,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
|
||||
verifier.VerifyVector(controlDepFor()) &&
|
||||
verifier.VerifyVectorOfStrings(controlDepFor()) &&
|
||||
VerifyOffset(verifier, VT_EXTRATYPES) &&
|
||||
verifier.VerifyVector(extraTypes()) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
};
|
||||
|
@ -226,6 +232,9 @@ struct FlatNodeBuilder {
|
|||
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
|
||||
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
|
||||
}
|
||||
void add_extraTypes(flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes) {
|
||||
fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes);
|
||||
}
|
||||
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
|
@ -261,9 +270,11 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNode(
|
|||
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0) {
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes = 0) {
|
||||
FlatNodeBuilder builder_(_fbb);
|
||||
builder_.add_opNum(opNum);
|
||||
builder_.add_extraTypes(extraTypes);
|
||||
builder_.add_controlDepFor(controlDepFor);
|
||||
builder_.add_varControlDeps(varControlDeps);
|
||||
builder_.add_controlDeps(controlDeps);
|
||||
|
@ -311,7 +322,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
|||
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr) {
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr,
|
||||
const std::vector<int8_t> *extraTypes = nullptr) {
|
||||
return nd4j::graph::CreateFlatNode(
|
||||
_fbb,
|
||||
id,
|
||||
|
@ -335,7 +347,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
|||
scalar,
|
||||
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
||||
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
|
||||
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0);
|
||||
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0,
|
||||
extraTypes ? _fbb.CreateVector<int8_t>(*extraTypes) : 0);
|
||||
}
|
||||
|
||||
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {
|
||||
|
|
|
@ -398,11 +398,36 @@ nd4j.graph.FlatNode.prototype.controlDepForLength = function() {
|
|||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @returns {nd4j.graph.DType}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.extraTypes = function(index) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns {number}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.extraTypesLength = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns {Int8Array}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.extraTypesArray = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||
return offset ? new Int8Array(this.bb.bytes().buffer, this.bb.bytes().byteOffset + this.bb.__vector(this.bb_pos + offset), this.bb.__vector_len(this.bb_pos + offset)) : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
*/
|
||||
nd4j.graph.FlatNode.startFlatNode = function(builder) {
|
||||
builder.startObject(22);
|
||||
builder.startObject(23);
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -854,6 +879,35 @@ nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) {
|
|||
builder.startVector(4, numElems, 4);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {flatbuffers.Offset} extraTypesOffset
|
||||
*/
|
||||
nd4j.graph.FlatNode.addExtraTypes = function(builder, extraTypesOffset) {
|
||||
builder.addFieldOffset(22, extraTypesOffset, 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<nd4j.graph.DType>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatNode.createExtraTypesVector = function(builder, data) {
|
||||
builder.startVector(1, data.length, 1);
|
||||
for (var i = data.length - 1; i >= 0; i--) {
|
||||
builder.addInt8(data[i]);
|
||||
}
|
||||
return builder.endVector();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {number} numElems
|
||||
*/
|
||||
nd4j.graph.FlatNode.startExtraTypesVector = function(builder, numElems) {
|
||||
builder.startVector(1, numElems, 1);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @returns {flatbuffers.Offset}
|
||||
|
|
|
@ -587,9 +587,9 @@ namespace nd4j {
|
|||
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
||||
}
|
||||
|
||||
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
||||
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
||||
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
||||
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -624,9 +624,9 @@ namespace nd4j {
|
|||
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
||||
}
|
||||
|
||||
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
||||
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
||||
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
||||
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -664,9 +664,9 @@ namespace nd4j {
|
|||
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
||||
}
|
||||
|
||||
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
||||
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
||||
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
||||
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,9 @@ table FlatNode {
|
|||
controlDeps:[string];
|
||||
varControlDeps:[string];
|
||||
controlDepFor:[string];
|
||||
|
||||
|
||||
// DArgs
|
||||
extraTypes:[DType];
|
||||
}
|
||||
|
||||
root_type FlatNode;
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
OP_IMPL(ones_as, 1, 1, false) {
|
||||
CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) {
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
output->assign(1);
|
||||
|
@ -33,11 +33,21 @@ namespace nd4j {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(ones_as) {
|
||||
auto in = inputShape->at(0);
|
||||
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
|
||||
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
|
||||
|
||||
nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str());
|
||||
|
||||
return SHAPELIST(shape);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(ones_as) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
->setSameMode(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
OP_IMPL(zeros_as, 1, 1, false) {
|
||||
CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) {
|
||||
auto out = OUTPUT_VARIABLE(0);
|
||||
|
||||
out->assign(0); // output is filled by zero by default
|
||||
|
@ -35,11 +35,20 @@ namespace nd4j {
|
|||
DECLARE_SYN(zeroslike, zeros_as);
|
||||
DECLARE_SYN(zeros_like, zeros_as);
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(zeros_as) {
|
||||
auto in = inputShape->at(0);
|
||||
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
|
||||
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
|
||||
|
||||
return SHAPELIST(shape);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(zeros_as) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
->setSameMode(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -487,7 +487,7 @@ namespace nd4j {
|
|||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_zeros_as)
|
||||
DECLARE_OP(zeros_as, 1, 1, false);
|
||||
DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -497,7 +497,7 @@ namespace nd4j {
|
|||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_ones_as)
|
||||
DECLARE_OP(ones_as, 1, 1, false);
|
||||
DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
|
|||
|
||||
double tArgs[] = { -1.0, 1.0, 0.01 };
|
||||
|
||||
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0);
|
||||
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||
shape::printShapeInfoLinear("Result", shapes->at(0));
|
||||
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));
|
||||
|
||||
|
|
|
@ -2978,6 +2978,24 @@ TEST_F(DeclarableOpsTests8, ones_as_test2) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests8, ones_as_test3) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>(10.);
|
||||
//auto y = NDArrayFactory::create<double>(100.);
|
||||
auto exp = NDArrayFactory::create<int>(1.);
|
||||
|
||||
nd4j::ops::ones_as op;
|
||||
|
||||
auto results = op.evaluate({&x}, {}, {}, {}, {nd4j::DataType::INT32});
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
auto y = results->at(0);
|
||||
ASSERT_TRUE(y->isSameShape(exp));
|
||||
ASSERT_TRUE(y->equalsTo(exp));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) {
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
|||
Nd4jLong iArgs[] = {1};
|
||||
auto hash = op.getOpHash();
|
||||
|
||||
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
||||
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0);
|
||||
|
||||
ASSERT_EQ(3, shapeList->size());
|
||||
|
||||
|
@ -1065,7 +1065,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
|
|||
|
||||
NDArray::prepareSpecialUse({}, {&arrayX, &arrayY});
|
||||
nd4j::ops::greater_equal op;
|
||||
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||
NDArray::registerSpecialUse({}, {&arrayX, &arrayY});
|
||||
delete shapeList;
|
||||
}
|
||||
|
|
|
@ -1579,7 +1579,7 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) {
|
|||
#endif
|
||||
|
||||
auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast<double*>(tArgs.data()), tArgs.size(),
|
||||
const_cast<Nd4jLong*>(iArgs.data()), iArgs.size(), nullptr, bArgsF.size());
|
||||
const_cast<Nd4jLong*>(iArgs.data()), iArgs.size(), nullptr, bArgsF.size(), nullptr, 0);
|
||||
// Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs
|
||||
ASSERT_EQ(1, shapeList->size());
|
||||
|
||||
|
|
|
@ -4704,7 +4704,7 @@ public class SameDiff extends SDBaseOps {
|
|||
0,
|
||||
0,
|
||||
-1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
return flatNode;
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.autodiff.samediff.serde;
|
||||
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.shade.guava.primitives.Ints;
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import java.nio.ByteOrder;
|
||||
|
@ -361,6 +362,11 @@ public class FlatBuffersMapper {
|
|||
for (int i = 0; i < extraBools.length; i++) {
|
||||
extraBools[i] = fn.extraBools(i);
|
||||
}
|
||||
DataType[] extraDTypes = new DataType[fn.extraTypesLength()];
|
||||
for (int i = 0; i < extraDTypes.length; i++) {
|
||||
extraDTypes[i] = DataType.fromInt(fn.extraTypes(i));
|
||||
}
|
||||
|
||||
int[] dimensions = new int[fn.dimensionsLength()];
|
||||
for (int i = 0; i < dimensions.length; i++) {
|
||||
dimensions[i] = fn.dimensions(i);
|
||||
|
@ -401,6 +407,7 @@ public class FlatBuffersMapper {
|
|||
((CustomOp) op).addIArgument(extraInteger);
|
||||
((CustomOp) op).addTArgument(extraParams);
|
||||
((CustomOp) op).addBArgument(extraBools);
|
||||
((CustomOp) op).addDArgument(extraDTypes);
|
||||
|
||||
op.setPropertiesForFunction(props);
|
||||
return op;
|
||||
|
@ -714,11 +721,20 @@ public class FlatBuffersMapper {
|
|||
}
|
||||
|
||||
boolean[] boolArgs = null;
|
||||
byte[] dtypeArgs = null;
|
||||
long[] extraBits = null;
|
||||
if (node.opType() == Op.Type.CUSTOM) {
|
||||
DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) node;
|
||||
val dynamicCustomOp = (DynamicCustomOp) node;
|
||||
extraBits = dynamicCustomOp.iArgs();
|
||||
boolArgs = dynamicCustomOp.bArgs();
|
||||
|
||||
if (dynamicCustomOp.numDArguments() > 0) {
|
||||
dtypeArgs = new byte[dynamicCustomOp.numDArguments()];
|
||||
val d = dynamicCustomOp.dArgs();
|
||||
for (int e = 0; e < dtypeArgs.length; e++) {
|
||||
dtypeArgs[e] = (byte) d[e].toInt();
|
||||
}
|
||||
}
|
||||
} else if (node instanceof Enter) {
|
||||
// in case of Enter node we'll be storing unique frame reference
|
||||
val frameName = ((Enter) node).getFrameName();
|
||||
|
@ -817,6 +833,7 @@ public class FlatBuffersMapper {
|
|||
int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras);
|
||||
int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits);
|
||||
int bArgs = FlatNode.createExtraBoolsVector(bufferBuilder, boolArgs != null ? boolArgs : new boolean[0]);
|
||||
int dArgs = FlatNode.createOutputTypesVector(bufferBuilder, dtypeArgs != null ? dtypeArgs : new byte[0]);
|
||||
int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims);
|
||||
int fname = bufferBuilder.createString(node.getOwnName());
|
||||
int scopeName = bufferBuilder.createString("");
|
||||
|
@ -896,7 +913,8 @@ public class FlatBuffersMapper {
|
|||
scalar,
|
||||
opCds,
|
||||
varCds,
|
||||
cdsFor
|
||||
cdsFor,
|
||||
dArgs
|
||||
);
|
||||
|
||||
return flatNode;
|
||||
|
|
|
@ -73,6 +73,10 @@ public final class FlatNode extends Table {
|
|||
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||
public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; }
|
||||
public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; }
|
||||
public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); }
|
||||
public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); }
|
||||
|
||||
public static int createFlatNode(FlatBufferBuilder builder,
|
||||
int id,
|
||||
|
@ -96,9 +100,11 @@ public final class FlatNode extends Table {
|
|||
int scalarOffset,
|
||||
int controlDepsOffset,
|
||||
int varControlDepsOffset,
|
||||
int controlDepForOffset) {
|
||||
builder.startObject(22);
|
||||
int controlDepForOffset,
|
||||
int extraTypesOffset) {
|
||||
builder.startObject(23);
|
||||
FlatNode.addOpNum(builder, opNum);
|
||||
FlatNode.addExtraTypes(builder, extraTypesOffset);
|
||||
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||
|
@ -123,7 +129,7 @@ public final class FlatNode extends Table {
|
|||
return FlatNode.endFlatNode(builder);
|
||||
}
|
||||
|
||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); }
|
||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||
|
@ -172,6 +178,9 @@ public final class FlatNode extends Table {
|
|||
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); }
|
||||
public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||
public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||
int o = builder.endObject();
|
||||
return o;
|
||||
|
|
|
@ -61,6 +61,7 @@ public class DifferentialFunctionClassHolder {
|
|||
add("tArguments");
|
||||
add("iArguments");
|
||||
add("bArguments");
|
||||
add("dArguments");
|
||||
add("hash");
|
||||
add("opName");
|
||||
add("sameDiff");
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.Getter;
|
|||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.*;
|
||||
|
@ -33,9 +34,10 @@ public abstract class BaseOpContext implements OpContext {
|
|||
protected Map<Integer,INDArray> fastpath_in = new HashMap<>();
|
||||
protected Map<Integer,INDArray> fastpath_out = new HashMap<>();
|
||||
|
||||
protected List<Double> fastpath_d = new ArrayList<>();
|
||||
protected List<Double> fastpath_t = new ArrayList<>();
|
||||
protected List<Boolean> fastpath_b = new ArrayList<>();
|
||||
protected List<Long> fastpath_i = new ArrayList<>();
|
||||
protected List<DataType> fastpath_d = new ArrayList<>();
|
||||
|
||||
@Setter()
|
||||
@Getter
|
||||
|
@ -55,14 +57,14 @@ public abstract class BaseOpContext implements OpContext {
|
|||
|
||||
@Override
|
||||
public void setTArguments(double... arguments) {
|
||||
fastpath_d.clear();
|
||||
fastpath_t.clear();
|
||||
for (val v:arguments)
|
||||
fastpath_d.add(v);
|
||||
fastpath_t.add(v);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> getTArguments(){
|
||||
return fastpath_d;
|
||||
return fastpath_t;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -77,6 +79,18 @@ public abstract class BaseOpContext implements OpContext {
|
|||
return fastpath_b;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDArguments(DataType... arguments) {
|
||||
fastpath_d.clear();
|
||||
for (val v:arguments)
|
||||
fastpath_d.add(v);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> getDArguments() {
|
||||
return fastpath_d;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInputArray(int index, @NonNull INDArray array) {
|
||||
fastpath_in.put(index, array);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
|
||||
|
@ -57,12 +58,18 @@ public interface CustomOp {
|
|||
|
||||
boolean[] bArgs();
|
||||
|
||||
DataType[] dArgs();
|
||||
|
||||
void addTArgument(double... arg);
|
||||
|
||||
void addIArgument(int... arg);
|
||||
|
||||
void addIArgument(long... arg);
|
||||
|
||||
void addBArgument(boolean... arg);
|
||||
|
||||
void addDArgument(DataType... arg);
|
||||
|
||||
void removeIArgument(Integer arg);
|
||||
|
||||
Boolean getBArgument(int index);
|
||||
|
@ -71,8 +78,6 @@ public interface CustomOp {
|
|||
|
||||
int numIArguments();
|
||||
|
||||
void addTArgument(double... arg);
|
||||
|
||||
void removeTArgument(Double arg);
|
||||
|
||||
Double getTArgument(int index);
|
||||
|
@ -81,6 +86,8 @@ public interface CustomOp {
|
|||
|
||||
int numBArguments();
|
||||
|
||||
int numDArguments();
|
||||
|
||||
void addInputArgument(INDArray... arg);
|
||||
|
||||
void removeInputArgument(INDArray arg);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.shade.guava.collect.Lists;
|
||||
import org.nd4j.shade.guava.primitives.Doubles;
|
||||
import org.nd4j.shade.guava.primitives.Longs;
|
||||
|
@ -62,6 +63,9 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
@Builder.Default
|
||||
protected List<Boolean> bArguments = new ArrayList<>();
|
||||
|
||||
@Builder.Default
|
||||
protected List<DataType> dArguments = new ArrayList<>();
|
||||
|
||||
@Builder.Default
|
||||
protected List<Integer> axis = new ArrayList<>();
|
||||
|
||||
|
@ -77,6 +81,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
iArguments = new ArrayList<>();
|
||||
tArguments = new ArrayList<>();
|
||||
bArguments = new ArrayList<>();
|
||||
dArguments = new ArrayList<>();
|
||||
}
|
||||
|
||||
public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) {
|
||||
|
@ -93,6 +98,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
iArguments = new ArrayList<>();
|
||||
tArguments = new ArrayList<>();
|
||||
bArguments = new ArrayList<>();
|
||||
dArguments = new ArrayList<>();
|
||||
}
|
||||
|
||||
public DynamicCustomOp(String opName, INDArray input, INDArray output, List<Double> tArguments, int[] iArguments) {
|
||||
|
@ -132,6 +138,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
this.iArguments.add((Long) a.longValue());
|
||||
}
|
||||
bArguments = new ArrayList<>();
|
||||
dArguments = new ArrayList<>();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -173,6 +180,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
iArguments = new ArrayList<>();
|
||||
tArguments = new ArrayList<>();
|
||||
bArguments = new ArrayList<>();
|
||||
dArguments = new ArrayList<>();
|
||||
this.inplaceCall = inPlace;
|
||||
}
|
||||
|
||||
|
@ -185,6 +193,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
iArguments = new ArrayList<>();
|
||||
tArguments = new ArrayList<>();
|
||||
bArguments = new ArrayList<>();
|
||||
dArguments = new ArrayList<>();
|
||||
}
|
||||
|
||||
|
||||
|
@ -260,6 +269,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
return hash;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numDArguments() {
|
||||
return dArguments.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<INDArray> outputArguments() {
|
||||
return outputArguments;
|
||||
|
@ -280,6 +294,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
return Doubles.toArray(tArguments);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType[] dArgs() {
|
||||
return dArguments.toArray(new DataType[dArguments.size()]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIArgument(int... arg) {
|
||||
for (long a: arg)
|
||||
|
@ -323,6 +342,15 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
addTArgument(Doubles.asList(arg).toArray(new Double[arg.length]));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addDArgument(DataType... arg) {
|
||||
if (dArguments == null)
|
||||
dArguments = new ArrayList<>();
|
||||
|
||||
if (arg != null)
|
||||
dArguments.addAll(Arrays.asList(arg));
|
||||
}
|
||||
|
||||
private void addTArgument(Double... arg) {
|
||||
tArguments.addAll(Arrays.asList(arg));
|
||||
}
|
||||
|
@ -650,6 +678,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
private List<INDArray> outputArguments = new ArrayList<>();
|
||||
private List<Double> tArguments = new ArrayList<>();
|
||||
private List<Long> iArguments = new ArrayList<>();
|
||||
private List<DataType> dArguments = new ArrayList<>();
|
||||
private List<Boolean> bArguments = new ArrayList<>();
|
||||
|
||||
protected DynamicCustomOpsBuilder(String opName, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) {
|
||||
|
@ -870,6 +899,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
result.iArguments = iArguments;
|
||||
result.tArguments = tArguments;
|
||||
result.bArguments = bArguments;
|
||||
result.dArguments = dArguments;
|
||||
result.inplaceCall = inplaceCall;
|
||||
result.hash = opHash;
|
||||
result.outputShapes = outputShapes;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
|
@ -43,9 +44,15 @@ public interface OpContext extends AutoCloseable {
|
|||
* @param arguments
|
||||
*/
|
||||
void setTArguments(double... arguments);
|
||||
|
||||
List<Double> getTArguments();
|
||||
|
||||
/**
|
||||
* This method sets data type arguments required for operation
|
||||
* @param arguments
|
||||
*/
|
||||
void setDArguments(DataType... arguments);
|
||||
List<DataType> getDArguments();
|
||||
|
||||
/**
|
||||
* This method sets boolean arguments required for operation
|
||||
* @param arguments
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.custom;
|
|||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||
|
@ -246,6 +247,21 @@ public class ScatterUpdate implements CustomOp {
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType[] dArgs() {
|
||||
return new DataType[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addDArgument(DataType... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numDArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clearArrays() {
|
||||
op.clearArrays();
|
||||
|
|
|
@ -83,6 +83,9 @@ public class OneHot extends DynamicCustomOp {
|
|||
addIArgument(depth);
|
||||
addTArgument(on);
|
||||
addTArgument(off);
|
||||
|
||||
if (outputType != null)
|
||||
addDArgument(outputType);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -53,6 +55,22 @@ public class OnesLike extends DynamicCustomOp {
|
|||
public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
|
||||
super(name, sameDiff, new SDVariable[]{input}, false);
|
||||
this.outputType = dataType;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public OnesLike(@NonNull INDArray input, DataType dataType) {
|
||||
this.addInputArgument(input);
|
||||
this.outputType = dataType;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public OnesLike(@NonNull INDArray input) {
|
||||
this(input, input.dataType());
|
||||
}
|
||||
|
||||
public void addArgs() {
|
||||
if (outputType != null)
|
||||
addDArgument(outputType);
|
||||
}
|
||||
|
||||
|
||||
|
@ -78,6 +96,8 @@ public class OnesLike extends DynamicCustomOp {
|
|||
if(attributesForNode.containsKey("T")) {
|
||||
outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
|
||||
}
|
||||
|
||||
addArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -3438,6 +3438,16 @@ public class Nd4j {
|
|||
return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create 2D double array based on java 2d double array. and ordering
|
||||
*
|
||||
* @param data the data to use
|
||||
* @return the created ndarray.
|
||||
*/
|
||||
public static INDArray create(int[][] data) {
|
||||
return createFromArray(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* create 3D int array based on 3D java int array.
|
||||
* @param data java 3D i array.
|
||||
|
|
|
@ -1056,7 +1056,7 @@ public interface NativeOps {
|
|||
|
||||
OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
|
||||
|
||||
OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs);
|
||||
OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, @Cast("int *") IntPointer dArgs, int numDArgs);
|
||||
|
||||
long getShapeListSize(OpaqueShapeList list);
|
||||
LongPointer getShape(OpaqueShapeList list, long i);
|
||||
|
@ -1156,6 +1156,7 @@ public interface NativeOps {
|
|||
void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
||||
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
|
||||
void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments);
|
||||
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
|
||||
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
|
||||
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||
|
|
|
@ -1928,6 +1928,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null;
|
||||
|
||||
val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null;
|
||||
|
||||
cnt = 0;
|
||||
for (val b: op.bArgs())
|
||||
bArgs.put(cnt++, b);
|
||||
|
@ -1936,7 +1938,12 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
for (val t: op.tArgs())
|
||||
tArgs.put(cnt++, t);
|
||||
|
||||
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
|
||||
cnt = 0;
|
||||
val dArgs1 = op.dArgs();
|
||||
for (val d: dArgs1)
|
||||
dArgs.put(cnt++, d.toInt());
|
||||
|
||||
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments());
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
@ -2003,6 +2010,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
context.setBArguments(op.bArgs());
|
||||
context.setIArguments(op.iArgs());
|
||||
context.setTArguments(op.tArgs());
|
||||
context.setDArguments(op.dArgs());
|
||||
|
||||
val result = exec(op, context);
|
||||
val states = context.getRngStates();
|
||||
|
|
|
@ -18,12 +18,10 @@ package org.nd4j.linalg.jcublas.ops.executioner;
|
|||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.BooleanPointer;
|
||||
import org.bytedeco.javacpp.DoublePointer;
|
||||
import org.bytedeco.javacpp.LongPointer;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.bytedeco.javacpp.*;
|
||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||
|
@ -76,6 +74,18 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDArguments(DataType... arguments) {
|
||||
if (arguments.length > 0) {
|
||||
super.setDArguments(arguments);
|
||||
val args = new int[arguments.length];
|
||||
for (int e = 0; e < arguments.length; e++)
|
||||
args[e] = arguments[e].toInt();
|
||||
|
||||
nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length);
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setRngStates(long rootState, long nodeState) {
|
||||
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
||||
|
|
|
@ -2984,12 +2984,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe
|
|||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs);
|
||||
|
||||
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
||||
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
||||
|
|
|
@ -17,10 +17,9 @@
|
|||
package org.nd4j.linalg.cpu.nativecpu.ops;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.bytedeco.javacpp.BooleanPointer;
|
||||
import org.bytedeco.javacpp.DoublePointer;
|
||||
import org.bytedeco.javacpp.LongPointer;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.*;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||
import org.nd4j.linalg.api.ops.ExecutionMode;
|
||||
|
@ -73,6 +72,18 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
|
|||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDArguments(DataType... arguments) {
|
||||
if (arguments.length > 0) {
|
||||
super.setDArguments(arguments);
|
||||
val args = new int[arguments.length];
|
||||
for (int e = 0; e < arguments.length; e++)
|
||||
args[e] = arguments[e].toInt();
|
||||
|
||||
nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length);
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setRngStates(long rootState, long nodeState) {
|
||||
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
||||
|
|
|
@ -1636,6 +1636,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
context.setBArguments(op.bArgs());
|
||||
context.setIArguments(op.iArgs());
|
||||
context.setTArguments(op.tArgs());
|
||||
context.setDArguments(op.dArgs());
|
||||
|
||||
val result = exec(op, context);
|
||||
val states = context.getRngStates();
|
||||
|
@ -1712,6 +1713,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val bArgs = op.numBArguments() > 0 ? new BooleanPointer(op.numBArguments()) : null;
|
||||
|
||||
val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null;
|
||||
|
||||
cnt = 0;
|
||||
val bArgs1 = op.bArgs();
|
||||
for (val b: bArgs1)
|
||||
|
@ -1722,11 +1725,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
for (val t: tArgs1)
|
||||
tArgs.put(cnt++, t);
|
||||
|
||||
cnt = 0;
|
||||
val dArgs1 = op.dArgs();
|
||||
for (val d: dArgs1)
|
||||
dArgs.put(cnt++, d.toInt());
|
||||
|
||||
|
||||
OpaqueShapeList ptrptr;
|
||||
try {
|
||||
ptrptr = loop.calculateOutputShapes2(null,
|
||||
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
|
||||
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
|
||||
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments(), dArgs, op.numDArguments());
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
|
|
@ -2987,12 +2987,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe
|
|||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs);
|
||||
|
||||
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
||||
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
||||
|
@ -17951,7 +17951,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
*
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_zeros_as)
|
||||
@Namespace("nd4j::ops") public static class zeros_as extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class zeros_as extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public zeros_as(Pointer p) { super(p); }
|
||||
|
@ -17962,10 +17962,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
return (zeros_as)super.position(position);
|
||||
}
|
||||
|
||||
public zeros_as() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
public zeros_as() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
|
@ -17975,7 +17975,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
*
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_ones_as)
|
||||
@Namespace("nd4j::ops") public static class ones_as extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class ones_as extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public ones_as(Pointer p) { super(p); }
|
||||
|
@ -17986,10 +17986,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
return (ones_as)super.position(position);
|
||||
}
|
||||
|
||||
public ones_as() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
public ones_as() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -1169,6 +1169,26 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOneHot4() {
|
||||
|
||||
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable indices = sd.constant("indices", indicesArr);
|
||||
int depth = 3;
|
||||
int axis = -1;
|
||||
SDVariable oneHot = sd.oneHot("oneHot", indices, depth, axis, 5.0, 0.0, DataType.INT32);
|
||||
|
||||
INDArray exp = Nd4j.create(new int[][]{{5, 0, 0}, {0,0,5}, {0,0,0}, {0, 5, 0}});
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expected(oneHot, exp)
|
||||
.gradientCheck(false));
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOneHot3() {
|
||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
||||
|
@ -1204,8 +1224,6 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
assertNull(err);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testLinspace(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeArea;
|
|||
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
||||
|
@ -1673,4 +1674,13 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOnesLike_1() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 3, 4, 5);
|
||||
val e = Nd4j.ones(DataType.INT32, 3, 4, 5);
|
||||
|
||||
val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0];
|
||||
assertEquals(e, z);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue