Configurable DataType for ops (#201)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* - one more test for OneHot with dtype
- one more signature in Nd4j

Signed-off-by: raver119 <raver119@gmail.com>

* ones_as/zeros_as now accept dtype

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* - more updates for configurable data types
- ones_as/zeros_as java side + tests

Signed-off-by: raver119 <raver119@gmail.com>

* few c++ tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* few more changes around DArgs

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-30 18:46:12 +03:00 committed by GitHub
parent ba961c7601
commit 5d98cfcf47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 448 additions and 87 deletions

View File

@ -1518,7 +1518,7 @@ ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPoi
typedef nd4j::ShapeList OpaqueShapeList;
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs);
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);

View File

@ -1974,7 +1974,7 @@ void deleteShapeList(Nd4jPointer shapeList) {
delete list;
}
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
nd4j::graph::VariableSpace varSpace;
Context block(2, &varSpace);
nd4j::ShapeList inShapes;
@ -1988,6 +1988,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
for (int e = 0; e < numBArgs; e++)
block.getBArguments()->push_back(bArgs[e]);
for (int e = 0; e < numDArgs; e++)
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
for (int e = 0; e < numInputShapes; e++) {
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
@ -2015,11 +2018,11 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
return shapeList;
}
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
try {
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs);
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());

View File

@ -2684,7 +2684,7 @@ const char* getAllCustomOps() {
}
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
nd4j::graph::VariableSpace varSpace;
Context block(2, &varSpace);
nd4j::ShapeList inShapes;
@ -2698,6 +2698,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
for (int e = 0; e < numBArgs; e++)
block.getBArguments()->push_back(bArgs[e]);
for (int e = 0; e < numDArgs; e++)
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
for (int e = 0; e < numInputShapes; e++) {
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
@ -2722,12 +2725,12 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
return shapeList;
}
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
try {
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs,
iArgs, numIArgs, bArgs, numBArgs);
iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());

View File

@ -112,6 +112,14 @@ public struct FlatNode : IFlatbufferObject
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
public DType ExtraTypes(int j) { int o = __p.__offset(48); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
public int ExtraTypesLength { get { int o = __p.__offset(48); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T
public Span<byte> GetExtraTypesBytes() { return __p.__vector_as_span(48); }
#else
public ArraySegment<byte>? GetExtraTypesBytes() { return __p.__vector_as_arraysegment(48); }
#endif
public DType[] GetExtraTypesArray() { return __p.__vector_as_array<DType>(48); }
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
int id = 0,
@ -135,9 +143,11 @@ public struct FlatNode : IFlatbufferObject
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
VectorOffset controlDepsOffset = default(VectorOffset),
VectorOffset varControlDepsOffset = default(VectorOffset),
VectorOffset controlDepForOffset = default(VectorOffset)) {
builder.StartObject(22);
VectorOffset controlDepForOffset = default(VectorOffset),
VectorOffset extraTypesOffset = default(VectorOffset)) {
builder.StartObject(23);
FlatNode.AddOpNum(builder, opNum);
FlatNode.AddExtraTypes(builder, extraTypesOffset);
FlatNode.AddControlDepFor(builder, controlDepForOffset);
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
FlatNode.AddControlDeps(builder, controlDepsOffset);
@ -162,7 +172,7 @@ public struct FlatNode : IFlatbufferObject
return FlatNode.EndFlatNode(builder);
}
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); }
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(23); }
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
@ -224,6 +234,10 @@ public struct FlatNode : IFlatbufferObject
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddExtraTypes(FlatBufferBuilder builder, VectorOffset extraTypesOffset) { builder.AddOffset(22, extraTypesOffset.Value, 0); }
public static VectorOffset CreateExtraTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
public static VectorOffset CreateExtraTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
int o = builder.EndObject();
return new Offset<FlatNode>(o);

View File

@ -72,6 +72,10 @@ public final class FlatNode extends Table {
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; }
public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; }
public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); }
public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); }
public static int createFlatNode(FlatBufferBuilder builder,
int id,
@ -95,9 +99,11 @@ public final class FlatNode extends Table {
int scalarOffset,
int controlDepsOffset,
int varControlDepsOffset,
int controlDepForOffset) {
builder.startObject(22);
int controlDepForOffset,
int extraTypesOffset) {
builder.startObject(23);
FlatNode.addOpNum(builder, opNum);
FlatNode.addExtraTypes(builder, extraTypesOffset);
FlatNode.addControlDepFor(builder, controlDepForOffset);
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
FlatNode.addControlDeps(builder, controlDepsOffset);
@ -122,7 +128,7 @@ public final class FlatNode extends Table {
return FlatNode.endFlatNode(builder);
}
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); }
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
@ -171,6 +177,9 @@ public final class FlatNode extends Table {
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); }
public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
public static int endFlatNode(FlatBufferBuilder builder) {
int o = builder.endObject();
return o;

View File

@ -339,7 +339,29 @@ class FlatNode(object):
return self._tab.VectorLen(o)
return 0
def FlatNodeStart(builder): builder.StartObject(22)
# FlatNode
def ExtraTypes(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
return 0
# FlatNode
def ExtraTypesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
return 0
# FlatNode
def ExtraTypesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
if o != 0:
return self._tab.VectorLen(o)
return 0
def FlatNodeStart(builder): builder.StartObject(23)
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
@ -375,4 +397,6 @@ def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTR
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatNodeAddExtraTypes(builder, extraTypes): builder.PrependUOffsetTRelativeSlot(22, flatbuffers.number_types.UOffsetTFlags.py_type(extraTypes), 0)
def FlatNodeStartExtraTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
def FlatNodeEnd(builder): return builder.EndObject()

View File

@ -38,7 +38,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_SCALAR = 40,
VT_CONTROLDEPS = 42,
VT_VARCONTROLDEPS = 44,
VT_CONTROLDEPFOR = 46
VT_CONTROLDEPFOR = 46,
VT_EXTRATYPES = 48
};
int32_t id() const {
return GetField<int32_t>(VT_ID, 0);
@ -106,6 +107,9 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
}
const flatbuffers::Vector<int8_t> *extraTypes() const {
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_EXTRATYPES);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_ID) &&
@ -153,6 +157,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
verifier.VerifyVector(controlDepFor()) &&
verifier.VerifyVectorOfStrings(controlDepFor()) &&
VerifyOffset(verifier, VT_EXTRATYPES) &&
verifier.VerifyVector(extraTypes()) &&
verifier.EndTable();
}
};
@ -226,6 +232,9 @@ struct FlatNodeBuilder {
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
}
void add_extraTypes(flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes) {
fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes);
}
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@ -261,9 +270,11 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNode(
flatbuffers::Offset<FlatArray> scalar = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0) {
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes = 0) {
FlatNodeBuilder builder_(_fbb);
builder_.add_opNum(opNum);
builder_.add_extraTypes(extraTypes);
builder_.add_controlDepFor(controlDepFor);
builder_.add_varControlDeps(varControlDeps);
builder_.add_controlDeps(controlDeps);
@ -311,7 +322,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
flatbuffers::Offset<FlatArray> scalar = 0,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr) {
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr,
const std::vector<int8_t> *extraTypes = nullptr) {
return nd4j::graph::CreateFlatNode(
_fbb,
id,
@ -335,7 +347,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
scalar,
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0);
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0,
extraTypes ? _fbb.CreateVector<int8_t>(*extraTypes) : 0);
}
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {

View File

@ -398,11 +398,36 @@ nd4j.graph.FlatNode.prototype.controlDepForLength = function() {
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/**
* @param {number} index
* @returns {nd4j.graph.DType}
*/
nd4j.graph.FlatNode.prototype.extraTypes = function(index) {
var offset = this.bb.__offset(this.bb_pos, 48);
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
};
/**
* @returns {number}
*/
nd4j.graph.FlatNode.prototype.extraTypesLength = function() {
var offset = this.bb.__offset(this.bb_pos, 48);
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/**
* @returns {Int8Array}
*/
nd4j.graph.FlatNode.prototype.extraTypesArray = function() {
var offset = this.bb.__offset(this.bb_pos, 48);
return offset ? new Int8Array(this.bb.bytes().buffer, this.bb.bytes().byteOffset + this.bb.__vector(this.bb_pos + offset), this.bb.__vector_len(this.bb_pos + offset)) : null;
};
/**
* @param {flatbuffers.Builder} builder
*/
nd4j.graph.FlatNode.startFlatNode = function(builder) {
builder.startObject(22);
builder.startObject(23);
};
/**
@ -854,6 +879,35 @@ nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) {
builder.startVector(4, numElems, 4);
};
/**
* @param {flatbuffers.Builder} builder
* @param {flatbuffers.Offset} extraTypesOffset
*/
nd4j.graph.FlatNode.addExtraTypes = function(builder, extraTypesOffset) {
builder.addFieldOffset(22, extraTypesOffset, 0);
};
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<nd4j.graph.DType>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatNode.createExtraTypesVector = function(builder, data) {
builder.startVector(1, data.length, 1);
for (var i = data.length - 1; i >= 0; i--) {
builder.addInt8(data[i]);
}
return builder.endVector();
};
/**
* @param {flatbuffers.Builder} builder
* @param {number} numElems
*/
nd4j.graph.FlatNode.startExtraTypesVector = function(builder, numElems) {
builder.startVector(1, numElems, 1);
};
/**
* @param {flatbuffers.Builder} builder
* @returns {flatbuffers.Offset}

View File

@ -587,9 +587,9 @@ namespace nd4j {
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
}
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
}
}
@ -624,9 +624,9 @@ namespace nd4j {
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
}
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
}
}
@ -664,9 +664,9 @@ namespace nd4j {
block->getBArguments()->push_back(node->extraBools()->Get(e));
}
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
}
}

View File

@ -58,6 +58,8 @@ table FlatNode {
varControlDeps:[string];
controlDepFor:[string];
// DArgs
extraTypes:[DType];
}
root_type FlatNode;

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(ones_as, 1, 1, false) {
CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) {
auto output = OUTPUT_VARIABLE(0);
output->assign(1);
@ -33,11 +33,21 @@ namespace nd4j {
return Status::OK();
}
DECLARE_SHAPE_FN(ones_as) {
auto in = inputShape->at(0);
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str());
return SHAPELIST(shape);
}
DECLARE_TYPES(ones_as) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::ANY)
->setSameMode(true);
->setSameMode(false);
}
}
}

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(zeros_as, 1, 1, false) {
CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) {
auto out = OUTPUT_VARIABLE(0);
out->assign(0); // output is filled by zero by default
@ -35,11 +35,20 @@ namespace nd4j {
DECLARE_SYN(zeroslike, zeros_as);
DECLARE_SYN(zeros_like, zeros_as);
DECLARE_SHAPE_FN(zeros_as) {
auto in = inputShape->at(0);
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
return SHAPELIST(shape);
}
DECLARE_TYPES(zeros_as) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::ANY)
->setSameMode(true);
->setSameMode(false);
}
}
}

View File

@ -487,7 +487,7 @@ namespace nd4j {
*
*/
#if NOT_EXCLUDED(OP_zeros_as)
DECLARE_OP(zeros_as, 1, 1, false);
DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0);
#endif
/**
@ -497,7 +497,7 @@ namespace nd4j {
*
*/
#if NOT_EXCLUDED(OP_ones_as)
DECLARE_OP(ones_as, 1, 1, false);
DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0);
#endif
/**

View File

@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
double tArgs[] = { -1.0, 1.0, 0.01 };
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0);
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0);
shape::printShapeInfoLinear("Result", shapes->at(0));
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));

View File

@ -2978,6 +2978,24 @@ TEST_F(DeclarableOpsTests8, ones_as_test2) {
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests8, ones_as_test3) {
auto x = NDArrayFactory::create<double>(10.);
//auto y = NDArrayFactory::create<double>(100.);
auto exp = NDArrayFactory::create<int>(1.);
nd4j::ops::ones_as op;
auto results = op.evaluate({&x}, {}, {}, {}, {nd4j::DataType::INT32});
ASSERT_EQ(Status::OK(), results->status());
auto y = results->at(0);
ASSERT_TRUE(y->isSameShape(exp));
ASSERT_TRUE(y->equalsTo(exp));
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) {

View File

@ -112,7 +112,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
Nd4jLong iArgs[] = {1};
auto hash = op.getOpHash();
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0);
ASSERT_EQ(3, shapeList->size());
@ -1065,7 +1065,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
NDArray::prepareSpecialUse({}, {&arrayX, &arrayY});
nd4j::ops::greater_equal op;
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0);
NDArray::registerSpecialUse({}, {&arrayX, &arrayY});
delete shapeList;
}

View File

@ -1579,7 +1579,7 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) {
#endif
auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast<double*>(tArgs.data()), tArgs.size(),
const_cast<Nd4jLong*>(iArgs.data()), iArgs.size(), nullptr, bArgsF.size());
const_cast<Nd4jLong*>(iArgs.data()), iArgs.size(), nullptr, bArgsF.size(), nullptr, 0);
// Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs
ASSERT_EQ(1, shapeList->size());

View File

@ -4704,7 +4704,7 @@ public class SameDiff extends SDBaseOps {
0,
0,
-1,
0, 0, 0, 0, 0, 0, 0, 0, 0);
0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
return flatNode;
}

View File

@ -17,6 +17,7 @@
package org.nd4j.autodiff.samediff.serde;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.shade.guava.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import java.nio.ByteOrder;
@ -361,6 +362,11 @@ public class FlatBuffersMapper {
for (int i = 0; i < extraBools.length; i++) {
extraBools[i] = fn.extraBools(i);
}
DataType[] extraDTypes = new DataType[fn.extraTypesLength()];
for (int i = 0; i < extraDTypes.length; i++) {
extraDTypes[i] = DataType.fromInt(fn.extraTypes(i));
}
int[] dimensions = new int[fn.dimensionsLength()];
for (int i = 0; i < dimensions.length; i++) {
dimensions[i] = fn.dimensions(i);
@ -401,6 +407,7 @@ public class FlatBuffersMapper {
((CustomOp) op).addIArgument(extraInteger);
((CustomOp) op).addTArgument(extraParams);
((CustomOp) op).addBArgument(extraBools);
((CustomOp) op).addDArgument(extraDTypes);
op.setPropertiesForFunction(props);
return op;
@ -714,11 +721,20 @@ public class FlatBuffersMapper {
}
boolean[] boolArgs = null;
byte[] dtypeArgs = null;
long[] extraBits = null;
if (node.opType() == Op.Type.CUSTOM) {
DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) node;
val dynamicCustomOp = (DynamicCustomOp) node;
extraBits = dynamicCustomOp.iArgs();
boolArgs = dynamicCustomOp.bArgs();
if (dynamicCustomOp.numDArguments() > 0) {
dtypeArgs = new byte[dynamicCustomOp.numDArguments()];
val d = dynamicCustomOp.dArgs();
for (int e = 0; e < dtypeArgs.length; e++) {
dtypeArgs[e] = (byte) d[e].toInt();
}
}
} else if (node instanceof Enter) {
// in case of Enter node we'll be storing unique frame reference
val frameName = ((Enter) node).getFrameName();
@ -817,6 +833,7 @@ public class FlatBuffersMapper {
int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras);
int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits);
int bArgs = FlatNode.createExtraBoolsVector(bufferBuilder, boolArgs != null ? boolArgs : new boolean[0]);
int dArgs = FlatNode.createOutputTypesVector(bufferBuilder, dtypeArgs != null ? dtypeArgs : new byte[0]);
int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims);
int fname = bufferBuilder.createString(node.getOwnName());
int scopeName = bufferBuilder.createString("");
@ -896,7 +913,8 @@ public class FlatBuffersMapper {
scalar,
opCds,
varCds,
cdsFor
cdsFor,
dArgs
);
return flatNode;

View File

@ -73,6 +73,10 @@ public final class FlatNode extends Table {
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; }
public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; }
public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); }
public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); }
public static int createFlatNode(FlatBufferBuilder builder,
int id,
@ -96,9 +100,11 @@ public final class FlatNode extends Table {
int scalarOffset,
int controlDepsOffset,
int varControlDepsOffset,
int controlDepForOffset) {
builder.startObject(22);
int controlDepForOffset,
int extraTypesOffset) {
builder.startObject(23);
FlatNode.addOpNum(builder, opNum);
FlatNode.addExtraTypes(builder, extraTypesOffset);
FlatNode.addControlDepFor(builder, controlDepForOffset);
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
FlatNode.addControlDeps(builder, controlDepsOffset);
@ -123,7 +129,7 @@ public final class FlatNode extends Table {
return FlatNode.endFlatNode(builder);
}
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); }
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
@ -172,6 +178,9 @@ public final class FlatNode extends Table {
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); }
public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
public static int endFlatNode(FlatBufferBuilder builder) {
int o = builder.endObject();
return o;

View File

@ -61,6 +61,7 @@ public class DifferentialFunctionClassHolder {
add("tArguments");
add("iArguments");
add("bArguments");
add("dArguments");
add("hash");
add("opName");
add("sameDiff");

View File

@ -20,6 +20,7 @@ import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.val;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.*;
@ -33,9 +34,10 @@ public abstract class BaseOpContext implements OpContext {
protected Map<Integer,INDArray> fastpath_in = new HashMap<>();
protected Map<Integer,INDArray> fastpath_out = new HashMap<>();
protected List<Double> fastpath_d = new ArrayList<>();
protected List<Double> fastpath_t = new ArrayList<>();
protected List<Boolean> fastpath_b = new ArrayList<>();
protected List<Long> fastpath_i = new ArrayList<>();
protected List<DataType> fastpath_d = new ArrayList<>();
@Setter()
@Getter
@ -55,14 +57,14 @@ public abstract class BaseOpContext implements OpContext {
@Override
public void setTArguments(double... arguments) {
fastpath_d.clear();
fastpath_t.clear();
for (val v:arguments)
fastpath_d.add(v);
fastpath_t.add(v);
}
@Override
public List<Double> getTArguments(){
return fastpath_d;
return fastpath_t;
}
@Override
@ -77,6 +79,18 @@ public abstract class BaseOpContext implements OpContext {
return fastpath_b;
}
@Override
public void setDArguments(DataType... arguments) {
fastpath_d.clear();
for (val v:arguments)
fastpath_d.add(v);
}
@Override
public List<DataType> getDArguments() {
return fastpath_d;
}
@Override
public void setInputArray(int index, @NonNull INDArray array) {
fastpath_in.put(index, array);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
@ -57,12 +58,18 @@ public interface CustomOp {
boolean[] bArgs();
DataType[] dArgs();
void addTArgument(double... arg);
void addIArgument(int... arg);
void addIArgument(long... arg);
void addBArgument(boolean... arg);
void addDArgument(DataType... arg);
void removeIArgument(Integer arg);
Boolean getBArgument(int index);
@ -71,8 +78,6 @@ public interface CustomOp {
int numIArguments();
void addTArgument(double... arg);
void removeTArgument(Double arg);
Double getTArgument(int index);
@ -81,6 +86,8 @@ public interface CustomOp {
int numBArguments();
int numDArguments();
void addInputArgument(INDArray... arg);
void removeInputArgument(INDArray arg);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.shade.guava.collect.Lists;
import org.nd4j.shade.guava.primitives.Doubles;
import org.nd4j.shade.guava.primitives.Longs;
@ -62,6 +63,9 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
@Builder.Default
protected List<Boolean> bArguments = new ArrayList<>();
@Builder.Default
protected List<DataType> dArguments = new ArrayList<>();
@Builder.Default
protected List<Integer> axis = new ArrayList<>();
@ -77,6 +81,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
iArguments = new ArrayList<>();
tArguments = new ArrayList<>();
bArguments = new ArrayList<>();
dArguments = new ArrayList<>();
}
public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) {
@ -93,6 +98,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
iArguments = new ArrayList<>();
tArguments = new ArrayList<>();
bArguments = new ArrayList<>();
dArguments = new ArrayList<>();
}
public DynamicCustomOp(String opName, INDArray input, INDArray output, List<Double> tArguments, int[] iArguments) {
@ -132,6 +138,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
this.iArguments.add((Long) a.longValue());
}
bArguments = new ArrayList<>();
dArguments = new ArrayList<>();
}
/**
@ -173,6 +180,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
iArguments = new ArrayList<>();
tArguments = new ArrayList<>();
bArguments = new ArrayList<>();
dArguments = new ArrayList<>();
this.inplaceCall = inPlace;
}
@ -185,6 +193,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
iArguments = new ArrayList<>();
tArguments = new ArrayList<>();
bArguments = new ArrayList<>();
dArguments = new ArrayList<>();
}
@ -260,6 +269,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
return hash;
}
@Override
public int numDArguments() {
return dArguments.size();
}
@Override
public List<INDArray> outputArguments() {
return outputArguments;
@ -280,6 +294,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
return Doubles.toArray(tArguments);
}
@Override
public DataType[] dArgs() {
return dArguments.toArray(new DataType[dArguments.size()]);
}
@Override
public void addIArgument(int... arg) {
for (long a: arg)
@ -323,6 +342,15 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
addTArgument(Doubles.asList(arg).toArray(new Double[arg.length]));
}
@Override
public void addDArgument(DataType... arg) {
if (dArguments == null)
dArguments = new ArrayList<>();
if (arg != null)
dArguments.addAll(Arrays.asList(arg));
}
private void addTArgument(Double... arg) {
tArguments.addAll(Arrays.asList(arg));
}
@ -650,6 +678,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
private List<INDArray> outputArguments = new ArrayList<>();
private List<Double> tArguments = new ArrayList<>();
private List<Long> iArguments = new ArrayList<>();
private List<DataType> dArguments = new ArrayList<>();
private List<Boolean> bArguments = new ArrayList<>();
protected DynamicCustomOpsBuilder(String opName, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) {
@ -870,6 +899,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
result.iArguments = iArguments;
result.tArguments = tArguments;
result.bArguments = bArguments;
result.dArguments = dArguments;
result.inplaceCall = inplaceCall;
result.hash = opHash;
result.outputShapes = outputShapes;

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
@ -43,9 +44,15 @@ public interface OpContext extends AutoCloseable {
* @param arguments
*/
void setTArguments(double... arguments);
List<Double> getTArguments();
/**
* This method sets data type arguments required for operation
* @param arguments
*/
void setDArguments(DataType... arguments);
List<DataType> getDArguments();
/**
* This method sets boolean arguments required for operation
* @param arguments

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import lombok.val;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
@ -246,6 +247,21 @@ public class ScatterUpdate implements CustomOp {
}
@Override
public DataType[] dArgs() {
return new DataType[0];
}
@Override
public void addDArgument(DataType... arg) {
}
@Override
public int numDArguments() {
return 0;
}
@Override
public void clearArrays() {
op.clearArrays();

View File

@ -83,6 +83,9 @@ public class OneHot extends DynamicCustomOp {
addIArgument(depth);
addTArgument(on);
addTArgument(off);
if (outputType != null)
addDArgument(outputType);
}
@Override

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -53,6 +55,22 @@ public class OnesLike extends DynamicCustomOp {
public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
super(name, sameDiff, new SDVariable[]{input}, false);
this.outputType = dataType;
addArgs();
}
public OnesLike(@NonNull INDArray input, DataType dataType) {
this.addInputArgument(input);
this.outputType = dataType;
addArgs();
}
public OnesLike(@NonNull INDArray input) {
this(input, input.dataType());
}
public void addArgs() {
if (outputType != null)
addDArgument(outputType);
}
@ -78,6 +96,8 @@ public class OnesLike extends DynamicCustomOp {
if(attributesForNode.containsKey("T")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
}
addArgs();
}
@Override

View File

@ -3438,6 +3438,16 @@ public class Nd4j {
return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length);
}
/**
* Create 2D double array based on java 2d double array. and ordering
*
* @param data the data to use
* @return the created ndarray.
*/
public static INDArray create(int[][] data) {
return createFromArray(data);
}
/**
* create 3D int array based on 3D java int array.
* @param data java 3D i array.

View File

@ -1056,7 +1056,7 @@ public interface NativeOps {
OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs);
OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, @Cast("int *") IntPointer dArgs, int numDArgs);
long getShapeListSize(OpaqueShapeList list);
LongPointer getShape(OpaqueShapeList list, long i);
@ -1156,6 +1156,7 @@ public interface NativeOps {
void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments);
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);

View File

@ -1928,6 +1928,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null;
val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null;
cnt = 0;
for (val b: op.bArgs())
bArgs.put(cnt++, b);
@ -1936,7 +1938,12 @@ public class CudaExecutioner extends DefaultOpExecutioner {
for (val t: op.tArgs())
tArgs.put(cnt++, t);
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
cnt = 0;
val dArgs1 = op.dArgs();
for (val d: dArgs1)
dArgs.put(cnt++, d.toInt());
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments());
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
@ -2003,6 +2010,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
context.setBArguments(op.bArgs());
context.setIArguments(op.iArgs());
context.setTArguments(op.tArgs());
context.setDArguments(op.dArgs());
val result = exec(op, context);
val states = context.getRngStates();

View File

@ -18,12 +18,10 @@ package org.nd4j.linalg.jcublas.ops.executioner;
import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.*;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext;
@ -76,6 +74,18 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
}
}
@Override
public void setDArguments(DataType... arguments) {
if (arguments.length > 0) {
super.setDArguments(arguments);
val args = new int[arguments.length];
for (int e = 0; e < arguments.length; e++)
args[e] = arguments[e].toInt();
nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length);
};
}
@Override
public void setRngStates(long rootState, long nodeState) {
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);

View File

@ -2984,12 +2984,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs);
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);

View File

@ -17,10 +17,9 @@
package org.nd4j.linalg.cpu.nativecpu.ops;
import lombok.NonNull;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import lombok.val;
import org.bytedeco.javacpp.*;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext;
import org.nd4j.linalg.api.ops.ExecutionMode;
@ -73,6 +72,18 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
};
}
@Override
public void setDArguments(DataType... arguments) {
if (arguments.length > 0) {
super.setDArguments(arguments);
val args = new int[arguments.length];
for (int e = 0; e < arguments.length; e++)
args[e] = arguments[e].toInt();
nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length);
};
}
@Override
public void setRngStates(long rootState, long nodeState) {
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);

View File

@ -1636,6 +1636,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
context.setBArguments(op.bArgs());
context.setIArguments(op.iArgs());
context.setTArguments(op.tArgs());
context.setDArguments(op.dArgs());
val result = exec(op, context);
val states = context.getRngStates();
@ -1712,6 +1713,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
val bArgs = op.numBArguments() > 0 ? new BooleanPointer(op.numBArguments()) : null;
val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null;
cnt = 0;
val bArgs1 = op.bArgs();
for (val b: bArgs1)
@ -1722,11 +1725,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
for (val t: tArgs1)
tArgs.put(cnt++, t);
cnt = 0;
val dArgs1 = op.dArgs();
for (val d: dArgs1)
dArgs.put(cnt++, d.toInt());
OpaqueShapeList ptrptr;
try {
ptrptr = loop.calculateOutputShapes2(null,
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments(), dArgs, op.numDArguments());
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());

View File

@ -2987,12 +2987,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs);
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs);
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
@ -17951,7 +17951,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
*
*/
// #if NOT_EXCLUDED(OP_zeros_as)
@Namespace("nd4j::ops") public static class zeros_as extends DeclarableOp {
@Namespace("nd4j::ops") public static class zeros_as extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public zeros_as(Pointer p) { super(p); }
@ -17962,10 +17962,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (zeros_as)super.position(position);
}
public zeros_as() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public zeros_as() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
@ -17975,7 +17975,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
*
*/
// #if NOT_EXCLUDED(OP_ones_as)
@Namespace("nd4j::ops") public static class ones_as extends DeclarableOp {
@Namespace("nd4j::ops") public static class ones_as extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ones_as(Pointer p) { super(p); }
@ -17986,10 +17986,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (ones_as)super.position(position);
}
public ones_as() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public ones_as() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**

View File

@ -1169,6 +1169,26 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err);
}
@Test
public void testOneHot4() {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
SameDiff sd = SameDiff.create();
SDVariable indices = sd.constant("indices", indicesArr);
int depth = 3;
int axis = -1;
SDVariable oneHot = sd.oneHot("oneHot", indices, depth, axis, 5.0, 0.0, DataType.INT32);
INDArray exp = Nd4j.create(new int[][]{{5, 0, 0}, {0,0,5}, {0,0,0}, {0, 5, 0}});
String err = OpValidation.validate(new TestCase(sd)
.expected(oneHot, exp)
.gradientCheck(false));
assertNull(err);
}
@Test
public void testOneHot3() {
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
@ -1204,8 +1224,6 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err);
}
@Test
public void testLinspace(){
SameDiff sd = SameDiff.create();

View File

@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeArea;
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.shape.Create;
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
@ -1673,4 +1674,13 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, ret[0]);
}
@Test
public void testOnesLike_1() {
val x = Nd4j.create(DataType.FLOAT, 3, 4, 5);
val e = Nd4j.ones(DataType.INT32, 3, 4, 5);
val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0];
assertEquals(e, z);
}
}