DataTypes & FlatBuffers (#197)

* flatbuffers version upgrade

Signed-off-by: raver119 <raver119@gmail.com>

* flatbuffers version upgrade java side

Signed-off-by: raver119 <raver119@gmail.com>

* flatbuffers dependency version upgrade java side

Signed-off-by: raver119 <raver119@gmail.com>

* MKLDNN version upgrade

Signed-off-by: raver119 <raver119@gmail.com>

* DArgs first pass

Signed-off-by: raver119 <raver119@gmail.com>

* signatures first pass

Signed-off-by: raver119 <raver119@gmail.com>

* signatures second pass

Signed-off-by: raver119 <raver119@gmail.com>

* signatures third pass

Signed-off-by: raver119 <raver119@gmail.com>

* signatures third pass

Signed-off-by: raver119 <raver119@gmail.com>

* signatures fourth pass

Signed-off-by: raver119 <raver119@gmail.com>

* signatures fifth pass

Signed-off-by: raver119 <raver119@gmail.com>

* flatbuffers UI version upgrade java side

Signed-off-by: raver119 <raver119@gmail.com>

* flatbuffers ui update

Signed-off-by: raver119 <raver119@gmail.com>

* flatbuffers downgrade

Signed-off-by: raver119 <raver119@gmail.com>

* flatbuffers downgrade java side

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-30 10:07:24 +03:00 committed by GitHub
parent 5039fb22b7
commit ba961c7601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
84 changed files with 2804 additions and 2497 deletions

View File

@ -5,7 +5,7 @@ project(flatbuffers-download NONE)
include(ExternalProject) include(ExternalProject)
ExternalProject_Add(flatbuffers ExternalProject_Add(flatbuffers
GIT_REPOSITORY https://github.com/google/flatbuffers.git GIT_REPOSITORY https://github.com/google/flatbuffers.git
GIT_TAG v1.10.0 GIT_TAG v1.11.0
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src" SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

View File

@ -5,7 +5,7 @@ project(mkldnn-download NONE)
include(ExternalProject) include(ExternalProject)
ExternalProject_Add(mkldnn ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.1.2 GIT_TAG v1.1.3
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

View File

@ -1607,6 +1607,7 @@ ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);

View File

@ -2130,7 +2130,7 @@ Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4
biArgs[e] = bArgs[e]; biArgs[e] = bArgs[e];
// hypothetically at this point we have everything filled // hypothetically at this point we have everything filled
auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, isInplace); auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, std::vector<nd4j::DataType>(), isInplace);
//auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace); //auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
@ -2788,6 +2788,15 @@ void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, i
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) { void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
ptr->setBArguments(arguments, numberOfArguments); ptr->setBArguments(arguments, numberOfArguments);
} }
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
std::vector<nd4j::DataType> dtypes(numberOfArguments);
for (int e = 0; e < numberOfArguments; e++)
dtypes[e] = (nd4j::DataType) arguments[e];
ptr->setDArguments(dtypes);
}
void deleteGraphContext(nd4j::graph::Context* ptr) { void deleteGraphContext(nd4j::graph::Context* ptr) {
delete ptr; delete ptr;
} }

View File

@ -2831,7 +2831,7 @@ static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer*
// hypothetically at this point we have everything filled // hypothetically at this point we have everything filled
auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, isInplace); auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, std::vector<nd4j::DataType>(), isInplace);
//auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace); //auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
@ -3596,6 +3596,14 @@ void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int n
ptr->setBArguments(arguments, numberOfArguments); ptr->setBArguments(arguments, numberOfArguments);
} }
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
std::vector<nd4j::DataType> dtypes(numberOfArguments);
for (int e = 0; e < numberOfArguments; e++)
dtypes[e] = (nd4j::DataType) arguments[e];
ptr->setDArguments(dtypes);
}
void deleteGraphContext(nd4j::graph::Context* ptr) { void deleteGraphContext(nd4j::graph::Context* ptr) {
delete ptr; delete ptr;
} }

View File

@ -95,6 +95,10 @@ namespace nd4j {
template<typename T> template<typename T>
// struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; }; // struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; };
struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; }; struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; };
template<typename T>
struct scalarTypesForExecution { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<Nd4jLong, T>::value || std::is_same<int, T>::value || std::is_same<bool, T>::value; };
}; };

View File

@ -158,7 +158,7 @@ namespace nd4j {
iargs.push_back(_axis); iargs.push_back(_axis);
auto result = op.execute(inputs, {}, {}, {}); auto result = op.evaluate(inputs);
auto array = new NDArray(result->at(0)->dup()); auto array = new NDArray(result->at(0)->dup());

View File

@ -197,10 +197,12 @@ namespace nd4j {
void setTArguments(double *arguments, int numberOfArguments); void setTArguments(double *arguments, int numberOfArguments);
void setIArguments(Nd4jLong *arguments, int numberOfArguments); void setIArguments(Nd4jLong *arguments, int numberOfArguments);
void setBArguments(bool *arguments, int numberOfArguments); void setBArguments(bool *arguments, int numberOfArguments);
void setDArguments(nd4j::DataType *arguments, int numberOfArguments);
void setTArguments(const std::vector<double> &tArgs); void setTArguments(const std::vector<double> &tArgs);
void setIArguments(const std::vector<Nd4jLong> &tArgs); void setIArguments(const std::vector<Nd4jLong> &tArgs);
void setBArguments(const std::vector<bool> &tArgs); void setBArguments(const std::vector<bool> &tArgs);
void setDArguments(const std::vector<nd4j::DataType> &dArgs);
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);

View File

@ -47,6 +47,9 @@ namespace nd4j {
std::vector<int> _iArgs; std::vector<int> _iArgs;
std::vector<bool> _bArgs; std::vector<bool> _bArgs;
std::vector<int> _axis; std::vector<int> _axis;
std::vector<nd4j::DataType> _dArgs;
// TODO: remove this field
nd4j::DataType _dataType = nd4j::DataType::FLOAT32; nd4j::DataType _dataType = nd4j::DataType::FLOAT32;
bool _isInplace; bool _isInplace;
@ -93,6 +96,7 @@ namespace nd4j {
std::vector<double>* getTArguments(); std::vector<double>* getTArguments();
std::vector<int>* getIArguments(); std::vector<int>* getIArguments();
std::vector<bool>* getBArguments(); std::vector<bool>* getBArguments();
std::vector<nd4j::DataType>* getDArguments();
std::vector<int>* getAxis(); std::vector<int>* getAxis();
samediff::Engine engine(); samediff::Engine engine();
@ -100,6 +104,7 @@ namespace nd4j {
size_t numT(); size_t numT();
size_t numI(); size_t numI();
size_t numB(); size_t numB();
size_t numD();
std::pair<int, int>* input(int idx); std::pair<int, int>* input(int idx);

View File

@ -38,7 +38,9 @@ namespace nd4j {
class ND4J_EXPORT Node { class ND4J_EXPORT Node {
protected: protected:
// TODO: this field must be removed
nd4j::DataType _dataType; nd4j::DataType _dataType;
OpType _opType; OpType _opType;
ContextPrototype* _protoContext = nullptr; ContextPrototype* _protoContext = nullptr;
Nd4jLong _opNum; Nd4jLong _opNum;
@ -61,6 +63,7 @@ namespace nd4j {
// optional scalar. used in scalar ops and in summary stats // optional scalar. used in scalar ops and in summary stats
// TODO: this field must be removed
NDArray _scalar; NDArray _scalar;
bool _hasExternalOutputs; bool _hasExternalOutputs;
@ -87,15 +90,15 @@ namespace nd4j {
int _scope_id = 0; int _scope_id = 0;
std::string _scope_name; std::string _scope_name;
// TODO: these 3 fields should be removed
int _rewindNode = -1; int _rewindNode = -1;
std::pair<int, int> _rewindLayer = {-1, -1}; std::pair<int, int> _rewindLayer = {-1, -1};
Nd4jLong _frameId = -1; Nd4jLong _frameId = -1;
public: public:
Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {}); explicit Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {}); explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
Node(const nd4j::graph::FlatNode *node); explicit Node(const nd4j::graph::FlatNode *node);
~Node(); ~Node();
bool equals(Node *other); bool equals(Node *other);

View File

@ -60,11 +60,13 @@ enum DType {
DType_QINT16 = 16, DType_QINT16 = 16,
DType_BFLOAT16 = 17, DType_BFLOAT16 = 17,
DType_UTF8 = 50, DType_UTF8 = 50,
DType_UTF16 = 51,
DType_UTF32 = 52,
DType_MIN = DType_INHERIT, DType_MIN = DType_INHERIT,
DType_MAX = DType_UTF8 DType_MAX = DType_UTF32
}; };
inline const DType (&EnumValuesDType())[19] { inline const DType (&EnumValuesDType())[21] {
static const DType values[] = { static const DType values[] = {
DType_INHERIT, DType_INHERIT,
DType_BOOL, DType_BOOL,
@ -84,7 +86,9 @@ inline const DType (&EnumValuesDType())[19] {
DType_QINT8, DType_QINT8,
DType_QINT16, DType_QINT16,
DType_BFLOAT16, DType_BFLOAT16,
DType_UTF8 DType_UTF8,
DType_UTF16,
DType_UTF32
}; };
return values; return values;
} }
@ -142,6 +146,8 @@ inline const char * const *EnumNamesDType() {
"", "",
"", "",
"UTF8", "UTF8",
"UTF16",
"UTF32",
nullptr nullptr
}; };
return names; return names;

View File

@ -42,7 +42,9 @@ nd4j.graph.DType = {
QINT8: 15, QINT8: 15,
QINT16: 16, QINT16: 16,
BFLOAT16: 17, BFLOAT16: 17,
UTF8: 50 UTF8: 50,
UTF16: 51,
UTF32: 52
}; };
/** /**

View File

@ -26,6 +26,8 @@ public enum DType : sbyte
QINT16 = 16, QINT16 = 16,
BFLOAT16 = 17, BFLOAT16 = 17,
UTF8 = 50, UTF8 = 50,
UTF16 = 51,
UTF32 = 52,
}; };

View File

@ -23,8 +23,10 @@ public final class DType {
public static final byte QINT16 = 16; public static final byte QINT16 = 16;
public static final byte BFLOAT16 = 17; public static final byte BFLOAT16 = 17;
public static final byte UTF8 = 50; public static final byte UTF8 = 50;
public static final byte UTF16 = 51;
public static final byte UTF32 = 52;
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", }; public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", "UTF16", "UTF32", };
public static String name(int e) { return names[e]; } public static String name(int e) { return names[e]; }
} }

View File

@ -22,4 +22,6 @@ class DType(object):
QINT16 = 16 QINT16 = 16
BFLOAT16 = 17 BFLOAT16 = 17
UTF8 = 50 UTF8 = 50
UTF16 = 51
UTF32 = 52

View File

@ -26,7 +26,7 @@ public struct UIVariable : IFlatbufferObject
#endif #endif
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); } public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } } public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
public DataType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } public DType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; } public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } } public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T #if ENABLE_SPAN_T
@ -70,7 +70,7 @@ public struct UIVariable : IFlatbufferObject
Offset<IntPair> idOffset = default(Offset<IntPair>), Offset<IntPair> idOffset = default(Offset<IntPair>),
StringOffset nameOffset = default(StringOffset), StringOffset nameOffset = default(StringOffset),
VarType type = VarType.VARIABLE, VarType type = VarType.VARIABLE,
DataType datatype = DataType.INHERIT, DType datatype = DType.INHERIT,
VectorOffset shapeOffset = default(VectorOffset), VectorOffset shapeOffset = default(VectorOffset),
VectorOffset controlDepsOffset = default(VectorOffset), VectorOffset controlDepsOffset = default(VectorOffset),
StringOffset outputOfOpOffset = default(StringOffset), StringOffset outputOfOpOffset = default(StringOffset),
@ -101,7 +101,7 @@ public struct UIVariable : IFlatbufferObject
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); } public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); }
public static void AddDatatype(FlatBufferBuilder builder, DataType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); } public static void AddDatatype(FlatBufferBuilder builder, DType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); }
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); } public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); }
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); } public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }

View File

@ -266,8 +266,8 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VarType type() const { VarType type() const {
return static_cast<VarType>(GetField<int8_t>(VT_TYPE, 0)); return static_cast<VarType>(GetField<int8_t>(VT_TYPE, 0));
} }
DataType datatype() const { DType datatype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DATATYPE, 0)); return static_cast<DType>(GetField<int8_t>(VT_DATATYPE, 0));
} }
const flatbuffers::Vector<int64_t> *shape() const { const flatbuffers::Vector<int64_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE); return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
@ -342,7 +342,7 @@ struct UIVariableBuilder {
void add_type(VarType type) { void add_type(VarType type) {
fbb_.AddElement<int8_t>(UIVariable::VT_TYPE, static_cast<int8_t>(type), 0); fbb_.AddElement<int8_t>(UIVariable::VT_TYPE, static_cast<int8_t>(type), 0);
} }
void add_datatype(DataType datatype) { void add_datatype(DType datatype) {
fbb_.AddElement<int8_t>(UIVariable::VT_DATATYPE, static_cast<int8_t>(datatype), 0); fbb_.AddElement<int8_t>(UIVariable::VT_DATATYPE, static_cast<int8_t>(datatype), 0);
} }
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) { void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
@ -389,7 +389,7 @@ inline flatbuffers::Offset<UIVariable> CreateUIVariable(
flatbuffers::Offset<IntPair> id = 0, flatbuffers::Offset<IntPair> id = 0,
flatbuffers::Offset<flatbuffers::String> name = 0, flatbuffers::Offset<flatbuffers::String> name = 0,
VarType type = VarType_VARIABLE, VarType type = VarType_VARIABLE,
DataType datatype = DataType_INHERIT, DType datatype = DType_INHERIT,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0, flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0, flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
flatbuffers::Offset<flatbuffers::String> outputOfOp = 0, flatbuffers::Offset<flatbuffers::String> outputOfOp = 0,
@ -421,7 +421,7 @@ inline flatbuffers::Offset<UIVariable> CreateUIVariableDirect(
flatbuffers::Offset<IntPair> id = 0, flatbuffers::Offset<IntPair> id = 0,
const char *name = nullptr, const char *name = nullptr,
VarType type = VarType_VARIABLE, VarType type = VarType_VARIABLE,
DataType datatype = DataType_INHERIT, DType datatype = DType_INHERIT,
const std::vector<int64_t> *shape = nullptr, const std::vector<int64_t> *shape = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr, const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
const char *outputOfOp = nullptr, const char *outputOfOp = nullptr,

View File

@ -503,11 +503,11 @@ nd4j.graph.UIVariable.prototype.type = function() {
}; };
/** /**
* @returns {nd4j.graph.DataType} * @returns {nd4j.graph.DType}
*/ */
nd4j.graph.UIVariable.prototype.datatype = function() { nd4j.graph.UIVariable.prototype.datatype = function() {
var offset = this.bb.__offset(this.bb_pos, 10); var offset = this.bb.__offset(this.bb_pos, 10);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
}; };
/** /**
@ -668,10 +668,10 @@ nd4j.graph.UIVariable.addType = function(builder, type) {
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} datatype * @param {nd4j.graph.DType} datatype
*/ */
nd4j.graph.UIVariable.addDatatype = function(builder, datatype) { nd4j.graph.UIVariable.addDatatype = function(builder, datatype) {
builder.addFieldInt8(3, datatype, nd4j.graph.DataType.INHERIT); builder.addFieldInt8(3, datatype, nd4j.graph.DType.INHERIT);
}; };
/** /**

View File

@ -551,6 +551,18 @@ namespace nd4j {
bool Context::isInference() { bool Context::isInference() {
return _execMode == samediff::ExecutionMode::MODE_INFERENCE; return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
} }
void Context::setDArguments(nd4j::DataType *arguments, int numberOfArguments) {
_dArgs.clear();
for (int e = 0; e < numberOfArguments; e++)
_dArgs.emplace_back(arguments[e]);
}
void Context::setDArguments(const std::vector<nd4j::DataType> &dArgs) {
_dArgs.clear();
for (auto d:dArgs)
_dArgs.emplace_back(d);
}
} }
} }

View File

@ -173,5 +173,13 @@ namespace nd4j {
return clone; return clone;
} }
std::vector<nd4j::DataType> *ContextPrototype::getDArguments() {
return &_dArgs;
}
size_t ContextPrototype::numD() {
return _dArgs.size();
}
} }
} }

View File

@ -587,6 +587,12 @@ namespace nd4j {
block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
} }
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
}
}
this->setContextPrototype(block); this->setContextPrototype(block);
this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
@ -618,6 +624,12 @@ namespace nd4j {
block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
} }
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
}
}
this->setContextPrototype(block); this->setContextPrototype(block);
this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
@ -652,6 +664,12 @@ namespace nd4j {
block->getBArguments()->push_back(node->extraBools()->Get(e)); block->getBArguments()->push_back(node->extraBools()->Get(e));
} }
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
}
}
for (auto v: _dimensions) for (auto v: _dimensions)
block->getAxis()->emplace_back(v); block->getAxis()->emplace_back(v);

View File

@ -40,7 +40,7 @@ namespace nd4j {
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps] NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
nd4j::ops::matmul mmul; nd4j::ops::matmul mmul;
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {}); mmul.execute({&projectionPrep, &inputPrep}, {&projected});
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength] projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
@ -66,7 +66,7 @@ namespace nd4j {
nd4j::ops::matmul_bp mmulBp; nd4j::ops::matmul_bp mmulBp;
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context); NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context); NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {}); mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, std::vector<NDArray*>{&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
dLdProjectionMatrix->assign(dLdProjectionPrep); dLdProjectionMatrix->assign(dLdProjectionPrep);

View File

@ -1516,7 +1516,9 @@
#define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList()) #define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
#define D_ARG(INDEX) block.getDArguments()->at(INDEX)
#define INT_ARG(INDEX) block.getIArguments()->at(INDEX) #define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
#define I_ARG(INDEX) INT_ARG(INDEX)
#define T_ARG(INDEX) block.getTArguments()->at(INDEX) #define T_ARG(INDEX) block.getTArguments()->at(INDEX)
#define B_ARG(INDEX) block.getBArguments()->at(INDEX) #define B_ARG(INDEX) block.getBArguments()->at(INDEX)

View File

@ -36,9 +36,8 @@ namespace nd4j {
public: public:
BooleanOp(const char *name, int numInputs, bool scalar); BooleanOp(const char *name, int numInputs, bool scalar);
bool evaluate(std::initializer_list<nd4j::NDArray*> args); bool verify(const std::vector<nd4j::NDArray*>& args);
bool evaluate(std::vector<nd4j::NDArray*>& args); bool verify(nd4j::graph::Context& block);
bool evaluate(nd4j::graph::Context& block);
Nd4jStatus execute(Context* block) override; Nd4jStatus execute(Context* block) override;

View File

@ -169,13 +169,22 @@ namespace nd4j {
*/ */
virtual Nd4jStatus execute(Context* block); virtual Nd4jStatus execute(Context* block);
nd4j::ResultSet* execute(std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs);
Nd4jStatus execute(std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
nd4j::ResultSet* execute(const std::vector<NDArray*>& inputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs = std::vector<bool>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); template <class T>
Nd4jStatus execute(std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs , std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, std::initializer_list<T> tArgs);
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
template <class T>
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false); nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false);

View File

@ -73,7 +73,7 @@ namespace nd4j {
// at first step we build fwd activation // at first step we build fwd activation
nd4j::ops::crelu op; nd4j::ops::crelu op;
auto tmpResult = op.execute({input}, {}, {}, {}); auto tmpResult = op.evaluate({input});
if (tmpResult->status() != ND4J_STATUS_OK) if (tmpResult->status() != ND4J_STATUS_OK)
return tmpResult->status(); return tmpResult->status();
@ -84,7 +84,7 @@ namespace nd4j {
helpers::reluDerivative(block.launchContext(), actv, epsilonNext); helpers::reluDerivative(block.launchContext(), actv, epsilonNext);
// now we split updated array into 2 chunks along last dimension // now we split updated array into 2 chunks along last dimension
nd4j::ops::concat_bp opc; nd4j::ops::concat_bp opc;
auto dec = opc.execute({input, input, actv}, {}, {-1}, {}); auto dec = opc.evaluate({input, input, actv}, {-1});
if (dec->status() != ND4J_STATUS_OK) if (dec->status() != ND4J_STATUS_OK)
return dec->status(); return dec->status();

View File

@ -103,7 +103,7 @@ namespace nd4j {
// if (output->isEmpty()) // if (output->isEmpty())
Nd4jLong width = condition->rankOf(); Nd4jLong width = condition->rankOf();
nd4j::ops::Where op; nd4j::ops::Where op;
std::unique_ptr<ResultSet> res(op.execute({condition}, {}, {}, {})); std::unique_ptr<ResultSet> res(op.evaluate({condition}));
REQUIRE_OK(res->status()); REQUIRE_OK(res->status());
NDArray* whereTrue = res->at(0); NDArray* whereTrue = res->at(0);
if (whereTrue->isEmpty()) if (whereTrue->isEmpty())

View File

@ -66,7 +66,7 @@ namespace nd4j {
auto gradY = OUTPUT_VARIABLE(1); auto gradY = OUTPUT_VARIABLE(1);
gradX->assign(epsNext); gradX->assign(epsNext);
nd4j::ops::floormod op; nd4j::ops::floormod op;
std::unique_ptr<ResultSet> tmpResult(op.execute({x, y}, {}, {}, {})); std::unique_ptr<ResultSet> tmpResult(op.evaluate({x, y}));
if (gradY->rankOf() == gradX->rankOf()) if (gradY->rankOf() == gradX->rankOf())
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY); epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);

View File

@ -91,7 +91,7 @@ namespace ops {
} }
nd4j::ops::softmax softmax; nd4j::ops::softmax softmax;
softmax.execute({weights}, {weights}, {}, {-2}, {}, true); softmax.execute({weights}, std::vector<NDArray*>{weights}, {}, {-2}, {}, {}, true);
mmul.execute({values, weights}, {output}, {}, {}, {}); mmul.execute({values, weights}, {output}, {}, {}, {});
@ -189,7 +189,7 @@ namespace ops {
nd4j::ops::matmul_bp mmul_bp; nd4j::ops::matmul_bp mmul_bp;
NDArray dLdw(weights.getShapeInfo(), block.workspace()); NDArray dLdw(weights.getShapeInfo(), block.workspace());
mmul_bp.execute({values, &weights, eps}, {dLdv, &dLdw}, {}, {}, {}); mmul_bp.execute({values, &weights, eps}, std::vector<NDArray*>{dLdv, &dLdw}, {}, {}, {});
NDArray dLds(preSoftmax.shapeInfo(), block.workspace()); NDArray dLds(preSoftmax.shapeInfo(), block.workspace());
nd4j::ops::softmax_bp softmax_bp; nd4j::ops::softmax_bp softmax_bp;
@ -198,7 +198,7 @@ namespace ops {
if(normalization) if(normalization)
dLds /= factor; dLds /= factor;
mmul_bp.execute({keys, queries, &dLds}, {dLdk, dLdq}, {}, {1}, {}); mmul_bp.execute({keys, queries, &dLds}, std::vector<NDArray*>{dLdk, dLdq}, {}, {1}, {});
return Status::OK(); return Status::OK();
} }

View File

@ -239,7 +239,7 @@ namespace ops {
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize}); auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
nd4j::ops::matmul_bp matmulBp; nd4j::ops::matmul_bp matmulBp;
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext()); NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {}); matmulBp.execute({&attnResults, Wo, &epsPostReshape}, std::vector<NDArray*>{&dLdPreWo, dLdWo}, {}, {}, {});
// dLdAttn // dLdAttn
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)}); dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});

View File

@ -40,7 +40,7 @@ namespace nd4j {
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf()); //nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
nd4j::ops::xw_plus_b op; nd4j::ops::xw_plus_b op;
std::unique_ptr<ResultSet> result(op.execute({x, w, b}, {}, {}, {})); std::unique_ptr<ResultSet> result(op.evaluate({x, w, b}));
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data."); REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;

View File

@ -34,7 +34,7 @@ namespace nd4j {
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
bitcast res; bitcast res;
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false); auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false);
if (tZ != &z0) { if (tZ != &z0) {
delete tZ; delete tZ;
} }

View File

@ -112,7 +112,7 @@ namespace ops {
NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType()); NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType());
originalIndices.linspace(0); originalIndices.linspace(0);
ops::dynamic_partition op; ops::dynamic_partition op;
auto res = op.execute({&originalIndices, indices}, {}, {numPartition}); auto res = op.evaluate({&originalIndices, indices}, {numPartition});
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
ops::dynamic_stitch stichOp; ops::dynamic_stitch stichOp;
std::vector<NDArray*> partitions(numPartition * 2); std::vector<NDArray*> partitions(numPartition * 2);
@ -121,7 +121,7 @@ namespace ops {
partitions[i + numPartition] = gradOutList[i]; partitions[i + numPartition] = gradOutList[i];
} }
auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false); auto result = stichOp.evaluate(partitions, {numPartition});
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
result->at(0)->reshapei(outputList[0]->getShapeAsVector()); result->at(0)->reshapei(outputList[0]->getShapeAsVector());
outputList[1]->assign(indices); outputList[1]->assign(indices);

View File

@ -66,7 +66,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
nd4j::ops::gather op; nd4j::ops::gather op;
std::unique_ptr<ResultSet> result(op.execute({input, indeces}, {}, {0}, {})); std::unique_ptr<ResultSet> result(op.evaluate({input, indeces}, {0}));
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op."); REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op."); REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
output->assign(result->at(0)); output->assign(result->at(0));
@ -94,7 +94,7 @@ DECLARE_SHAPE_FN(embedding_lookup) {
for (int e = 1; e < outRank; e++) for (int e = 1; e < outRank; e++)
shapeInfo[e] = shape::sizeAt(inShapeInfo, e); shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(inShapeInfo), shapeInfo); auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo);
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }

View File

@ -74,6 +74,8 @@ namespace nd4j {
DECLARE_SHAPE_FN(onehot) { DECLARE_SHAPE_FN(onehot) {
auto inShape = inputShape->at(0); auto inShape = inputShape->at(0);
nd4j::DataType dtype = block.numD() > 0 ? D_ARG(0) : nd4j::DataType::FLOAT32;
int depth = -1; int depth = -1;
Nd4jLong axis = -1; Nd4jLong axis = -1;
@ -99,7 +101,7 @@ namespace nd4j {
shape.push_back(shape::shapeOf(inShape)[e]); shape.push_back(shape::shapeOf(inShape)[e]);
shape.insert(shape.begin() + axis, depth); shape.insert(shape.begin() + axis, depth);
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', rank + 1, shape.data()); newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', rank + 1, shape.data());
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }

View File

@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
// forward steps // forward steps
nd4j::ops::dynamic_rnn dynamicRnn; nd4j::ops::dynamic_rnn dynamicRnn;
auto resultsFW = dynamicRnn.execute({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {}, {timeMajor}, {}, false, x->dataType()); auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW] hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
hFWFinal->assign(resultsFW->at(1)); hFWFinal->assign(resultsFW->at(1));
@ -97,17 +97,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
// reverse x // reverse x
nd4j::ops::reverse_sequence reverse; nd4j::ops::reverse_sequence reverse;
auto resultsIn = timeMajor ? reverse.execute({x, seqLen}, {}, {0, 1}, {}, false, x->dataType()) : reverse.execute({x, seqLen}, {}, {1, 0}, {}, false, x->dataType()); auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence."); REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
auto revInput = resultsIn->at(0); auto revInput = resultsIn->at(0);
// backward steps // backward steps
auto resultsBW = dynamicRnn.execute({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {}, {timeMajor}, {}); auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
hBWFinal->assign(resultsBW->at(1)); hBWFinal->assign(resultsBW->at(1));
// reverse hBWtemp // reverse hBWtemp
auto resultsOut = timeMajor ? reverse.execute({hBWtemp, seqLen}, {}, {0, 1}, {}) : reverse.execute({hBWtemp, seqLen}, {}, {1, 0}, {}); auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
hBW->assign(resultsOut->at(0)); hBW->assign(resultsOut->at(0));
delete resultsOut; delete resultsOut;

View File

@ -48,7 +48,7 @@ namespace ops {
auto conv = ArrayUtils::toLongVector(*block.getIArguments()); auto conv = ArrayUtils::toLongVector(*block.getIArguments());
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(in), conv); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), shape::order(in), conv);
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }

View File

@ -51,7 +51,7 @@ namespace helpers {
throw std::runtime_error("multiUnique: cannot execute concat op properly."); throw std::runtime_error("multiUnique: cannot execute concat op properly.");
nd4j::ops::unique opUnique; nd4j::ops::unique opUnique;
auto uResult = opUnique.execute({&arrayFull}, {}, {}, {}); auto uResult = opUnique.evaluate({&arrayFull});
if (Status::OK() != uResult->status()) if (Status::OK() != uResult->status())
throw std::runtime_error("multiUnique: cannot execute unique op properly."); throw std::runtime_error("multiUnique: cannot execute unique op properly.");

View File

@ -36,7 +36,7 @@ namespace nd4j {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL)); return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL));
} }
bool BooleanOp::evaluate(nd4j::graph::Context &block) { bool BooleanOp::verify(nd4j::graph::Context &block) {
// check if scalar or not // check if scalar or not
// validation? // validation?
@ -58,11 +58,6 @@ namespace nd4j {
} }
} }
bool BooleanOp::evaluate(std::initializer_list<nd4j::NDArray *> args) {
std::vector<nd4j::NDArray *> vec(args);
return this->evaluate(vec);
}
bool BooleanOp::prepareOutputs(Context& ctx) { bool BooleanOp::prepareOutputs(Context& ctx) {
auto variableSpace = ctx.getVariableSpace(); auto variableSpace = ctx.getVariableSpace();
@ -120,7 +115,7 @@ namespace nd4j {
return ND4J_STATUS_KERNEL_FAILURE; return ND4J_STATUS_KERNEL_FAILURE;
} }
bool BooleanOp::evaluate(std::vector<nd4j::NDArray *> &args) { bool BooleanOp::verify(const std::vector<nd4j::NDArray *> &args) {
VariableSpace variableSpace; VariableSpace variableSpace;
int cnt = -1; int cnt = -1;
@ -135,7 +130,7 @@ namespace nd4j {
Context block(1, &variableSpace, false); Context block(1, &variableSpace, false);
block.fillInputs(in); block.fillInputs(in);
return this->evaluate(block); return this->verify(block);
} }
} }
} }

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by raver119 on 07.10.2017. // @author raver119@gmail.com
// //
#include <ops/declarable/DeclarableOp.h> #include <ops/declarable/DeclarableOp.h>
@ -27,6 +27,7 @@
#include <ops/declarable/OpRegistrator.h> #include <ops/declarable/OpRegistrator.h>
#include <exceptions/datatype_exception.h> #include <exceptions/datatype_exception.h>
#include <helpers/StringUtils.h> #include <helpers/StringUtils.h>
#include <cstdarg>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -164,6 +165,9 @@ namespace nd4j {
// we build list of input shapes // we build list of input shapes
if (ctx.isFastPath()) { if (ctx.isFastPath()) {
for (const auto p:ctx.fastpath_in()) { for (const auto p:ctx.fastpath_in()) {
if (p == nullptr)
continue;
inSha.push_back(p->getShapeInfo()); inSha.push_back(p->getShapeInfo());
} }
} else { } else {
@ -357,6 +361,9 @@ namespace nd4j {
std::vector<nd4j::DataType> inputTypes(block.width()); std::vector<nd4j::DataType> inputTypes(block.width());
if (block.isFastPath()) { if (block.isFastPath()) {
for (auto array: block.fastpath_in()) { for (auto array: block.fastpath_in()) {
if (array == nullptr)
continue;
inputTypes[inT++] = array->dataType(); inputTypes[inT++] = array->dataType();
if (!_descriptor->checkInputMatch(cnt, array->dataType())) { if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
auto ctype = DataTypeUtils::asString(array->dataType()); auto ctype = DataTypeUtils::asString(array->dataType());
@ -394,6 +401,9 @@ namespace nd4j {
if (block.isFastPath()) { if (block.isFastPath()) {
int index = 0; int index = 0;
for (auto array: block.fastpath_out()) { for (auto array: block.fastpath_out()) {
if (array == nullptr)
continue;
auto cType = array->dataType(); auto cType = array->dataType();
if (_descriptor->isSameMode()) { if (_descriptor->isSameMode()) {
@ -762,39 +772,7 @@ namespace nd4j {
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) { Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<nd4j::DataType>& dArgs, bool isInplace, nd4j::DataType type) {
std::vector<NDArray*> ins(inputs);
std::vector<double> tas(tArgs);
std::vector<Nd4jLong> ias(iArgs);
std::vector<bool> bas(bArgs);
return this->execute(ins, tas, ias, bas, isInplace, type);
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) {
std::vector<NDArray*> ins(inputs);
std::vector<NDArray*> ous(outputs);
std::vector<double> tas(tArgs);
std::vector<Nd4jLong> ias(iArgs);
std::vector<bool> bas(bArgs);
return this->execute(ins, ous, tas, ias, bas, isInplace, type);
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) {
std::vector<NDArray*> ins(inputs);
std::vector<NDArray*> ous(outputs);
std::vector<double> tas(tArgs);
std::vector<Nd4jLong> ias(iArgs);
std::vector<bool> bas(bArgs);
return this->execute(rng, ins, ous, tas, ias, bas, isInplace, type);
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) {
// TODO: nullptr here might be replaced
nd4j::graph::RandomGenerator rng(0, 0);
return execute(rng, inputs, outputs, tArgs, iArgs, bArgs, isInplace, type);
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) {
VariableSpace variableSpace; VariableSpace variableSpace;
FlowPath fp; FlowPath fp;
variableSpace.setFlowPath(&fp); variableSpace.setFlowPath(&fp);
@ -838,12 +816,124 @@ namespace nd4j {
for (int e = 0; e < bArgs.size(); e++) for (int e = 0; e < bArgs.size(); e++)
block.getBArguments()->push_back(static_cast<int>(bArgs.at(e))); block.getBArguments()->push_back(static_cast<int>(bArgs.at(e)));
for (int e = 0; e < dArgs.size(); e++)
block.getDArguments()->push_back(dArgs.at(e));
Nd4jStatus result = this->execute(&block); Nd4jStatus result = this->execute(&block);
return result; return result;
} }
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const std::vector<NDArray*>& inputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) { Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs) {
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<double> tArgs) {
std::vector<double> realArgs(tArgs);
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<float> tArgs) {
std::vector<double> realArgs;
for (auto v:tArgs)
realArgs.emplace_back(v);
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<Nd4jLong> iArgs) {
std::vector<Nd4jLong> realArgs(iArgs);
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<int> iArgs) {
std::vector<Nd4jLong> realArgs;
for (auto v:iArgs)
realArgs.emplace_back(v);
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<bool> bArgs) {
std::vector<bool> realArgs(bArgs);
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
}
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
Context ctx(1);
for (int e = 0; e < inputs.size(); e++) {
if (inputs[e] == nullptr)
break;
ctx.setInputArray(e, inputs[e]);
}
for (int e = 0; e < outputs.size(); e++) {
if (outputs[e] == nullptr)
break;
ctx.setOutputArray(e, outputs[e]);
}
if (isInplace)
ctx.markInplace(isInplace);
ctx.setIArguments(iArgs);
ctx.setTArguments(tArgs);
ctx.setBArguments(bArgs);
ctx.setDArguments(dArgs);
return execute(&ctx);
}
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
std::vector<Nd4jLong> realArgs;
for (auto v:iArgs)
realArgs.emplace_back(v);
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
std::vector<Nd4jLong> realArgs(iArgs);
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
std::vector<double> realArgs;
for (auto v:tArgs)
realArgs.emplace_back(v);
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
std::vector<double> realArgs(tArgs);
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
std::vector<bool> realArgs(bArgs);
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
}
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
VariableSpace variableSpace; VariableSpace variableSpace;
//ResultSet arrayList; //ResultSet arrayList;
FlowPath fp; FlowPath fp;
@ -862,21 +952,23 @@ namespace nd4j {
} }
Context block(1, &variableSpace, false); Context block(1, &variableSpace, false);
block.setDataType(0, type); block.setDataType(0, nd4j::DataType::FLOAT32);
block.fillInputs(in); block.fillInputs(in);
block.markInplace(isInplace); block.markInplace(isInplace);
// block.setRNG(ProviderRNG::getInstance().getRNG()); // block.setRNG(ProviderRNG::getInstance().getRNG());
for (int e = 0; e < tArgs.size(); e++) for (int e = 0; e < tArgs.size(); e++)
block.getTArguments()->emplace_back(tArgs.at(e)); block.getTArguments()->emplace_back(tArgs.at(e));
for (int e = 0; e < iArgs.size(); e++) for (int e = 0; e < iArgs.size(); e++)
block.getIArguments()->emplace_back(iArgs.at(e)); block.getIArguments()->emplace_back(iArgs.at(e));
for (int e = 0; e < bArgs.size(); e++) for (int e = 0; e < bArgs.size(); e++)
block.getBArguments()->push_back(bArgs.at(e)); block.getBArguments()->push_back(bArgs.at(e));
for (int e = 0; e < dArgs.size(); e++)
block.getDArguments()->push_back(dArgs.at(e));
Nd4jStatus status = this->execute(&block); Nd4jStatus status = this->execute(&block);
auto arrayList = new ResultSet(); auto arrayList = new ResultSet();
if (isInplace) if (isInplace)
@ -907,7 +999,8 @@ namespace nd4j {
} }
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const nd4j::OpArgsHolder& holder, bool isInplace) { nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const nd4j::OpArgsHolder& holder, bool isInplace) {
return execute(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), isInplace, nd4j::DataType::DOUBLE); // FIXME: add DArgs to OpArgsHolder
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<nd4j::DataType>(), isInplace);
} }
Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) { Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) {

View File

@ -43,7 +43,7 @@ TEST_F(AttentionTests, basic_dot_product_attention) {
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1}); auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
nd4j::ops::dot_product_attention op; nd4j::ops::dot_product_attention op;
auto result = op.execute({&queries, &keys, &values}, {}, {1, 0}, {}); auto result = op.evaluate({&queries, &keys, &values}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -71,7 +71,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) {
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1}); auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
nd4j::ops::dot_product_attention op; nd4j::ops::dot_product_attention op;
auto result = op.execute({&queries, &keys, &values}, {}, {1, 1}, {}); auto result = op.evaluate({&queries, &keys, &values}, {1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -85,7 +85,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
mask.assign(1.); mask.assign(1.);
nd4j::ops::dot_product_attention op; nd4j::ops::dot_product_attention op;
auto result = op.execute({&queries, &keys, &values, &mask}, {}, {1, 0}, {}); auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -117,7 +117,7 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) {
mask.assign(1.); mask.assign(1.);
nd4j::ops::dot_product_attention op; nd4j::ops::dot_product_attention op;
auto result = op.execute({&queries, &keys, &values, &mask}, {}, {1, 0}, {}); auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -153,7 +153,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention) {
auto Wo = NDArrayFactory::create<float>('c', {2* 3, 4}); auto Wo = NDArrayFactory::create<float>('c', {2* 3, 4});
nd4j::ops::multi_head_dot_product_attention op; nd4j::ops::multi_head_dot_product_attention op;
auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {}, {1, 0}, {}); auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -197,7 +197,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) {
nd4j::ops::multi_head_dot_product_attention op; nd4j::ops::multi_head_dot_product_attention op;
auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {}, {1, 0}, {}); auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;

View File

@ -37,7 +37,7 @@ TEST_F(BackpropTests, Test_Add_1) {
NDArray e('c', {2, 3, 4}, nd4j::DataType::FLOAT32); NDArray e('c', {2, 3, 4}, nd4j::DataType::FLOAT32);
nd4j::ops::add_bp op; nd4j::ops::add_bp op;
auto result = op.execute({&x, &y, &e}, {}, {}, {}); auto result = op.evaluate({&x, &y, &e});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());

View File

@ -38,7 +38,7 @@ TEST_F(BooleanOpsTests, LtTest_1) {
nd4j::ops::lt_scalar op; nd4j::ops::lt_scalar op;
ASSERT_TRUE(op.evaluate({x, y})); ASSERT_TRUE(op.verify({x, y}));
delete x; delete x;
delete y; delete y;
@ -51,7 +51,7 @@ TEST_F(BooleanOpsTests, LtTest_2) {
nd4j::ops::lt_scalar op; nd4j::ops::lt_scalar op;
ASSERT_FALSE(op.evaluate({x, y})); ASSERT_FALSE(op.verify({x, y}));
delete x; delete x;
delete y; delete y;
@ -62,7 +62,7 @@ TEST_F(BooleanOpsTests, Is_non_decreasing_1) {
nd4j::ops::is_non_decreasing op; nd4j::ops::is_non_decreasing op;
ASSERT_TRUE(op.evaluate({&x})); ASSERT_TRUE(op.verify({&x}));
} }
@ -71,7 +71,7 @@ TEST_F(BooleanOpsTests, Is_non_decreasing_2) {
nd4j::ops::is_non_decreasing op; nd4j::ops::is_non_decreasing op;
ASSERT_FALSE(op.evaluate({&x})); ASSERT_FALSE(op.verify({&x}));
} }
@ -80,7 +80,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_1) {
nd4j::ops::is_strictly_increasing op; nd4j::ops::is_strictly_increasing op;
ASSERT_TRUE(op.evaluate({&x})); ASSERT_TRUE(op.verify({&x}));
} }
@ -89,7 +89,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_2) {
nd4j::ops::is_strictly_increasing op; nd4j::ops::is_strictly_increasing op;
ASSERT_FALSE(op.evaluate({&x})); ASSERT_FALSE(op.verify({&x}));
} }
@ -98,7 +98,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_3) {
nd4j::ops::is_strictly_increasing op; nd4j::ops::is_strictly_increasing op;
ASSERT_FALSE(op.evaluate({&x})); ASSERT_FALSE(op.verify({&x}));
} }
TEST_F(BooleanOpsTests, Is_strictly_increasing_5) { TEST_F(BooleanOpsTests, Is_strictly_increasing_5) {
@ -107,7 +107,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_5) {
nd4j::ops::is_strictly_increasing op; nd4j::ops::is_strictly_increasing op;
ASSERT_TRUE(op.evaluate({&x})); ASSERT_TRUE(op.verify({&x}));
} }
TEST_F(BooleanOpsTests, Is_strictly_increasing_6) { TEST_F(BooleanOpsTests, Is_strictly_increasing_6) {
@ -118,7 +118,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_6) {
nd4j::ops::is_strictly_increasing op; nd4j::ops::is_strictly_increasing op;
ASSERT_FALSE(op.evaluate({&x})); ASSERT_FALSE(op.verify({&x}));
} }
TEST_F(BooleanOpsTests, Is_numeric_tensor_1) { TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
@ -126,7 +126,7 @@ TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
nd4j::ops::is_numeric_tensor op; nd4j::ops::is_numeric_tensor op;
ASSERT_TRUE(op.evaluate({&x})); ASSERT_TRUE(op.verify({&x}));
} }
TEST_F(BooleanOpsTests, test_where_1) { TEST_F(BooleanOpsTests, test_where_1) {
@ -136,7 +136,7 @@ TEST_F(BooleanOpsTests, test_where_1) {
nd4j::ops::choose op; nd4j::ops::choose op;
auto result = op.execute({&x, &y}, {}, {3}); auto result = op.evaluate({&x, &y}, {3});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);

View File

@ -46,7 +46,7 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
exp.applyBroadcast(broadcast::Add, {1}, y, exp); exp.applyBroadcast(broadcast::Add, {1}, y, exp);
nd4j::ops::add op; nd4j::ops::add op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -73,7 +73,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) {
exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); exp.applyBroadcast(broadcast::Multiply, {1}, y, exp);
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -98,7 +98,7 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) {
nd4j::ops::squaredsubtract op; nd4j::ops::squaredsubtract op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -117,7 +117,7 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) {
auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 0, -1}); auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 0, -1});
nd4j::ops::subtract op; nd4j::ops::subtract op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -136,7 +136,7 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) {
auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 2, 3});
nd4j::ops::add op; nd4j::ops::add op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -155,7 +155,7 @@ TEST_F(BroadcastableOpsTests, Test_Maximum_1) {
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {2, 2, 2, 2, 3, 2}); auto exp = NDArrayFactory::create<float>('c', {2, 3}, {2, 2, 2, 2, 3, 2});
nd4j::ops::maximum op; nd4j::ops::maximum op;
auto result = op.execute({&x, &row}, {}, {}, {}); auto result = op.evaluate({&x, &row});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -173,7 +173,7 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 1, 1, 1, 1}); auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 1, 1, 1, 1});
nd4j::ops::minimum op; nd4j::ops::minimum op;
auto result = op.execute({&x, &col}, {}, {}, {}); auto result = op.evaluate({&x, &col});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -281,7 +281,7 @@ TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) {
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {3, 4, 5, 6}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {3, 4, 5, 6});
nd4j::ops::add op; nd4j::ops::add op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -331,7 +331,7 @@ TEST_F(BroadcastableOpsTests, Test_Subtract_2) {
auto e = NDArrayFactory::create<float>('c', {2}, {1.0f, 0.0f}); auto e = NDArrayFactory::create<float>('c', {2}, {1.0f, 0.0f});
nd4j::ops::subtract op; nd4j::ops::subtract op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
auto z = result->at(0); auto z = result->at(0);
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
@ -509,7 +509,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_7) {
auto e = NDArrayFactory::create<float>('c', {1}, {8.f}); auto e = NDArrayFactory::create<float>('c', {1}, {8.f});
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -525,7 +525,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
auto e = NDArrayFactory::create<float>('c', {1, 1}, {8.f}); auto e = NDArrayFactory::create<float>('c', {1, 1}, {8.f});
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -544,7 +544,7 @@ TEST_F(BroadcastableOpsTests, broadcast_add_1) {
NDArray exp('c', {1,4}, {2,3,4,5}, nd4j::DataType::DOUBLE); NDArray exp('c', {1,4}, {2,3,4,5}, nd4j::DataType::DOUBLE);
nd4j::ops::add op; nd4j::ops::add op;
auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); auto status = op.execute({&x, &y}, {&z});
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(z.equalsTo(exp)); ASSERT_TRUE(z.equalsTo(exp));
@ -559,7 +559,7 @@ TEST_F(BroadcastableOpsTests, broadcast_equals_1) {
NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, nd4j::DataType::BOOL); NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, nd4j::DataType::BOOL);
nd4j::ops::equals op; nd4j::ops::equals op;
auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); auto status = op.execute({&x, &y}, {&z});
// z.printIndexedBuffer(); // z.printIndexedBuffer();
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
@ -603,7 +603,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});; NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::maximum op; nd4j::ops::maximum op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -622,7 +622,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});; NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::maximum op; nd4j::ops::maximum op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -641,7 +641,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});; NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::realdiv op; nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -660,7 +660,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});; NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::realdiv op; nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -679,7 +679,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});; NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});;
nd4j::ops::realdiv op; nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -715,7 +715,7 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
nd4j::ops::greater op; nd4j::ops::greater op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
auto z = result->at(0); auto z = result->at(0);
@ -741,7 +741,7 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
nd4j::ops::greater op; nd4j::ops::greater op;
auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); auto status = op.execute({&x, &y}, {&z});
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);

View File

@ -140,7 +140,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_2) {
input.linspace(1); input.linspace(1);
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); auto result = op.evaluate({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -172,7 +172,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -201,7 +201,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_4) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -231,7 +231,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_5) {
weights.permutei({2,3,1,0}); weights.permutei({2,3,1,0});
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
// output->printIndexedBuffer(); // output->printIndexedBuffer();
@ -250,7 +250,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_6) {
auto weights = NDArrayFactory::create<TypeParam>('c', {1, 2, 12, 2}); auto weights = NDArrayFactory::create<TypeParam>('c', {1, 2, 12, 2});
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto result = op.execute({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1}); auto result = op.evaluate({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -271,7 +271,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_7) {
weights = 3.; weights = 3.;
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -305,7 +305,7 @@ TEST_F(ConvolutionTests1, conv2d_8) {
1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412}); 1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412});
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
// output->printBuffer(); // output->printBuffer();
@ -419,7 +419,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_2) {
nd4j::ops::sconv2d op; nd4j::ops::sconv2d op;
auto resultFF = op.execute({&input, &weightsD, &weightsP}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}, {}); auto resultFF = op.evaluate({&input, &weightsD, &weightsP}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
auto z = resultFF->at(0); auto z = resultFF->at(0);
//z->printShapeInfo("FF shape"); //z->printShapeInfo("FF shape");
@ -452,8 +452,8 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_3) {
auto expOutput = NDArrayFactory::create<TypeParam>('c', {3, 2, 8, 8}); auto expOutput = NDArrayFactory::create<TypeParam>('c', {3, 2, 8, 8});
nd4j::ops::sconv2d op; nd4j::ops::sconv2d op;
Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0}, {}); Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {1, 1, 1, 1, 0, 0, 1, 1, 0});
auto result = op.execute({&input, &weightsD, &weightsP, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0}, {}); auto result = op.evaluate({&input, &weightsD, &weightsP, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 0});
auto z = result->at(0); auto z = result->at(0);
@ -493,7 +493,7 @@ TEST_F(ConvolutionTests1, sconv2d_4) {
0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515}); 0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515});
nd4j::ops::sconv2d op; nd4j::ops::sconv2d op;
auto results = op.execute({&input, &weightsD, &weightsP, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weightsD, &weightsP, &biases}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -531,7 +531,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) {
nd4j::ops::conv2d_bp op; nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); auto results = op.evaluate({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});
ASSERT_TRUE(results->size() == 3); ASSERT_TRUE(results->size() == 3);
@ -581,7 +581,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) {
nd4j::ops::conv2d_bp op; nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); auto results = op.evaluate({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});
ASSERT_TRUE(results->size() == 2); ASSERT_TRUE(results->size() == 2);
@ -664,7 +664,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) {
input.linspace(1); input.linspace(1);
nd4j::ops::sconv2d op; nd4j::ops::sconv2d op;
auto resultFF = op.execute({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); auto resultFF = op.evaluate({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {});
auto z = resultFF->at(0); auto z = resultFF->at(0);
@ -674,7 +674,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) {
nd4j::ops::conv2d op2d; nd4j::ops::conv2d op2d;
// weightsP.printShapeInfo(); // weightsP.printShapeInfo();
auto result2D = op2d.execute({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); auto result2D = op2d.evaluate({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});
auto z2d = result2D->at(0); auto z2d = result2D->at(0);
// z2d->printBuffer(); // z2d->printBuffer();
@ -717,7 +717,7 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_1) {
nd4j::ops::deconv2d_bp op; nd4j::ops::deconv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -771,7 +771,7 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_2) {
nd4j::ops::deconv2d_bp<double> op; nd4j::ops::deconv2d_bp<double> op;
auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); auto result = op.evaluate({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -791,7 +791,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
bias.linspace(1); bias.linspace(1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); auto result_FF = op.evaluate({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result_FF->status()); ASSERT_EQ(ND4J_STATUS_OK, result_FF->status());
@ -805,7 +805,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
auto epsilonNxt = new NDArray(z->dup()); auto epsilonNxt = new NDArray(z->dup());
epsilonNxt->linspace(1); epsilonNxt->linspace(1);
auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result_BP->status()); ASSERT_EQ(ND4J_STATUS_OK, result_BP->status());
auto eps = result_BP->at(0); auto eps = result_BP->at(0);
@ -833,7 +833,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) {
input.linspace(1); input.linspace(1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1, 1,0}); auto result = op.evaluate({&input, &weights}, {}, {2, 1, 0, 1, 1,0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -860,7 +860,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_1) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -892,7 +892,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_2) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -923,7 +923,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_3) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -954,7 +954,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_4) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -985,7 +985,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_5) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1016,7 +1016,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1048,7 +1048,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_7) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1081,7 +1081,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_8) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op; nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1129,7 +1129,7 @@ TEST_F(ConvolutionTests1, Test_Dilation2D_1) {
weights.linspace(1); weights.linspace(1);
nd4j::ops::dilation2d op; nd4j::ops::dilation2d op;
auto result = op.execute({&input, &weights}, {}, {1, 1,2,2,1, 1,2,2,1}); auto result = op.evaluate({&input, &weights}, {1, 1,2,2,1, 1,2,2,1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1149,7 +1149,7 @@ TEST_F(ConvolutionTests1, Test_Dilation2D_2) {
weights.linspace(1); weights.linspace(1);
nd4j::ops::dilation2d op; nd4j::ops::dilation2d op;
auto result = op.execute({&input, &weights}, {}, {0, 1,2,2,1, 1,2,2,1}); auto result = op.evaluate({&input, &weights}, {0, 1,2,2,1, 1,2,2,1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1188,7 +1188,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::conv2d_bp op; nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0); auto gradI = results->at(0);
auto gradW = results->at(1); auto gradW = results->at(1);
@ -1231,7 +1231,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::conv2d_bp op; nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0); auto gradI = results->at(0);
auto gradW = results->at(1); auto gradW = results->at(1);
@ -1276,7 +1276,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) {
expGradW.permutei({2,3,1,0}); expGradW.permutei({2,3,1,0});
nd4j::ops::conv2d_bp op; nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0); auto gradI = results->at(0);
auto gradW = results->at(1); auto gradW = results->at(1);
auto gradB = results->at(2); auto gradB = results->at(2);
@ -1358,7 +1358,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::conv3dnew_bp op; nd4j::ops::conv3dnew_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0); auto gradI = results->at(0);
auto gradW = results->at(1); auto gradW = results->at(1);
@ -1406,7 +1406,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::conv3dnew_bp op; nd4j::ops::conv3dnew_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0); auto gradI = results->at(0);
auto gradW = results->at(1); auto gradW = results->at(1);
@ -1459,7 +1459,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) {
expGradW.permutei({2, 3, 4, 1, 0}); expGradW.permutei({2, 3, 4, 1, 0});
nd4j::ops::conv3dnew_bp op; nd4j::ops::conv3dnew_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* gradI = results->at(0); auto* gradI = results->at(0);
auto* gradW = results->at(1); auto* gradW = results->at(1);
auto* gradB = results->at(2); auto* gradB = results->at(2);
@ -1502,7 +1502,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::depthwise_conv2d_bp op; nd4j::ops::depthwise_conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* gradI = results->at(0); auto* gradI = results->at(0);
auto* gradW = results->at(1); auto* gradW = results->at(1);
@ -1540,7 +1540,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::depthwise_conv2d_bp op; nd4j::ops::depthwise_conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* gradI = results->at(0); auto* gradI = results->at(0);
auto* gradW = results->at(1); auto* gradW = results->at(1);
@ -1568,7 +1568,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test3) {
auto gradB = b.like(); auto gradB = b.like();
nd4j:ops::depthwise_conv2d_bp op; nd4j:ops::depthwise_conv2d_bp op;
auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {}); auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -1607,7 +1607,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test4) {
NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, nd4j::DataType::FLOAT32); NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, nd4j::DataType::FLOAT32);
nd4j::ops::depthwise_conv2d_bp op; nd4j::ops::depthwise_conv2d_bp op;
ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); ResultSet* results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
NDArray* gradI = results->at(0); NDArray* gradI = results->at(0);
NDArray* gradW = results->at(1); NDArray* gradW = results->at(1);
NDArray* gradB = results->at(2); NDArray* gradB = results->at(2);
@ -1662,7 +1662,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test5) {
NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, nd4j::DataType::FLOAT32); NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, nd4j::DataType::FLOAT32);
nd4j::ops::depthwise_conv2d_bp op; nd4j::ops::depthwise_conv2d_bp op;
ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); ResultSet* results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
NDArray* gradI = results->at(0); NDArray* gradI = results->at(0);
NDArray* gradW = results->at(1); NDArray* gradW = results->at(1);
NDArray* gradB = results->at(2); NDArray* gradB = results->at(2);
@ -1706,7 +1706,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test6) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::depthwise_conv2d_bp op; nd4j::ops::depthwise_conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* gradI = results->at(0); auto* gradI = results->at(0);
auto* gradW = results->at(1); auto* gradW = results->at(1);
@ -1742,7 +1742,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) {
weights = 1.; weights = 1.;
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1774,7 +1774,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1801,7 +1801,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1827,7 +1827,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test4) {
expected = 48.; expected = 48.;
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1855,7 +1855,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test5) {
bias = 1.; bias = 1.;
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
// output->printIndexedBuffer(); // output->printIndexedBuffer();
@ -1884,7 +1884,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) {
weights = 0.5; weights = 0.5;
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
// output->printIndexedBuffer(); // output->printIndexedBuffer();
@ -1915,7 +1915,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) {
weights.permutei({2, 3, 4, 1, 0}); weights.permutei({2, 3, 4, 1, 0});
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
// output->printIndexedBuffer(); // output->printIndexedBuffer();
@ -1944,7 +1944,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) {
weights.permutei({2, 3, 4, 1, 0}); weights.permutei({2, 3, 4, 1, 0});
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1961,7 +1961,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test9) {
auto e = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4}); auto e = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto result = op.execute({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); auto result = op.evaluate({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1977,7 +1977,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test10) {
auto exp = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4}); auto exp = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto result = op.execute({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); auto result = op.evaluate({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ShapeList shapeList({x.shapeInfo(), w.shapeInfo()}); ShapeList shapeList({x.shapeInfo(), w.shapeInfo()});
@ -2039,7 +2039,7 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) {
bias = 1.; bias = 1.;
nd4j::ops::pointwise_conv2d op; nd4j::ops::pointwise_conv2d op;
auto results = op.execute({&input, &weights, &bias}, {}, {dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2063,7 +2063,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test11) {
weights = 1.; weights = 1.;
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2087,7 +2087,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test12) {
weights = 1.; weights = 1.;
nd4j::ops::conv3dnew op; nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2205,7 +2205,7 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) {
31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f});
nd4j::ops::upsampling2d op; nd4j::ops::upsampling2d op;
auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW}); auto results = op.evaluate({&input}, {factorH, factorW, isNCHW});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2233,7 +2233,7 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) {
33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f});
nd4j::ops::upsampling2d op; nd4j::ops::upsampling2d op;
auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW}); auto results = op.evaluate({&input}, {factorH, factorW, isNCHW});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2271,7 +2271,7 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) {
67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f});
nd4j::ops::upsampling3d op; nd4j::ops::upsampling3d op;
auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW}); auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2305,7 +2305,7 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) {
65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f});
nd4j::ops::upsampling3d op; nd4j::ops::upsampling3d op;
auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW}); auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW});
auto* output = results->at(0); auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2332,7 +2332,7 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test1) {
expGradI = 8.; expGradI = 8.;
nd4j::ops::upsampling3d_bp op; nd4j::ops::upsampling3d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCDHW}); auto results = op.evaluate({&input, &gradO}, {isNCDHW});
auto* gradI = results->at(0); auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2359,7 +2359,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) {
nd4j::ops::conv2d_input_bp op; nd4j::ops::conv2d_input_bp op;
auto results = op.execute({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}); auto results = op.evaluate({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1});
ASSERT_TRUE(results->size() == 1); ASSERT_TRUE(results->size() == 1);
@ -2424,7 +2424,7 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {
4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32); 4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32);
nd4j::ops::upsampling3d_bp op; nd4j::ops::upsampling3d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCDHW}); auto results = op.evaluate({&input, &gradO}, {isNCDHW});
auto* gradI = results->at(0); auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2457,7 +2457,7 @@ TEST_F(ConvolutionTests1, deconv2d_test1) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::deconv2d op; nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
auto output = results->at(0); auto output = results->at(0);
@ -2490,7 +2490,7 @@ TEST_F(ConvolutionTests1, deconv2d_test2) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::deconv2d op; nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2522,7 +2522,7 @@ TEST_F(ConvolutionTests1, deconv2d_test3) {
bias = 0.2; bias = 0.2;
nd4j::ops::deconv2d op; nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
auto output = results->at(0); auto output = results->at(0);
@ -2557,7 +2557,7 @@ TEST_F(ConvolutionTests1, deconv2d_test4) {
weights.permutei({2,3,1,0}); weights.permutei({2,3,1,0});
nd4j::ops::deconv2d op; nd4j::ops::deconv2d op;
auto result = op.execute({&input, &weights}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); auto result = op.evaluate({&input, &weights}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
auto z = result->at(0); auto z = result->at(0);
// z->printShapeInfo(); // z->printShapeInfo();
@ -2584,7 +2584,7 @@ TEST_F(ConvolutionTests1, deconv2d_test5) {
weights.permutei({2,3,1,0}); weights.permutei({2,3,1,0});
nd4j::ops::deconv2d op; nd4j::ops::deconv2d op;
auto result = op.execute({&input, &weights}, {&z}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0},{}); auto result = op.execute({&input, &weights}, {&z}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result); ASSERT_EQ(ND4J_STATUS_OK, result);
@ -2615,7 +2615,7 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) {
input.linspace(1); input.linspace(1);
nd4j::ops::deconv2d op; nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2640,7 +2640,7 @@ TEST_F(ConvolutionTests1, deconv2d_test7) {
nd4j::ops::deconv2d op; nd4j::ops::deconv2d op;
auto result = op.execute({&input, &weights, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); auto result = op.evaluate({&input, &weights, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2683,7 +2683,7 @@ TEST_F(ConvolutionTests1, deconv2d_test8) {
1.471922, 1.484062, 1.212039, 1.144419, 1.266123}); 1.471922, 1.484062, 1.212039, 1.144419, 1.266123});
nd4j::ops::deconv2d op; nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2718,7 +2718,7 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) {
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
nd4j::ops::deconv2d_tf op; nd4j::ops::deconv2d_tf op;
auto results = op.execute({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto results = op.evaluate({&outShape, &weights, &input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());

File diff suppressed because one or more lines are too long

View File

@ -46,7 +46,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) {
input.linspace(1); input.linspace(1);
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status()); ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status());
@ -62,7 +62,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) {
input.linspace(1); input.linspace(1);
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);

View File

@ -161,7 +161,7 @@ TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) {
auto exp = NDArrayFactory::create<double>('c', {3,4}); auto exp = NDArrayFactory::create<double>('c', {3,4});
exp.linspace(0.9, 0.9); exp.linspace(0.9, 0.9);
nd4j::ops::apply_sgd op; nd4j::ops::apply_sgd op;
auto result = op.execute({&x, &y}, {1.}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &y}, {1.}, {});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto z = result->at(0); auto z = result->at(0);
@ -175,7 +175,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) {
auto y = NDArrayFactory::create<double>('c', {1,4}, {0.1,0.2,0.3,0.4}); auto y = NDArrayFactory::create<double>('c', {1,4}, {0.1,0.2,0.3,0.4});
auto exp = NDArrayFactory::create<double>('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}); auto exp = NDArrayFactory::create<double>('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4});
nd4j::ops::assign op; nd4j::ops::assign op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &y});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto z = result->at(0); auto z = result->at(0);
@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) {
auto exp1 = NDArrayFactory::create<double>('c', {3,4}); // zero auto exp1 = NDArrayFactory::create<double>('c', {3,4}); // zero
auto exp2 = NDArrayFactory::create<double>('c', {1,4}, {3, 6, 9, 12}); auto exp2 = NDArrayFactory::create<double>('c', {1,4}, {3, 6, 9, 12});
nd4j::ops::assign_bp op; nd4j::ops::assign_bp op;
auto result = op.execute({&x, &y, &eps}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &y, &eps});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto z1 = result->at(0); auto z1 = result->at(0);
auto z2 = result->at(1); auto z2 = result->at(1);
@ -208,7 +208,7 @@ TEST_F(DeclarableOpsTests1, AXpY_Test_1) {
auto exp = NDArrayFactory::create<double>('c', {3,4}); auto exp = NDArrayFactory::create<double>('c', {3,4});
exp.linspace(3, 3); exp.linspace(3, 3);
nd4j::ops::axpy op; nd4j::ops::axpy op;
auto result = op.execute({&x, &y}, {2.}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &y}, {2.});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto z = result->at(0); auto z = result->at(0);
@ -249,7 +249,7 @@ TEST_F(DeclarableOpsTests1, TestTensorMmul1) {
NDArray exp('c', {2, 2}, {650.0, 1586.0, 1586.0, 4250.0}, nd4j::DataType::FLOAT32); NDArray exp('c', {2, 2}, {650.0, 1586.0, 1586.0, 4250.0}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2}); auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -269,7 +269,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) {
NDArray exp('c', {2, 2}, {2300.0, 2444.0, 2444.0, 2600.0}, nd4j::DataType::FLOAT32); NDArray exp('c', {2, 2}, {2300.0, 2444.0, 2444.0, 2600.0}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2}); auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -289,7 +289,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) {
NDArray exp('f', {2, 2}, {1090.0, 2818.0, 1168.0, 3040.0}, nd4j::DataType::FLOAT32); NDArray exp('f', {2, 2}, {1090.0, 2818.0, 1168.0, 3040.0}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2}); auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -309,7 +309,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) {
NDArray exp('f', {2, 2}, {1090.0, 1168.0, 2818.0, 3040.0}, nd4j::DataType::FLOAT32); NDArray exp('f', {2, 2}, {1090.0, 1168.0, 2818.0, 3040.0}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2}); auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -506,7 +506,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) {
nd4j::ops::subtract subOp; nd4j::ops::subtract subOp;
auto res = subOp.execute({&x, &y}, {}, {}); auto res = subOp.evaluate({&x, &y});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
@ -767,7 +767,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) {
nd4j::ops::reversesubtract subOp; nd4j::ops::reversesubtract subOp;
auto res = subOp.execute({&x, &y}, {}, {}); auto res = subOp.evaluate({&x, &y});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(&exp)); ASSERT_TRUE(res->at(0)->equalsTo(&exp));
@ -792,7 +792,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) {
nd4j::ops::reversesubtract subOp; nd4j::ops::reversesubtract subOp;
auto res = subOp.execute({&x, &y}, {}, {}); auto res = subOp.evaluate({&x, &y});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(&exp)); ASSERT_TRUE(res->at(0)->equalsTo(&exp));
@ -815,7 +815,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) {
ASSERT_TRUE(z.equalsTo(&exp)); ASSERT_TRUE(z.equalsTo(&exp));
nd4j::ops::reversesubtract subOp; nd4j::ops::reversesubtract subOp;
auto res = subOp.execute({&x, &y}, {}, {}); auto res = subOp.evaluate({&x, &y});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(&exp)); ASSERT_TRUE(res->at(0)->equalsTo(&exp));
@ -841,7 +841,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) {
nd4j::ops::reversemod subOp; nd4j::ops::reversemod subOp;
auto res = subOp.execute({&x, &y}, {}, {}); auto res = subOp.evaluate({&x, &y});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(&exp)); ASSERT_TRUE(res->at(0)->equalsTo(&exp));
@ -868,7 +868,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) {
nd4j::ops::reversemod subOp; nd4j::ops::reversemod subOp;
auto res = subOp.execute({&x, &y}, {}, {}); auto res = subOp.evaluate({&x, &y});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(&exp)); ASSERT_TRUE(res->at(0)->equalsTo(&exp));
@ -1157,7 +1157,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) {
nd4j::ops::divide div; nd4j::ops::divide div;
auto res = div.execute({&x, &y}, {}, {}); auto res = div.evaluate({&x, &y});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(exp)); ASSERT_TRUE(res->at(0)->equalsTo(exp));
@ -1176,7 +1176,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) {
exp.assign(3); exp.assign(3);
nd4j::ops::divide_no_nan div; nd4j::ops::divide_no_nan div;
auto res = div.execute({&x, &y}, {}, {}); auto res = div.evaluate({&x, &y});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(exp)); ASSERT_TRUE(res->at(0)->equalsTo(exp));
@ -1192,7 +1192,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) {
auto exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2}); auto exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2});
nd4j::ops::divide_no_nan div; nd4j::ops::divide_no_nan div;
auto res = div.execute({&x, &y}, {}, {}); auto res = div.evaluate({&x, &y});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(exp)); ASSERT_TRUE(res->at(0)->equalsTo(exp));
@ -1212,7 +1212,7 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {
nd4j::ops::reversedivide div; nd4j::ops::reversedivide div;
auto res = div.execute({&x, &y}, {}, {}); auto res = div.evaluate({&x, &y});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
@ -1469,7 +1469,7 @@ TEST_F(DeclarableOpsTests1, Test_Cast_1) {
yExp.linspace(1); yExp.linspace(1);
nd4j::ops::cast op; nd4j::ops::cast op;
auto result = op.execute({&x}, {}, {3}); auto result = op.evaluate({&x}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1673,7 +1673,7 @@ TEST_F(DeclarableOpsTests1, Reshape3) {
auto x = NDArrayFactory::create<float>('c', {3, 4, 5}); auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {-99, 3, 4, 5}); auto result = op.evaluate({&x}, {}, {-99, 3, 4, 5});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1688,7 +1688,7 @@ TEST_F(DeclarableOpsTests1, Reshape4) {
auto x = NDArrayFactory::create<float>('c', {3, 4, 5}); auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {3, 4, 5}); auto result = op.evaluate({&x}, {}, {3, 4, 5});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1703,7 +1703,7 @@ TEST_F(DeclarableOpsTests1, Reshape5) {
auto x = NDArrayFactory::create<float>('c', {3, 4, 5}); auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {5, 4, 3}); auto result = op.evaluate({&x}, {}, {5, 4, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1715,7 +1715,7 @@ TEST_F(DeclarableOpsTests1, Reshape6){
auto exp = NDArrayFactory::create<float>('c', {4, 15}); auto exp = NDArrayFactory::create<float>('c', {4, 15});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {4, -1}); auto result = op.evaluate({&x}, {}, {4, -1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1732,7 +1732,7 @@ TEST_F(DeclarableOpsTests1, Reshape7){
auto exp = NDArrayFactory::create<float>('c', {60}); auto exp = NDArrayFactory::create<float>('c', {60});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {-1}); auto result = op.evaluate({&x}, {}, {-1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2217,7 +2217,7 @@ TEST_F(DeclarableOpsTests1, IsMax1) {
exp.p<bool>(2, 2, true); exp.p<bool>(2, 2, true);
nd4j::ops::ismax ismaxOp; nd4j::ops::ismax ismaxOp;
auto result = ismaxOp.execute({&x}, {}, {1}); auto result = ismaxOp.evaluate({&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2239,7 +2239,7 @@ TEST_F(DeclarableOpsTests1, IsMax2) {
exp.p<bool>(2, 2, true); exp.p<bool>(2, 2, true);
nd4j::ops::ismax ismaxOp; nd4j::ops::ismax ismaxOp;
auto result = ismaxOp.execute({&x}, {}, {0, 1}); auto result = ismaxOp.evaluate({&x}, {}, {0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2261,7 +2261,7 @@ TEST_F(DeclarableOpsTests1, IsMax3) {
//exp.p<bool>(2, 2, true); //exp.p<bool>(2, 2, true);
nd4j::ops::ismax ismaxOp; nd4j::ops::ismax ismaxOp;
auto result = ismaxOp.execute({&x}, {}, {0}); auto result = ismaxOp.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2279,7 +2279,7 @@ TEST_F(DeclarableOpsTests1, IsMax4) {
auto e = NDArrayFactory::create<bool>('c', {6}, {false, false, false, true, false, false}); auto e = NDArrayFactory::create<bool>('c', {6}, {false, false, false, true, false, false});
nd4j::ops::ismax op; nd4j::ops::ismax op;
auto result = op.execute({&x}, {&z}, {}, {}, {}); auto result = op.execute({&x}, {&z});
ASSERT_EQ(Status::OK(), result); ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
@ -2343,7 +2343,7 @@ TEST_F(DeclarableOpsTests1, sru_test1) {
mask.assign(1.); mask.assign(1.);
nd4j::ops::sru op; nd4j::ops::sru op;
auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {}); auto results = op.evaluate({&input, &weights, &bias, &init, &mask});
ASSERT_TRUE(results->size() == 2); ASSERT_TRUE(results->size() == 2);
auto output = results->at(0); auto output = results->at(0);
@ -2390,7 +2390,7 @@ TEST_F(DeclarableOpsTests1, sru_bp) {
inGradH.assign(0.5); inGradH.assign(0.5);
nd4j::ops::sru_bp bp; nd4j::ops::sru_bp bp;
auto resultsBP = bp.execute({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {}); auto resultsBP = bp.evaluate({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {});
ASSERT_TRUE(resultsBP->size() == 4); ASSERT_TRUE(resultsBP->size() == 4);
auto gradX = resultsBP->at(0); auto gradX = resultsBP->at(0);
@ -2429,7 +2429,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) {
mask.assign(1.); mask.assign(1.);
nd4j::ops::sru_bi op; nd4j::ops::sru_bi op;
auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {}); auto results = op.evaluate({&input, &weights, &bias, &init, &mask}, {}, {});
ASSERT_TRUE(results->size() == 2); ASSERT_TRUE(results->size() == 2);
auto output = results->at(0); auto output = results->at(0);
@ -2480,7 +2480,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
inGradH.assign(0.5); inGradH.assign(0.5);
nd4j::ops::sru_bi_bp bp; nd4j::ops::sru_bi_bp bp;
auto resultsBP = bp.execute({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {}); auto resultsBP = bp.evaluate({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {});
ASSERT_TRUE(resultsBP->size() == 4); ASSERT_TRUE(resultsBP->size() == 4);
auto gradX = resultsBP->at(0); auto gradX = resultsBP->at(0);
@ -2504,7 +2504,7 @@ TEST_F(DeclarableOpsTests1, ArgMax1) {
nd4j::ops::argmax op; nd4j::ops::argmax op;
auto result = op.execute({&x}, {}, {1}); auto result = op.evaluate({&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2525,7 +2525,7 @@ TEST_F(DeclarableOpsTests1, ArgMax2) {
nd4j::ops::argmax op; nd4j::ops::argmax op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2547,7 +2547,7 @@ TEST_F(DeclarableOpsTests1, ArgMax3) {
nd4j::ops::argmax op; nd4j::ops::argmax op;
auto result = op.execute({&x, &dim}, {}, {}); auto result = op.evaluate({&x, &dim}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2568,7 +2568,7 @@ TEST_F(DeclarableOpsTests1, ArgMax4) {
nd4j::ops::argmax op; nd4j::ops::argmax op;
auto result = op.execute({&x, &dim}, {}, {}); auto result = op.evaluate({&x, &dim}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2590,7 +2590,7 @@ TEST_F(DeclarableOpsTests1, ArgMax5) {
nd4j::ops::argmax op; nd4j::ops::argmax op;
auto result = op.execute({&x, &dim}, {}, {}); auto result = op.evaluate({&x, &dim}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2610,12 +2610,12 @@ TEST_F(DeclarableOpsTests1, ArgMax6) {
nd4j::ops::argmax op; nd4j::ops::argmax op;
auto expected = op.execute({&x}, {}, {2}); auto expected = op.evaluate({&x}, {}, {2});
ASSERT_EQ(Status::OK(), expected->status()); ASSERT_EQ(Status::OK(), expected->status());
auto exp = expected->at(0); auto exp = expected->at(0);
auto result = op.execute({&x, &dim}, {}, {}); auto result = op.evaluate({&x, &dim}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2636,7 +2636,7 @@ TEST_F(DeclarableOpsTests1, ArgMin1) {
nd4j::ops::argmin op; nd4j::ops::argmin op;
auto result = op.execute({&x}, {}, {1}); auto result = op.evaluate({&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2659,7 +2659,7 @@ TEST_F(DeclarableOpsTests1, SquareTests1) {
nd4j::ops::square op; nd4j::ops::square op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2677,7 +2677,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_1) {
nd4j::ops::onehot op; nd4j::ops::onehot op;
auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3}); auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2695,7 +2695,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_2) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::onehot op; nd4j::ops::onehot op;
auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3}); auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2715,7 +2715,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_3) {
nd4j::ops::onehot op; nd4j::ops::onehot op;
auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3}); auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2736,7 +2736,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_4) {
nd4j::ops::onehot op; nd4j::ops::onehot op;
auto result = op.execute({&indices, &depth}, {1.0f, 0.0f}, {}); auto result = op.evaluate({&indices, &depth}, {1.0f, 0.0f}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2757,7 +2757,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) {
nd4j::ops::onehot op; nd4j::ops::onehot op;
auto result = op.execute({&indices, &depth, &on, &off}, {}, {}); auto result = op.evaluate({&indices, &depth, &on, &off}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2769,11 +2769,24 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) {
} }
TEST_F(DeclarableOpsTests1, OneHotTests_6) { TEST_F(DeclarableOpsTests1, OneHotTests_6) {
auto indices = NDArrayFactory::create<float>('c', {3}, {0., 1., 2.}); auto indices = NDArrayFactory::create<float>('c', {3}, {0.f, 1.f, 2.f});
auto e = NDArrayFactory::create<float>('c', {3, 3}, {1., 0., 0., 0., 1., 0., 0., 0., 1.}); auto e = NDArrayFactory::create<float>('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f});
nd4j::ops::onehot op; nd4j::ops::onehot op;
auto result = op.execute({&indices}, {1.0, 0.0}, {0, 3}); auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3});
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests1, OneHotTests_7) {
auto indices = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
auto e = NDArrayFactory::create<float16>('c', {3, 3}, {1., 0., 0., 0., 1., 0., 0., 0., 1.});
nd4j::ops::onehot op;
auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3}, {}, {nd4j::DataType::HALF}, false);
auto z = result->at(0); auto z = result->at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
@ -2788,7 +2801,7 @@ TEST_F(DeclarableOpsTests1, FillAs_1) {
float scalar = 119.f; float scalar = 119.f;
nd4j::ops::fill_as op; nd4j::ops::fill_as op;
auto result = op.execute({&x}, {scalar}, {}); auto result = op.evaluate({&x}, {scalar}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2824,7 +2837,7 @@ TEST_F(DeclarableOpsTests1, Stack_1) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input2}, {}, {0}); auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -2852,7 +2865,7 @@ TEST_F(DeclarableOpsTests1, Stack_2) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input2}, {}, {1}); auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -2880,7 +2893,7 @@ TEST_F(DeclarableOpsTests1, Stack_3) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input2}, {}, {0}); auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -2907,7 +2920,7 @@ TEST_F(DeclarableOpsTests1, Stack_4) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input2}, {}, {1}); auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -2934,7 +2947,7 @@ TEST_F(DeclarableOpsTests1, Stack_5) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input2}, {}, {0}); auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -2961,7 +2974,7 @@ TEST_F(DeclarableOpsTests1, Stack_6) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input2}, {}, {1}); auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -2985,7 +2998,7 @@ TEST_F(DeclarableOpsTests1, Stack_7) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input1, &input1}, {}, {0}); auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -3008,7 +3021,7 @@ TEST_F(DeclarableOpsTests1, Stack_8) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input1, &input1}, {}, {0}); auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -3031,7 +3044,7 @@ TEST_F(DeclarableOpsTests1, Stack_9) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input1, &input1}, {}, {1}); auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -3054,7 +3067,7 @@ TEST_F(DeclarableOpsTests1, Stack_10) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input1, &input1}, {}, {1}); auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
auto output = results->at(0); auto output = results->at(0);
//expected.printShapeInfo("exp"); //expected.printShapeInfo("exp");
@ -3079,7 +3092,7 @@ TEST_F(DeclarableOpsTests1, Stack_11) {
NDArray expected(expBuff, expShape); NDArray expected(expBuff, expShape);
nd4j::ops::stack op; nd4j::ops::stack op;
auto results = op.execute({&input1, &input1, &input1}, {}, {}); auto results = op.evaluate({&input1, &input1, &input1}, {}, {});
auto output = results->at(0); auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
@ -3095,7 +3108,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) {
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {}, {1, 5, 1}); auto result = op.evaluate({}, {}, {1, 5, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -3122,7 +3135,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) {
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&start, &stop, &step}, {}, {}); auto result = op.evaluate({&start, &stop, &step}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -3142,7 +3155,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) {
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {1.f, 5.f, 1.f}, {}); auto result = op.evaluate({}, {1.f, 5.f, 1.f}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -3161,7 +3174,7 @@ TEST_F(DeclarableOpsTests1, softmax_test1) {
auto expOutput = NDArrayFactory::create<double>('c', {3, 3}, {1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01}); auto expOutput = NDArrayFactory::create<double>('c', {3, 3}, {1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01});
nd4j::ops::softmax op; nd4j::ops::softmax op;
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -3177,7 +3190,7 @@ TEST_F(DeclarableOpsTests1, softmax_test2) {
auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01}); auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01});
nd4j::ops::softmax op; nd4j::ops::softmax op;
auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {1}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -3193,7 +3206,7 @@ TEST_F(DeclarableOpsTests1, softmax_test3) {
auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01}); auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01});
nd4j::ops::softmax op; nd4j::ops::softmax op;
auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {0}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -3209,7 +3222,7 @@ TEST_F(DeclarableOpsTests1, softmax_test4) {
auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {0.01198,0.08855,0.00441,0.24072,0.65434}); auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {0.01198,0.08855,0.00441,0.24072,0.65434});
nd4j::ops::softmax op; nd4j::ops::softmax op;
auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {1}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -3225,7 +3238,7 @@ TEST_F(DeclarableOpsTests1, softmax_test5) {
auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {1,1,1,1,1}); auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {1,1,1,1,1});
nd4j::ops::softmax op; nd4j::ops::softmax op;
auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {0});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -3241,7 +3254,7 @@ TEST_F(DeclarableOpsTests1, softmax_test6) {
auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {0.01198,0.08855,0.00441,0.24072,0.65434}); auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {0.01198,0.08855,0.00441,0.24072,0.65434});
nd4j::ops::softmax op; nd4j::ops::softmax op;
auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {0}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -3257,7 +3270,7 @@ TEST_F(DeclarableOpsTests1, softmax_test7) {
auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {1,1,1,1,1}); auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {1,1,1,1,1});
nd4j::ops::softmax op; nd4j::ops::softmax op;
auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {1}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -3273,7 +3286,7 @@ TEST_F(DeclarableOpsTests1, softmax_test8) {
auto expOutput = NDArrayFactory::create<double>('c', {5}, {0.01198,0.08855,0.00441,0.24072,0.65434}); auto expOutput = NDArrayFactory::create<double>('c', {5}, {0.01198,0.08855,0.00441,0.24072,0.65434});
nd4j::ops::softmax op; nd4j::ops::softmax op;
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -3294,7 +3307,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_1) {
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&input}, {}, {0}); auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -3316,7 +3329,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_2) {
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&input}, {}, {0}); auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -3338,7 +3351,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_3) {
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&input}, {}, {1}); auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -3364,7 +3377,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {0,1,2}); auto results = op.evaluate({&input}, {}, {0,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3389,7 +3402,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {}, {}, true); auto results = op.evaluate({&input}, {}, {}, {}, {}, true);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3414,7 +3427,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {1,2}); auto results = op.evaluate({&input}, {}, {1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3440,7 +3453,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {0,2}); auto results = op.evaluate({&input}, {}, {0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3466,7 +3479,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {0,1}); auto results = op.evaluate({&input}, {}, {0,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3491,7 +3504,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {2}, {}, true); auto results = op.evaluate({&input}, {}, {2}, {}, {}, true);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3518,7 +3531,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {1}); auto results = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3547,7 +3560,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {2,1}); auto results = op.evaluate({&input}, {}, {2,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3573,7 +3586,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9 ) {
NDArray output(shapeInfo); NDArray output(shapeInfo);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {0}); auto results = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3591,7 +3604,7 @@ TEST_F(DeclarableOpsTests1, Reverse_10 ) {
auto e = NDArrayFactory::create<double>('c', {4, 3}, {0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872,0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, 0.13950661}); auto e = NDArrayFactory::create<double>('c', {4, 3}, {0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872,0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, 0.13950661});
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto result = op.execute({&x, &i}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &i}, {}, {}, {});
auto z = result->at(0); auto z = result->at(0);
@ -3612,7 +3625,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) {
input.linspace(1); input.linspace(1);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {0, 1, 2}); auto results = op.evaluate({&input}, {}, {0, 1, 2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3633,7 +3646,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12 ) {
//input.linspace(1); //input.linspace(1);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {0}); auto results = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3655,7 +3668,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13 ) {
//input.linspace(1); //input.linspace(1);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {-1}); auto results = op.evaluate({&input}, {}, {-1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3676,7 +3689,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14 ) {
//input.linspace(1); //input.linspace(1);
nd4j::ops::reverse op; nd4j::ops::reverse op;
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&input}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3694,7 +3707,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) {
nd4j::ops::expose op; nd4j::ops::expose op;
auto result = op.execute({&input0, &input1}, {}, {}); auto result = op.evaluate({&input0, &input1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());

View File

@ -60,7 +60,7 @@ TEST_F(DeclarableOpsTests10, Test_ArgMax_1) {
nd4j::ops::argmax op; nd4j::ops::argmax op;
auto result = op.execute({&x}, {}, {}, {}); auto result = op.evaluate({&x});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -79,7 +79,7 @@ TEST_F(DeclarableOpsTests10, Test_ArgMax_2) {
x.linspace(1.0); x.linspace(1.0);
nd4j::ops::argmax op; nd4j::ops::argmax op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = *result->at(0); auto z = *result->at(0);
@ -98,7 +98,7 @@ TEST_F(DeclarableOpsTests10, Test_And_1) {
auto e = NDArrayFactory::create<double>('c', {4}, {0, 0, 0, 1}); auto e = NDArrayFactory::create<double>('c', {4}, {0, 0, 0, 1});
nd4j::ops::boolean_and op; nd4j::ops::boolean_and op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));
@ -112,7 +112,7 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) {
auto e = NDArrayFactory::create<double>('c', {4}, {1, 1, 0, 1}); auto e = NDArrayFactory::create<double>('c', {4}, {1, 1, 0, 1});
nd4j::ops::boolean_or op; nd4j::ops::boolean_or op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));
@ -127,7 +127,7 @@ TEST_F(DeclarableOpsTests10, Test_Not_1) {
auto e = NDArrayFactory::create<bool>('c', {4}, {false, false, true, false}); auto e = NDArrayFactory::create<bool>('c', {4}, {false, false, true, false});
nd4j::ops::boolean_not op; nd4j::ops::boolean_not op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto res = result->at(0); auto res = result->at(0);
@ -141,7 +141,7 @@ TEST_F(DeclarableOpsTests10, Test_Size_at_1) {
auto e = NDArrayFactory::create<Nd4jLong>(20); auto e = NDArrayFactory::create<Nd4jLong>(20);
nd4j::ops::size_at op; nd4j::ops::size_at op;
auto result = op.execute({&x}, {}, {1}); auto result = op.evaluate({&x}, {1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));
@ -161,7 +161,7 @@ TEST_F(DeclarableOpsTests10, MirrorPad_SGO_Test_1) {
nd4j::ops::mirror_pad op; nd4j::ops::mirror_pad op;
auto res = op.execute({&in, &pad}, {10.0}, {0}, {}, false, nd4j::DataType::DOUBLE); auto res = op.evaluate({&in, &pad}, {10.0}, {0});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -175,7 +175,7 @@ TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) {
auto exp = NDArrayFactory::create<double>({3., 4., 1., 0., 2.}); auto exp = NDArrayFactory::create<double>({3., 4., 1., 0., 2.});
nd4j::ops::unique op; nd4j::ops::unique op;
auto res = op.execute({&input}, {}, {}); auto res = op.evaluate({&input}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto res1 = res->at(0); auto res1 = res->at(0);
auto res2 = res->at(1); auto res2 = res->at(1);
@ -192,7 +192,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {6, 2}, {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {6, 2}, {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL});
nd4j::ops::Where op; nd4j::ops::Where op;
auto res = op.execute({&input}, {}, {}); auto res = op.evaluate({&input}, {}, {});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
@ -209,7 +209,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5, 3}, {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {5, 3}, {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL});
nd4j::ops::Where op; nd4j::ops::Where op;
auto res = op.execute({&input}, {}, {}); auto res = op.evaluate({&input}, {}, {});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
@ -227,7 +227,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) {
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 0, 0, 1}); auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 0, 0, 1});
auto exp3 = NDArrayFactory::create<Nd4jLong>({0, 1, 0, 1, 0}); auto exp3 = NDArrayFactory::create<Nd4jLong>({0, 1, 0, 1, 0});
nd4j::ops::where_np op; nd4j::ops::where_np op;
auto res = op.execute({&cond3d}, {}, {}); auto res = op.evaluate({&cond3d}, {}, {});
ASSERT_TRUE(res->size() == 3); ASSERT_TRUE(res->size() == 3);
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto res1 = res->at(0); auto res1 = res->at(0);
@ -251,7 +251,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2});
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4});
nd4j::ops::where_np op; nd4j::ops::where_np op;
auto res = op.execute({&cond2d}, {}, {}); auto res = op.evaluate({&cond2d}, {}, {});
ASSERT_TRUE(res->size() == 2); ASSERT_TRUE(res->size() == 2);
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
ASSERT_TRUE(exp1.equalsTo(res->at(0))); ASSERT_TRUE(exp1.equalsTo(res->at(0)));
@ -267,7 +267,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4,1}, {0, 2, 3, 4}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {4,1}, {0, 2, 3, 4});
nd4j::ops::Where op; nd4j::ops::Where op;
auto res = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::INT64); auto res = op.evaluate({&input});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
// resA->printIndexedBuffer("Result A"); // resA->printIndexedBuffer("Result A");
@ -285,7 +285,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
nd4j::ops::Where op; nd4j::ops::Where op;
auto res = op.execute({&input}, {}, {}); auto res = op.evaluate({&input}, {}, {});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
//resA->printIndexedBuffer("Result A"); //resA->printIndexedBuffer("Result A");
@ -303,7 +303,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
nd4j::ops::Where op; nd4j::ops::Where op;
auto res = op.execute({&input}, {}, {}); auto res = op.evaluate({&input}, {}, {});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
ASSERT_TRUE(resA->isEmpty()); ASSERT_TRUE(resA->isEmpty());
@ -322,7 +322,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {3, 1}, {0, 3, 4}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {3, 1}, {0, 3, 4});
nd4j::ops::Where op; nd4j::ops::Where op;
auto res = op.execute({&input}, {}, {}); auto res = op.evaluate({&input}, {}, {});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
//ASSERT_TRUE(resA->isEmpty()); //ASSERT_TRUE(resA->isEmpty());
@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
nd4j::ops::where_np op; nd4j::ops::where_np op;
auto res = op.execute({&input}, {}, {}); auto res = op.evaluate({&input}, {}, {});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
ASSERT_TRUE(resA->isEmpty()); ASSERT_TRUE(resA->isEmpty());
@ -361,7 +361,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) {
auto exp = NDArrayFactory::create<double>(0.6); auto exp = NDArrayFactory::create<double>(0.6);
nd4j::ops::cosine_distance_loss op; nd4j::ops::cosine_distance_loss op;
auto res = op.execute({&predictions, &weights, &labels}, {}, {3, 1}); auto res = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
@ -379,7 +379,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) {
auto exp = NDArrayFactory::create<double>(0.6); auto exp = NDArrayFactory::create<double>(0.6);
nd4j::ops::cosine_distance_loss op; nd4j::ops::cosine_distance_loss op;
auto res = op.execute({&predictions, &weights, &labels}, {}, {2, 1}); auto res = op.evaluate({&predictions, &weights, &labels}, {}, {2, 1});
ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
auto resA = res->at(0); auto resA = res->at(0);
@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) {
exp.p(1, 2, 0, 0.); exp.p(1, 2, 0, 0.);
nd4j::ops::matrix_band_part op; nd4j::ops::matrix_band_part op;
auto results = op.execute({&x}, {}, {1, 1}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&x}, {}, {1, 1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
//results->at(0)->printIndexedBuffer("MBP Test1"); //results->at(0)->printIndexedBuffer("MBP Test1");
@ -422,7 +422,7 @@ TEST_F(DeclarableOpsTests10, atan2_test1) {
0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266,}); 0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266,});
nd4j::ops::tf_atan2 op; nd4j::ops::tf_atan2 op;
auto result = op.execute({&y, &x}, {}, {}); auto result = op.evaluate({&y, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests10, atan2_test2) {
3.11208, 2.99987, 2.83399, 2.57869, 2.207 , 1.77611, 1.41664, 1.17298, 1.01458, 0.90829, 0.8336 , 0.77879}); 3.11208, 2.99987, 2.83399, 2.57869, 2.207 , 1.77611, 1.41664, 1.17298, 1.01458, 0.90829, 0.8336 , 0.77879});
nd4j::ops::tf_atan2 op; nd4j::ops::tf_atan2 op;
auto result = op.execute({&y, &x}, {}, {}); auto result = op.evaluate({&y, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
// z->printIndexedBuffer(); // z->printIndexedBuffer();
@ -465,7 +465,7 @@ TEST_F(DeclarableOpsTests10, atan2_test3) {
-1.54128, -1.42907, -1.2632 , -1.00789,-0.63621, -0.20531, 0.15416, 0.39782, 0.55622, 0.6625 , 0.7372 , 0.79201}); -1.54128, -1.42907, -1.2632 , -1.00789,-0.63621, -0.20531, 0.15416, 0.39782, 0.55622, 0.6625 , 0.7372 , 0.79201});
nd4j::ops::tf_atan2 op; nd4j::ops::tf_atan2 op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -485,7 +485,7 @@ TEST_F(DeclarableOpsTests10, atan2_test4) {
3.05688, 3.03942, 3.01293, 2.9681 , 2.18167, 1.87635, 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372 }); 3.05688, 3.03942, 3.01293, 2.9681 , 2.18167, 1.87635, 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372 });
nd4j::ops::tf_atan2 op; nd4j::ops::tf_atan2 op;
auto result = op.execute({&x, &y}, {}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -505,7 +505,7 @@ TEST_F(DeclarableOpsTests10, atan2_test5) {
-1.48608, -1.46862, -1.44214, -1.3973 ,-0.61088, -0.30556, 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336 }); -1.48608, -1.46862, -1.44214, -1.3973 ,-0.61088, -0.30556, 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336 });
nd4j::ops::tf_atan2 op; nd4j::ops::tf_atan2 op;
auto result = op.execute({&y, &x}, {}, {}, {}); auto result = op.evaluate({&y, &x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -524,7 +524,7 @@ TEST_F(DeclarableOpsTests10, atan2_test6) {
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {-2.25712, -1.68608, -1.44214, -0.54006,-2.77695, -2.16855, 0.34972, 0.24585, 2.71267, 1.74453, 1.45312, 0.8336 }); auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {-2.25712, -1.68608, -1.44214, -0.54006,-2.77695, -2.16855, 0.34972, 0.24585, 2.71267, 1.74453, 1.45312, 0.8336 });
nd4j::ops::tf_atan2 op; nd4j::ops::tf_atan2 op;
auto result = op.execute({&y, &x}, {}, {}, {}); auto result = op.evaluate({&y, &x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests10, IGamma_Test1) {
0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); 0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735});
nd4j::ops::igamma op; nd4j::ops::igamma op;
auto result = op.execute({&y, &x}, {}, {}, {}); auto result = op.evaluate({&y, &x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
// z->printBuffer("OUtput"); // z->printBuffer("OUtput");
@ -568,7 +568,7 @@ TEST_F(DeclarableOpsTests10, IGamma_Test2) {
0.999996, 0.999914, 0.999564, 0.998773}); 0.999996, 0.999914, 0.999564, 0.998773});
nd4j::ops::igammac op; nd4j::ops::igammac op;
auto result = op.execute({&y, &x}, {}, {}, {}); auto result = op.evaluate({&y, &x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
// z->printBuffer("OUtput"); // z->printBuffer("OUtput");
@ -591,7 +591,7 @@ TEST_F(DeclarableOpsTests10, LGamma_Test1) {
}); });
nd4j::ops::lgamma op; nd4j::ops::lgamma op;
auto result = op.execute({&x}, {}, {}, {}); auto result = op.evaluate({&x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
// z->printBuffer("OUtput"); // z->printBuffer("OUtput");
@ -610,7 +610,7 @@ TEST_F(DeclarableOpsTests10, range_test10) {
auto exp = NDArrayFactory::create<double>('c', {5}, {0.,1.,2.,3.,4.}); auto exp = NDArrayFactory::create<double>('c', {5}, {0.,1.,2.,3.,4.});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&limit}, {}, {}, {}); auto result = op.evaluate({&limit}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -632,7 +632,7 @@ TEST_F(DeclarableOpsTests10, range_test11) {
auto exp = NDArrayFactory::create<double>('c', {5}, {0.5,1.5,2.5,3.5,4.5}); auto exp = NDArrayFactory::create<double>('c', {5}, {0.5,1.5,2.5,3.5,4.5});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&start, &limit}, {}, {}, {}); auto result = op.evaluate({&start, &limit}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -650,7 +650,7 @@ TEST_F(DeclarableOpsTests10, range_test12) {
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f}); auto exp = NDArrayFactory::create<float>('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {0.5, 5, 0.5}, {}, {}); auto result = op.evaluate({}, {0.5, 5, 0.5}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -671,7 +671,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test1) {
nd4j::ops::top_k op; nd4j::ops::top_k op;
auto result = op.execute({&x}, {}, {4}, {false}); auto result = op.evaluate({&x}, {}, {4}, {false});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -681,7 +681,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test1) {
ASSERT_TRUE(expUnsorted.isSameShape(z)); ASSERT_TRUE(expUnsorted.isSameShape(z));
ASSERT_TRUE(expUnsorted.equalsTo(z)); ASSERT_TRUE(expUnsorted.equalsTo(z));
auto result2 = op.execute({&x}, {}, {5}, {true}); auto result2 = op.evaluate({&x}, {}, {5}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result2->status()); ASSERT_EQ(ND4J_STATUS_OK, result2->status());
@ -704,7 +704,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test2) {
nd4j::ops::top_k op; nd4j::ops::top_k op;
auto result = op.execute({&x}, {}, {5}, {false}); auto result = op.evaluate({&x}, {}, {5}, {false});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -714,7 +714,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test2) {
ASSERT_TRUE(expUnsorted.isSameShape(z)); ASSERT_TRUE(expUnsorted.isSameShape(z));
ASSERT_TRUE(expUnsorted.equalsTo(z)); ASSERT_TRUE(expUnsorted.equalsTo(z));
auto result2 = op.execute({&x}, {}, {5}, {true}); auto result2 = op.evaluate({&x}, {}, {5}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result2->status()); ASSERT_EQ(ND4J_STATUS_OK, result2->status());
@ -738,7 +738,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test1
logits.linspace(0.1, 0.1); logits.linspace(0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&labels, &logits});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -760,7 +760,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2
logits.linspace(0.1, 0.1); logits.linspace(0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&labels, &logits});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -782,7 +782,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3
logits.linspace(0.1, 0.1); logits.linspace(0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&labels, &logits});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -804,7 +804,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4
logits.linspace(0.1, 0.1); logits.linspace(0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&labels, &logits});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -825,7 +825,7 @@ TEST_F(DeclarableOpsTests10, split_test4) {
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f}); auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
nd4j::ops::split op; nd4j::ops::split op;
auto results = op.execute({&input, &axis}, {}, {2}, {}); auto results = op.evaluate({&input, &axis}, {}, {2}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -849,7 +849,7 @@ TEST_F(DeclarableOpsTests10, split_test5) {
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f}); auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
nd4j::ops::split op; nd4j::ops::split op;
auto results = op.execute({&input}, {}, {2,-1},{}); auto results = op.evaluate({&input}, {}, {2,-1},{});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -872,7 +872,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {2, 1, 1, 0, 2}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {2, 1, 1, 0, 2});
nd4j::ops::histogram_fixed_width op; nd4j::ops::histogram_fixed_width op;
auto results = op.execute({&input, &range}, {}, {5}, {}); auto results = op.evaluate({&input, &range}, {}, {5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -892,7 +892,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test2) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {5, 2, 5, 3, 9}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {5, 2, 5, 3, 9});
nd4j::ops::histogram_fixed_width op; nd4j::ops::histogram_fixed_width op;
auto results = op.execute({&input, &range}, {}, {5}, {}); auto results = op.evaluate({&input, &range}, {}, {5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -912,7 +912,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test3) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {5, 2, 5, 4, 8}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {5, 2, 5, 4, 8});
nd4j::ops::histogram_fixed_width op; nd4j::ops::histogram_fixed_width op;
auto results = op.execute({&input, &range}, {}, {5}, {}); auto results = op.evaluate({&input, &range}, {}, {5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -937,7 +937,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test4) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {22, 17, 24, 19, 18}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {22, 17, 24, 19, 18});
nd4j::ops::histogram_fixed_width op; nd4j::ops::histogram_fixed_width op;
auto results = op.execute({&input, &range}, {}, {5}, {}); auto results = op.evaluate({&input, &range}, {}, {5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -963,7 +963,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {23, 15, 24, 17, 21}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {23, 15, 24, 17, 21});
nd4j::ops::histogram_fixed_width op; nd4j::ops::histogram_fixed_width op;
auto results = op.execute({&input, &range}, {}, {5}, {}); auto results = op.evaluate({&input, &range}, {}, {5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -986,7 +986,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {3, 1, 2, 0, 1}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {3, 1, 2, 0, 1});
nd4j::ops::histogram_fixed_width op; nd4j::ops::histogram_fixed_width op;
auto results = op.execute({&input, &range, &bins}, {}, {}, {}); auto results = op.evaluate({&input, &range, &bins}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1010,7 +1010,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_1) {
//input.linspace(1.f); //input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {}); auto results = op.evaluate({&input, &n}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1032,7 +1032,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_2) {
// input.linspace(1.f); // input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {}); auto results = op.evaluate({&input, &n}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_3) {
//input.linspace(1.f); //input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {1}); // with reverse = true auto results = op.evaluate({&input, &n}, {}, {1}); // with reverse = true
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1076,7 +1076,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_4) {
//input.linspace(1.f); //input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {}); auto results = op.evaluate({&input, &n}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1097,7 +1097,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_04) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {}); auto results = op.evaluate({&input, &n}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1118,7 +1118,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_5) {
// input.linspace(1.f); // input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {1}); auto results = op.evaluate({&input, &n}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1140,7 +1140,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_6) {
// input.linspace(1.f); // input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {0}); auto results = op.evaluate({&input, &n}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1160,7 +1160,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) {
// input.linspace(1.f); // input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {1}); auto results = op.evaluate({&input, &n}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1186,7 +1186,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) {
//input.linspace(1.f); //input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {0}); auto results = op.evaluate({&input, &n}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1213,7 +1213,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) {
//input.linspace(1.f); //input.linspace(1.f);
nd4j::ops::nth_element op; nd4j::ops::nth_element op;
auto results = op.execute({&input, &n}, {}, {1}); auto results = op.evaluate({&input, &n}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1235,7 +1235,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test1) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1257,7 +1257,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test2) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1279,7 +1279,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test3) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1299,7 +1299,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test4) {
auto exp = NDArrayFactory::create<double>('c', {3,3}, {10.f, 10.f, 10.f,10.f, 10.f, 10.f, 10.f, 10.f, 10.f}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {10.f, 10.f, 10.f,10.f, 10.f, 10.f, 10.f, 10.f, 10.f});
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1319,7 +1319,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test5) {
auto exp = NDArrayFactory::create<double>('c', {3}, {10.f, 10.f, 10.f}); auto exp = NDArrayFactory::create<double>('c', {3}, {10.f, 10.f, 10.f});
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1339,7 +1339,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) {
auto exp = NDArrayFactory::create<double>('c', {1}, {10.f}); auto exp = NDArrayFactory::create<double>('c', {1}, {10.f});
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1359,7 +1359,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test7) {
auto exp = NDArrayFactory::create<double>('c', {1}, {10.}); auto exp = NDArrayFactory::create<double>('c', {1}, {10.});
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1381,7 +1381,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test8) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1403,7 +1403,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test9) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1425,7 +1425,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.evaluate({&input, &shape}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1481,7 +1481,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {10, 10}); auto results = op.evaluate({&input}, {}, {10, 10});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1503,7 +1503,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) {
auto size = NDArrayFactory::create<int>({65,65}); auto size = NDArrayFactory::create<int>({65,65});
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256}); auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input, &size}, {}, {}, {false}); auto results = op.evaluate({&input, &size}, {}, {}, {false});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1522,7 +1522,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) {
auto size = NDArrayFactory::create<int>({65,65}); auto size = NDArrayFactory::create<int>({65,65});
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256}); auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input, &size}, {}, {}, {true}); auto results = op.evaluate({&input, &size}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1566,7 +1566,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {4, 5}, {false, true}); auto results = op.evaluate({&input}, {}, {4, 5}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1613,7 +1613,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {4, 5}, {false, true}); auto results = op.evaluate({&input}, {}, {4, 5}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1669,7 +1669,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {10, 10}); auto results = op.evaluate({&input}, {}, {10, 10});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1824,7 +1824,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
//input.linspace(1); //input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {9, 9}); auto results = op.evaluate({&input}, {}, {9, 9});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1883,7 +1883,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2013,7 +2013,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {10, 10}, {true}); auto results = op.evaluate({&input}, {}, {10, 10}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2142,7 +2142,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
auto results = op.execute({&input, &size}, {}, {}, {true}); auto results = op.evaluate({&input, &size}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2166,7 +2166,7 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.});
nd4j::ops::lin_space op; nd4j::ops::lin_space op;
auto result = op.execute({&start, &finish, &num}, {}, {}); auto result = op.evaluate({&start, &finish, &num}, {}, {});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto res = result->at(0); auto res = result->at(0);
@ -2208,7 +2208,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_nearest_neighbor op; nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4, 5}, {false, false}); auto results = op.evaluate({&input}, {}, {4, 5}, {false, false});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_nearest_neighbor op; nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4, 5}); auto results = op.evaluate({&input}, {}, {4, 5});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2304,7 +2304,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_nearest_neighbor op; nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4,5}, {false, true}); auto results = op.evaluate({&input}, {}, {4,5}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2351,7 +2351,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) {
input.linspace(1); input.linspace(1);
nd4j::ops::resize_nearest_neighbor op; nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4, 5}); auto results = op.evaluate({&input}, {}, {4, 5});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2373,7 +2373,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) {
NDArray expected = NDArrayFactory::create<double>(2.5206409f); NDArray expected = NDArrayFactory::create<double>(2.5206409f);
nd4j::ops::reduce_logsumexp op; nd4j::ops::reduce_logsumexp op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2394,7 +2394,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) {
NDArray expected = NDArrayFactory::create<double>({1.0986123f, 1.8619947f, 1.0986123f}); NDArray expected = NDArrayFactory::create<double>({1.0986123f, 1.8619947f, 1.0986123f});
nd4j::ops::reduce_logsumexp op; nd4j::ops::reduce_logsumexp op;
auto results = op.execute({&input}, {}, {0}); auto results = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2414,7 +2414,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) {
NDArray expected = NDArrayFactory::create<float>('c', {1,3}, {1.0986123f, 1.8619947f, 1.0986123f}); NDArray expected = NDArrayFactory::create<float>('c', {1,3}, {1.0986123f, 1.8619947f, 1.0986123f});
nd4j::ops::reduce_logsumexp op; nd4j::ops::reduce_logsumexp op;
auto results = op.execute({&input}, {1.f}, {0}); auto results = op.evaluate({&input}, {1.f}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2435,7 +2435,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
boxes.linspace(1.f); boxes.linspace(1.f);
nd4j::ops::non_max_suppression op; nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scores}, {}, {3}); auto results = op.evaluate({&boxes, &scores}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2457,7 +2457,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5}); NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5});
nd4j::ops::non_max_suppression op; nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales}, {0.5}, {3}); auto results = op.evaluate({&boxes, &scales}, {0.5}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2479,7 +2479,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) {
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1}); NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
nd4j::ops::non_max_suppression op; nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales}, {0.5, 0.5}, {2}); auto results = op.evaluate({&boxes, &scales}, {0.5, 0.5}, {2});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2502,7 +2502,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) {
NDArray threshold = NDArrayFactory::create(0.5f); NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(0.5); NDArray scoreThreshold = NDArrayFactory::create(0.5);
nd4j::ops::non_max_suppression op; nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2524,7 +2524,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) {
NDArray threshold = NDArrayFactory::create(0.5f); NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>()); NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
nd4j::ops::non_max_suppression op; nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2547,7 +2547,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) {
NDArray threshold = NDArrayFactory::create(0.5f); NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>()); NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
nd4j::ops::non_max_suppression_v3 op; nd4j::ops::non_max_suppression_v3 op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2571,7 +2571,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) {
NDArray threshold = NDArrayFactory::create(0.5f); NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>()); NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
nd4j::ops::non_max_suppression_v3 op; nd4j::ops::non_max_suppression_v3 op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2594,7 +2594,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) {
NDArray threshold = NDArrayFactory::create(0.5f); NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(0.5f); NDArray scoreThreshold = NDArrayFactory::create(0.5f);
nd4j::ops::non_max_suppression_v3 op; nd4j::ops::non_max_suppression_v3 op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -2619,7 +2619,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) {
NDArray expected = NDArrayFactory::create<int>('c', {1,}, {3}); NDArray expected = NDArrayFactory::create<int>('c', {1,}, {3});
nd4j::ops::non_max_suppression_overlaps op; nd4j::ops::non_max_suppression_overlaps op;
auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2644,7 +2644,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) {
NDArray expected = NDArrayFactory::create<int>('c', {3,}, {1,1,1}); NDArray expected = NDArrayFactory::create<int>('c', {3,}, {1,1,1});
nd4j::ops::non_max_suppression_overlaps op; nd4j::ops::non_max_suppression_overlaps op;
auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2669,7 +2669,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) {
NDArray expected = NDArrayFactory::create<int>('c', {5,}, {1,1,1,1,1}); NDArray expected = NDArrayFactory::create<int>('c', {5,}, {1,1,1,1,1});
nd4j::ops::non_max_suppression_overlaps op; nd4j::ops::non_max_suppression_overlaps op;
auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2693,7 +2693,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
NDArray expected = NDArrayFactory::create<double>('c', {1,1,1,1}, {2.5f}); NDArray expected = NDArrayFactory::create<double>('c', {1,1,1,1}, {2.5f});
nd4j::ops::crop_and_resize op; nd4j::ops::crop_and_resize op;
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {}); auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2718,7 +2718,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) {
NDArray expected = NDArrayFactory::create<float>('c', {1,1,1,1}, {4.f}); NDArray expected = NDArrayFactory::create<float>('c', {1,1,1,1}, {4.f});
nd4j::ops::crop_and_resize op; nd4j::ops::crop_and_resize op;
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1}); auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2742,7 +2742,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32); NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
nd4j::ops::crop_and_resize op; nd4j::ops::crop_and_resize op;
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0}); auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2766,7 +2766,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32); NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
nd4j::ops::crop_and_resize op; nd4j::ops::crop_and_resize op;
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1}); auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2790,7 +2790,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) {
NDArray expected('c', {1, 10, 10,3}, nd4j::DataType::FLOAT32); NDArray expected('c', {1, 10, 10,3}, nd4j::DataType::FLOAT32);
nd4j::ops::crop_and_resize op; nd4j::ops::crop_and_resize op;
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1}); auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2826,7 +2826,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) {
}); });
images.linspace(1.); images.linspace(1.);
nd4j::ops::draw_bounding_boxes op; nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {}); auto results = op.evaluate({&images, &boxes, &colors}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2859,7 +2859,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) {
73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f }); 73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f });
images.linspace(1.1); images.linspace(1.1);
nd4j::ops::draw_bounding_boxes op; nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {}); auto results = op.evaluate({&images, &boxes, &colors}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2912,7 +2912,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) {
0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f, 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f,
0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f}); 0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f});
nd4j::ops::draw_bounding_boxes op; nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {}); auto results = op.evaluate({&images, &boxes, &colors}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0); auto result = results->at(0);
@ -2937,7 +2937,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32); NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
nd4j::ops::fake_quant_with_min_max_vars op; nd4j::ops::fake_quant_with_min_max_vars op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2958,7 +2958,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
NDArray max = NDArrayFactory::create<double>(0.1); NDArray max = NDArrayFactory::create<double>(0.1);
nd4j::ops::fake_quant_with_min_max_vars op; nd4j::ops::fake_quant_with_min_max_vars op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2979,7 +2979,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) {
NDArray max = NDArrayFactory::create<double>('c', {1}, {0.1}); NDArray max = NDArrayFactory::create<double>('c', {1}, {0.1});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3003,7 +3003,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) {
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3026,7 +3026,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) {
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {8}, {true}); auto results = op.evaluate({&x, &min, &max}, {}, {8}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3050,7 +3050,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) {
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {6}, {true}); auto results = op.evaluate({&x, &min, &max}, {}, {6}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3073,7 +3073,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) {
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {6}, {false}); auto results = op.evaluate({&x, &min, &max}, {}, {6}, {false});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3108,7 +3108,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
NDArray max = NDArrayFactory::create<float>({65.f, 70.f, 90.f}); NDArray max = NDArrayFactory::create<float>({65.f, 70.f, 90.f});
x.linspace(1.); x.linspace(1.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3161,7 +3161,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
NDArray max = NDArrayFactory::create<float>({20.f, 21.f, 22.f, 23.f}); NDArray max = NDArrayFactory::create<float>({20.f, 21.f, 22.f, 23.f});
x.linspace(-60.); x.linspace(-60.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3195,7 +3195,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
NDArray max = NDArrayFactory::create<float>('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); NDArray max = NDArrayFactory::create<float>('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
// x.linspace(-60.); // x.linspace(-60.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3241,7 +3241,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) {
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f}); NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
x.linspace(0., 0.01); x.linspace(0., 0.01);
nd4j::ops::fake_quant_with_min_max_vars op; nd4j::ops::fake_quant_with_min_max_vars op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3266,7 +3266,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) {
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f}); NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
x.linspace(0., 0.1); x.linspace(0., 0.1);
nd4j::ops::fake_quant_with_min_max_vars op; nd4j::ops::fake_quant_with_min_max_vars op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());

View File

@ -43,7 +43,7 @@ TEST_F(DeclarableOpsTests11, test_listdiff_1) {
auto y = NDArrayFactory::create<int>('c',{2}, {3, 1}); auto y = NDArrayFactory::create<int>('c',{2}, {3, 1});
nd4j::ops::listdiff op; nd4j::ops::listdiff op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test1) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}, {}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -100,7 +100,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test2) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -130,7 +130,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -162,7 +162,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test4) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -195,7 +195,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test5) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -227,7 +227,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test6) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -253,7 +253,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -288,7 +288,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test8) {
weights.p(3, 0.); weights.p(3, 0.);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -325,7 +325,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test9) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -357,7 +357,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -383,7 +383,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test11) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -420,7 +420,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test12) {
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -459,7 +459,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
weights.t<double>(2) = 0.; weights.t<double>(2) = 0.;
nd4j::ops::log_loss_grad op; nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -642,7 +642,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
auto size = NDArrayFactory::create<int>({30, 30}); auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op; nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0); NDArray* result = results->at(0);
@ -716,7 +716,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
input.linspace(1); input.linspace(1);
auto size = NDArrayFactory::create<int>({10, 8}); auto size = NDArrayFactory::create<int>({10, 8});
nd4j::ops::resize_bicubic op; nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -753,7 +753,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
input.linspace(1); input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6}); auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_bicubic op; nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -790,7 +790,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
input.linspace(1); input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 8}); auto size = NDArrayFactory::create<int>({6, 8});
nd4j::ops::resize_bicubic op; nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -833,7 +833,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
input.linspace(1); input.linspace(1);
auto size = NDArrayFactory::create<int>({8, 8}); auto size = NDArrayFactory::create<int>({8, 8});
nd4j::ops::resize_bicubic op; nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -963,7 +963,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
auto size = NDArrayFactory::create<int>({30, 30}); auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op; nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0); NDArray* result = results->at(0);
@ -1021,7 +1021,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) {
}); });
auto size = NDArrayFactory::create<int>({9, 9}); auto size = NDArrayFactory::create<int>({9, 9});
nd4j::ops::resize_bicubic op; nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1074,7 +1074,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
auto size = NDArrayFactory::create<int>({9, 9}); auto size = NDArrayFactory::create<int>({9, 9});
nd4j::ops::resize_bicubic op; nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}, {true, false}); auto results = op.evaluate({&input, &size}, {}, {}, {true, false});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1135,7 +1135,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
input.linspace(1); input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6}); auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1162,7 +1162,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) {
input.linspace(1); input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6}); auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1190,7 +1190,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) {
input.linspace(1); input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6}); auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1228,7 +1228,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) {
//input.linspace(1); //input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6}); auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1266,7 +1266,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) {
//input.linspace(1); //input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6}); auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1304,7 +1304,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) {
//input.linspace(1); //input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6}); auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {}, {true}); auto results = op.evaluate({&input, &size}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1342,7 +1342,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) {
//input.linspace(1); //input.linspace(1);
// auto size = NDArrayFactory::create<int>({6, 6}); // auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input}, {}, {6, 6}, {true}); auto results = op.evaluate({&input}, {}, {6, 6}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1372,7 +1372,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) {
//input.linspace(1); //input.linspace(1);
// auto size = NDArrayFactory::create<int>({6, 6}); // auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input}, {}, {6, 6}, {true}); auto results = op.evaluate({&input}, {}, {6, 6}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) {
//input.linspace(1); //input.linspace(1);
auto size = NDArrayFactory::create<int>({10, 10}); auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {}); auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1426,7 +1426,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) {
//input.linspace(1); //input.linspace(1);
//auto size = NDArrayFactory::create<int>({10, 10}); //auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input}, {}, {10, 10}); auto results = op.evaluate({&input}, {}, {10, 10});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1453,7 +1453,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) {
//input.linspace(1); //input.linspace(1);
//auto size = NDArrayFactory::create<int>({10, 10}); //auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input}, {}, {6, 9}); auto results = op.evaluate({&input}, {}, {6, 9});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1480,7 +1480,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) {
//input.linspace(1); //input.linspace(1);
//auto size = NDArrayFactory::create<int>({10, 10}); //auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input}, {}, {10, 15}); auto results = op.evaluate({&input}, {}, {10, 15});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1507,7 +1507,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) {
//input.linspace(1); //input.linspace(1);
//auto size = NDArrayFactory::create<int>({10, 10}); //auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op; nd4j::ops::resize_area op;
auto results = op.execute({&input}, {}, {9, 9}); auto results = op.evaluate({&input}, {}, {9, 9});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1558,7 +1558,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1618,7 +1618,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1650,7 +1650,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1679,7 +1679,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1711,7 +1711,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1737,7 +1737,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1770,7 +1770,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) {
weights.p(3, 0.); weights.p(3, 0.);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1805,7 +1805,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1863,7 +1863,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1896,7 +1896,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) {
weights.t<double>(3) = 0.; weights.t<double>(3) = 0.;
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1933,7 +1933,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) {
weights.t<double>(2) = 0.; weights.t<double>(2) = 0.;
nd4j::ops::mean_sqerr_loss_grad op; nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1956,7 +1956,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) {
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0}); auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9}); auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9});
nd4j::ops::squaredsubtract op; nd4j::ops::squaredsubtract op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0)));
@ -1968,7 +1968,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) {
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0}); auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9}); auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9});
nd4j::ops::squaredsubtract op; nd4j::ops::squaredsubtract op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result; delete result;
@ -1980,7 +1980,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) {
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48}); auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48});
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8}); auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8});
nd4j::ops::squaredsubtract_bp op; nd4j::ops::squaredsubtract_bp op;
auto result = op.execute({&x, &y, &eps}, {}, {}); auto result = op.evaluate({&x, &y, &eps}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result; delete result;
@ -2003,7 +2003,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2035,7 +2035,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2063,7 +2063,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2095,7 +2095,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2124,7 +2124,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2156,7 +2156,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2182,7 +2182,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2215,7 +2215,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) {
weights.p(3, 0.); weights.p(3, 0.);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2250,7 +2250,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2282,7 +2282,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2308,7 +2308,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2341,7 +2341,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) {
weights.t<double>(3) = 0.; weights.t<double>(3) = 0.;
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2378,7 +2378,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) {
weights.t<double>(2) = 0.; weights.t<double>(2) = 0.;
nd4j::ops::absolute_difference_loss_grad op; nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2407,7 +2407,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_1) {
y.linspace(1); y.linspace(1);
exp.linspace(2,2); exp.linspace(2,2);
nd4j::ops::add op; nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2428,7 +2428,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_2) {
y.linspace(1); y.linspace(1);
exp.linspace(2,2); exp.linspace(2,2);
nd4j::ops::add op; nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2449,7 +2449,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_3) {
y.linspace(1); y.linspace(1);
exp.linspace(2,2); exp.linspace(2,2);
nd4j::ops::add op; nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2478,7 +2478,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2514,7 +2514,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {0}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2550,7 +2550,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2582,7 +2582,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2613,7 +2613,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2645,7 +2645,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2671,7 +2671,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2705,7 +2705,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) {
weights.p(3, 0.); weights.p(3, 0.);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2741,7 +2741,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2773,7 +2773,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2799,7 +2799,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) {
weights.assign(0.5); weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2834,7 +2834,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) {
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2872,7 +2872,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) {
weights.t<double>(2) = 0.; weights.t<double>(2) = 0.;
nd4j::ops::sigm_cross_entropy_loss_grad op; nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2901,7 +2901,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_4) {
y.linspace(1); y.linspace(1);
exp.linspace(2,2); exp.linspace(2,2);
nd4j::ops::add op; nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2922,7 +2922,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_5) {
y.linspace(1); y.linspace(1);
exp.linspace(1); exp.linspace(1);
nd4j::ops::subtract op; nd4j::ops::subtract op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2943,7 +2943,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_6) {
y.linspace(1); y.linspace(1);
exp.linspace(1); exp.linspace(1);
nd4j::ops::subtract op; nd4j::ops::subtract op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2968,7 +2968,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) {
nd4j::ops::softmax_cross_entropy_loss_grad op; nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
nd4j::ops::softmax_cross_entropy_loss_grad op; nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1}); auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3030,7 +3030,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
nd4j::ops::softmax_cross_entropy_loss_grad op; nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1}); auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3061,7 +3061,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
nd4j::ops::softmax_cross_entropy_loss_grad op; nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2}); auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3092,7 +3092,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
nd4j::ops::softmax_cross_entropy_loss_grad op; nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3}); auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3123,7 +3123,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) {
nd4j::ops::softmax_cross_entropy_loss_grad op; nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3154,7 +3154,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) {
nd4j::ops::softmax_cross_entropy_loss_grad op; nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3}); auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3196,7 +3196,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) {
nd4j::ops::softmax_cross_entropy_loss_grad op; nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2}); auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3237,7 +3237,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) {
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {}); auto results = op.evaluate({&logits, &labels}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3261,7 +3261,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) {
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {1}); auto results = op.evaluate({&logits, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3284,7 +3284,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) {
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0}); auto results = op.evaluate({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3306,7 +3306,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {1}); auto results = op.evaluate({&logits, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3328,7 +3328,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0}); auto results = op.evaluate({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3350,7 +3350,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0}); auto results = op.evaluate({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3372,7 +3372,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0}); auto results = op.evaluate({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3394,7 +3394,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0}); auto results = op.evaluate({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3421,7 +3421,7 @@ TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) {
dLdpExp.assign(1.0); dLdpExp.assign(1.0);
nd4j::ops::multiply_bp op; nd4j::ops::multiply_bp op;
auto results = op.execute({&x, &y, &dLdp}, {}, {}); auto results = op.evaluate({&x, &y, &dLdp}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3444,7 +3444,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) {
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {}); auto results = op.evaluate({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3468,7 +3468,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {}); auto results = op.evaluate({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3490,7 +3490,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {}); auto results = op.evaluate({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3514,7 +3514,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {}); auto results = op.evaluate({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3536,7 +3536,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {}); auto results = op.evaluate({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());

View File

@ -44,7 +44,7 @@ TEST_F(DeclarableOpsTests12, test_any_validation_1) {
auto y = NDArrayFactory::create<int>('c', {2}, {1, 0}); auto y = NDArrayFactory::create<int>('c', {2}, {1, 0});
nd4j::ops::transpose op; nd4j::ops::transpose op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -69,7 +69,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test1) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0, -1}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, -1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -101,7 +101,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 0}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 0}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -169,7 +169,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1, 1}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -204,7 +204,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2, 0}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -238,7 +238,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test6) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3, 1}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -274,7 +274,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test7) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2, 0}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -310,7 +310,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test8) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3, 1}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -346,7 +346,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) {
nd4j::ops::cosine_distance_loss_grad op; nd4j::ops::cosine_distance_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 2}); auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -422,7 +422,7 @@ TEST_F(DeclarableOpsTests12, TestDivideBP_2) {
eps.linspace(1.); eps.linspace(1.);
nd4j::ops::divide_bp op; nd4j::ops::divide_bp op;
Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector<NDArray*>{&output1, &output2}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output1.equalsTo(exp1)); ASSERT_TRUE(output1.equalsTo(exp1));
@ -443,7 +443,7 @@ TEST_F(DeclarableOpsTests12, TestReverseDivideBP_1) {
eps.linspace(1.); eps.linspace(1.);
nd4j::ops::reversedivide_bp op; nd4j::ops::reversedivide_bp op;
Nd4jStatus status = op.execute({&y, &x, &eps}, {&output2, &output1}, {}, {}, {}); Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector<NDArray*>{&output2, &output1}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
//ASSERT_TRUE(output.e<double>(0) == 47.); //ASSERT_TRUE(output.e<double>(0) == 47.);
@ -467,7 +467,7 @@ TEST_F(DeclarableOpsTests12, TestReverseDivideBP_2) {
exp1.assign(1.); exp1.assign(1.);
exp2.assign(-2.); exp2.assign(-2.);
nd4j::ops::reversedivide_bp op; nd4j::ops::reversedivide_bp op;
Nd4jStatus status = op.execute({&y, &x, &eps}, {&output2, &output1}, {}, {}, {}); Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector<NDArray*>{&output2, &output1}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output1.equalsTo(exp1)); ASSERT_TRUE(output1.equalsTo(exp1));
@ -539,7 +539,7 @@ TEST_F(DeclarableOpsTests12, TestMaximumBP_1) {
//exp1.assign(1.); //exp1.assign(1.);
//exp2.assign(-2.); //exp2.assign(-2.);
nd4j::ops::maximum_bp op; nd4j::ops::maximum_bp op;
Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector<NDArray*>{&output1, &output2}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output1.equalsTo(exp1)); ASSERT_TRUE(output1.equalsTo(exp1));
@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) {
//exp1.assign(1.); //exp1.assign(1.);
//exp2.assign(-2.); //exp2.assign(-2.);
nd4j::ops::minimum_bp op; nd4j::ops::minimum_bp op;
Nd4jStatus status = op.execute({&x, &y, &eps}, {&output2, &output1}, {}, {}, {}); Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector<NDArray*>{&output2, &output1}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output1.equalsTo(exp1)); ASSERT_TRUE(output1.equalsTo(exp1));
@ -716,7 +716,7 @@ TEST_F(DeclarableOpsTests12, tensormmul_6) {
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {1,0, 1,1}); auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -743,7 +743,7 @@ TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {
exp = 0.333333; exp = 0.333333;
nd4j::ops::reduce_mean_bp op; nd4j::ops::reduce_mean_bp op;
auto result = op.execute({&x, &gradO}, {}, {0}); auto result = op.evaluate({&x, &gradO}, {}, {0});
auto output = result->at(0); auto output = result->at(0);
// output->printShapeInfo(); // output->printShapeInfo();
@ -765,7 +765,7 @@ TEST_F(DeclarableOpsTests12, reduceMeanBp_5) {
exp = 0.2; exp = 0.2;
nd4j::ops::reduce_mean_bp op; nd4j::ops::reduce_mean_bp op;
auto result = op.execute({&x, &gradO}, {}, {1}); auto result = op.evaluate({&x, &gradO}, {}, {1});
auto output = result->at(0); auto output = result->at(0);
// output->printShapeInfo(); // output->printShapeInfo();
@ -783,7 +783,7 @@ TEST_F(DeclarableOpsTests12, reduceSqnormBp_1) {
NDArray gradO('c', {8,6,1}, nd4j::DataType::DOUBLE); NDArray gradO('c', {8,6,1}, nd4j::DataType::DOUBLE);
nd4j::ops::reduce_sqnorm_bp op; nd4j::ops::reduce_sqnorm_bp op;
auto result = op.execute({&x, &gradO}, {1}, {2}); auto result = op.evaluate({&x, &gradO}, {1}, {2});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -937,7 +937,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_1) {
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
auto results = op.execute({&input, &gradO}, {1., 1., 1}, {5}); auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(*gradI, exp); ASSERT_EQ(*gradI, exp);
@ -968,7 +968,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_2) {
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
auto results = op.execute({&input, &gradO}, {1., 1., 1}, {2}); auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(*gradI, exp); ASSERT_EQ(*gradI, exp);
@ -999,7 +999,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_3) {
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
auto results = op.execute({&input, &gradO}, {1., 1., 1}, {7}); auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(*gradI, exp); ASSERT_EQ(*gradI, exp);
@ -1030,7 +1030,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_4) {
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
auto results = op.execute({&input, &gradO}, {1., 1., 1}, {12}); auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(*gradI, exp); ASSERT_EQ(*gradI, exp);
@ -1053,7 +1053,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_5) {
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
auto results = op.execute({&input, &gradO}, {1., 1., 0.5}, {2}); auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(*gradI, exp); ASSERT_EQ(*gradI, exp);
@ -1072,7 +1072,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_6) {
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {10}); auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(*gradI, exp); ASSERT_EQ(*gradI, exp);
@ -1126,7 +1126,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_9) {
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {3}); auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {3});
auto gradI = results->at(0); auto gradI = results->at(0);
// for (int i = 0; i < exp.lengthOf(); ++i) // for (int i = 0; i < exp.lengthOf(); ++i)
@ -1146,7 +1146,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_10) {
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {1}); auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(*gradI, exp); ASSERT_EQ(*gradI, exp);
@ -1167,7 +1167,7 @@ TEST_F(DeclarableOpsTests12, lrn_1) {
nd4j::ops::lrn op; nd4j::ops::lrn op;
auto results = op.execute({&input}, {1., 2., 0.5}, {2}); auto results = op.evaluate({&input}, {1., 2., 0.5}, {2});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(*output, exp); ASSERT_EQ(*output, exp);
@ -1183,7 +1183,7 @@ TEST_F(DeclarableOpsTests12, lrn_2) {
nd4j::ops::lrn op; nd4j::ops::lrn op;
auto results = op.execute({&input}, {0.1, 2., 0.5}, {5}); auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(*output, exp); ASSERT_EQ(*output, exp);
@ -1198,7 +1198,7 @@ TEST_F(DeclarableOpsTests12, lrn_3) {
nd4j::ops::lrn op; nd4j::ops::lrn op;
auto results = op.execute({&input}, {0.1, 2., 0.5}, {5}); auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(*output, exp); ASSERT_EQ(*output, exp);
@ -1213,7 +1213,7 @@ TEST_F(DeclarableOpsTests12, lrn_4) {
nd4j::ops::lrn op; nd4j::ops::lrn op;
auto results = op.execute({&input}, {0.1, 2., 0.5}, {0}); auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(*output, exp); ASSERT_EQ(*output, exp);
@ -1228,7 +1228,7 @@ TEST_F(DeclarableOpsTests12, lrn_5) {
nd4j::ops::lrn op; nd4j::ops::lrn op;
auto results = op.execute({&input}, {0.1, 2., 0.5}, {0}); auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(*output, exp); ASSERT_EQ(*output, exp);
@ -1268,7 +1268,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) {
nd4j::ops::in_top_k op; nd4j::ops::in_top_k op;
auto res = op.execute({&input, &idx}, {}, {1}, {}, false, nd4j::DataType::BOOL); auto res = op.evaluate({&input, &idx}, {}, {1});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
//res->at(0)->printIndexedBuffer("IN_TOP_K output"); //res->at(0)->printIndexedBuffer("IN_TOP_K output");
@ -1283,7 +1283,7 @@ TEST_F(DeclarableOpsTests12, inTopK_3) {
auto expV = NDArrayFactory::create<bool>('c', {2}, {true, false}); auto expV = NDArrayFactory::create<bool>('c', {2}, {true, false});
nd4j::ops::in_top_k op; nd4j::ops::in_top_k op;
auto result = op.execute({&x, &y}, {}, {2}); auto result = op.evaluate({&x, &y}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -1303,7 +1303,7 @@ TEST_F(DeclarableOpsTests12, inTopK_4) {
auto expV = NDArrayFactory::create<bool>('c', {6}, {true, false, true, false, false, true}); auto expV = NDArrayFactory::create<bool>('c', {6}, {true, false, true, false, false, true});
nd4j::ops::in_top_k op; nd4j::ops::in_top_k op;
auto result = op.execute({&x, &y}, {}, {2}); auto result = op.evaluate({&x, &y}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -1324,7 +1324,7 @@ TEST_F(DeclarableOpsTests12, inTopK_5) {
auto expV = NDArrayFactory::create<bool>('f', {6}, {true, false, false, false, false, false }); auto expV = NDArrayFactory::create<bool>('f', {6}, {true, false, false, false, false, false });
nd4j::ops::in_top_k op; nd4j::ops::in_top_k op;
auto result = op.execute({&x, &y}, {}, {2}); auto result = op.evaluate({&x, &y}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -1345,7 +1345,7 @@ TEST_F(DeclarableOpsTests12, cube_1) {
nd4j::ops::cube op; nd4j::ops::cube op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1368,7 +1368,7 @@ TEST_F(DeclarableOpsTests12, cube_bp_1) {
nd4j::ops::cube_bp op; nd4j::ops::cube_bp op;
auto result = op.execute({&x, &gradO}, {}, {}); auto result = op.evaluate({&x, &gradO});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1391,7 +1391,7 @@ TEST_F(DeclarableOpsTests12, pad_tests1) {
NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, nd4j::DataType::FLOAT32); NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, nd4j::DataType::FLOAT32);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1418,7 +1418,7 @@ TEST_F(DeclarableOpsTests12, pad_tests2) {
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7}); auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1445,7 +1445,7 @@ TEST_F(DeclarableOpsTests12, pad_tests3) {
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7}); auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1476,7 +1476,7 @@ TEST_F(DeclarableOpsTests12, pad_tests4) {
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1510,7 +1510,7 @@ TEST_F(DeclarableOpsTests12, pad_tests5) {
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1537,7 +1537,7 @@ TEST_F(DeclarableOpsTests12, pad_tests6) {
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1563,7 +1563,7 @@ TEST_F(DeclarableOpsTests12, pad_tests7)
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1589,7 +1589,7 @@ TEST_F(DeclarableOpsTests12, pad_tests8)
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1615,7 +1615,7 @@ TEST_F(DeclarableOpsTests12, pad_tests9)
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1638,7 +1638,7 @@ TEST_F(DeclarableOpsTests12, pad_tests10) {
input = 1.f; input = 1.f;
//input.assign(1.); //input.assign(1.);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1660,7 +1660,7 @@ TEST_F(DeclarableOpsTests12, pad_tests11) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1688,7 +1688,7 @@ TEST_F(DeclarableOpsTests12, pad_tests12) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1710,7 +1710,7 @@ TEST_F(DeclarableOpsTests12, pad_tests13) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1732,7 +1732,7 @@ TEST_F(DeclarableOpsTests12, pad_tests14) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1753,7 +1753,7 @@ TEST_F(DeclarableOpsTests12, pad_tests15) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1774,7 +1774,7 @@ TEST_F(DeclarableOpsTests12, pad_tests16) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1795,7 +1795,7 @@ TEST_F(DeclarableOpsTests12, pad_tests17) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1816,7 +1816,7 @@ TEST_F(DeclarableOpsTests12, pad_tests18) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests12, pad_tests19) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1858,7 +1858,7 @@ TEST_F(DeclarableOpsTests12, pad_tests20) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1880,7 +1880,7 @@ TEST_F(DeclarableOpsTests12, pad_tests21) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1903,7 +1903,7 @@ TEST_F(DeclarableOpsTests12, pad_tests22) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1926,7 +1926,7 @@ TEST_F(DeclarableOpsTests12, pad_tests23) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1950,7 +1950,7 @@ TEST_F(DeclarableOpsTests12, pad_tests24) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1972,7 +1972,7 @@ TEST_F(DeclarableOpsTests12, pad_tests25) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1994,7 +1994,7 @@ TEST_F(DeclarableOpsTests12, pad_tests26) {
input.linspace(1.f); input.linspace(1.f);
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2054,7 +2054,7 @@ TEST_F(DeclarableOpsTests12, pad_tests29) {
nd4j::ops::pad op; nd4j::ops::pad op;
auto res = op.execute({&in, &pad}, {10.0}, {0}); auto res = op.evaluate({&in, &pad}, {10.0}, {0});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res; delete res;
@ -2071,7 +2071,7 @@ TEST_F(DeclarableOpsTests12, pad_tests30) {
nd4j::ops::pad op; nd4j::ops::pad op;
auto res = op.execute({&in, &pad}, {10.0}, {2}); auto res = op.evaluate({&in, &pad}, {10.0}, {2});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res; delete res;
@ -2089,7 +2089,7 @@ TEST_F(DeclarableOpsTests12, pad_tests31) {
nd4j::ops::pad op; nd4j::ops::pad op;
auto res = op.execute({&in, &pad}, {10.0}, {1}); auto res = op.evaluate({&in, &pad}, {10.0}, {1});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res; delete res;
@ -2105,7 +2105,7 @@ TEST_F(DeclarableOpsTests12, pad_tests32) {
nd4j::ops::pad op; nd4j::ops::pad op;
auto res = op.execute({&in, &pad}, {10.0}, {2}); auto res = op.evaluate({&in, &pad}, {10.0}, {2});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res; delete res;
@ -2128,7 +2128,7 @@ TEST_F(DeclarableOpsTests12, pad_tests33) {
11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.}); 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.});
nd4j::ops::pad op; nd4j::ops::pad op;
auto res = op.execute({&in, &pad}, {10.0}, {2}); auto res = op.evaluate({&in, &pad}, {10.0}, {2});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res; delete res;
@ -2163,7 +2163,7 @@ TEST_F(DeclarableOpsTests12, Pad_1) {
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2190,7 +2190,7 @@ TEST_F(DeclarableOpsTests12, Pad_2) {
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2217,7 +2217,7 @@ TEST_F(DeclarableOpsTests12, Pad_3) {
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2244,7 +2244,7 @@ TEST_F(DeclarableOpsTests12, Pad_4) {
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2271,7 +2271,7 @@ TEST_F(DeclarableOpsTests12, Pad_5) {
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2298,7 +2298,7 @@ TEST_F(DeclarableOpsTests12, Pad_6) {
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2324,7 +2324,7 @@ TEST_F(DeclarableOpsTests12, Pad_7)
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.evaluate({&input, &paddings}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2350,7 +2350,7 @@ TEST_F(DeclarableOpsTests12, Pad_8)
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.evaluate({&input, &paddings}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2376,7 +2376,7 @@ TEST_F(DeclarableOpsTests12, Pad_9)
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.evaluate({&input, &paddings}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2395,7 +2395,7 @@ TEST_F(DeclarableOpsTests12, Test_Expose_1) {
nd4j::ops::expose op; nd4j::ops::expose op;
auto result = op.execute({&input0, &input1}, {}, {}); auto result = op.evaluate({&input0, &input1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2420,7 +2420,7 @@ TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) {
nd4j::ops::pad op; nd4j::ops::pad op;
auto res = op.execute({&in, &pad}, {10.0}, {0}); auto res = op.evaluate({&in, &pad}, {10.0}, {0});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
// res->at(0)->printIndexedBuffer("PAD_SGO"); // res->at(0)->printIndexedBuffer("PAD_SGO");
// exp.printIndexedBuffer("PAD_EXP"); // exp.printIndexedBuffer("PAD_EXP");
@ -2436,7 +2436,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_1) {
auto pExp = NDArrayFactory::create<int>('c', {3}, {0, 1, 2}); auto pExp = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2457,7 +2457,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_2) {
auto expP = NDArrayFactory::create<int>({2, 0, 1}); auto expP = NDArrayFactory::create<int>({2, 0, 1});
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2480,7 +2480,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3) {
auto expP = NDArrayFactory::create<int>({2, 1, 0}); auto expP = NDArrayFactory::create<int>({2, 1, 0});
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2522,7 +2522,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4) {
auto expP = NDArrayFactory::create<int>({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); auto expP = NDArrayFactory::create<int>({1, 2, 7, 3, 6, 8, 5, 4, 0, 9});
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2592,7 +2592,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_5) {
}); });
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2613,7 +2613,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_1_2) {
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2641,7 +2641,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_2) {
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 2, 1, 0}); auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 2, 1, 0});
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2669,7 +2669,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_3) {
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 0, 2, 1}); auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 0, 2, 1});
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2697,7 +2697,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_1) {
auto expP = NDArrayFactory::create<int>('c', {2,2}, {0, 1, 0, 1}); auto expP = NDArrayFactory::create<int>('c', {2,2}, {0, 1, 0, 1});
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {}); auto res = op.evaluate({&in});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2725,7 +2725,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_2) {
auto expP = NDArrayFactory::create<Nd4jLong>('c', {2,2}, {0, 1, 0, 1}); auto expP = NDArrayFactory::create<Nd4jLong>('c', {2,2}, {0, 1, 0, 1});
nd4j::ops::lu op; nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {nd4j::DataType::INT64}); auto res = op.evaluate({&in}, {}, {nd4j::DataType::INT64});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
auto p = res->at(1); auto p = res->at(1);
@ -2750,7 +2750,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1) {
auto expR = NDArrayFactory::create<double>('c', {5,3}, { auto expR = NDArrayFactory::create<double>('c', {5,3}, {
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. }); -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. });
nd4j::ops::qr op; nd4j::ops::qr op;
auto res = op.execute({&in}, {}, {}, {true}); auto res = op.evaluate({&in}, {}, {}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto q = res->at(0); auto q = res->at(0);
@ -2762,7 +2762,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1) {
// q->printShapeInfo("Q shape"); // q->printShapeInfo("Q shape");
// r->printShapeInfo("R shape"); // r->printShapeInfo("R shape");
nd4j::ops::matmul opMul; nd4j::ops::matmul opMul;
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false);
auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
ASSERT_TRUE(exp->isSameShape(in)); ASSERT_TRUE(exp->isSameShape(in));
// ASSERT_TRUE(q->isSameShape(expQ)); // ASSERT_TRUE(q->isSameShape(expQ));
@ -2797,7 +2797,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) {
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0.
}); });
nd4j::ops::qr op; nd4j::ops::qr op;
auto res = op.execute({&in}, {}, {}, {true}); auto res = op.evaluate({&in}, {}, {}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto q = res->at(0); auto q = res->at(0);
@ -2809,7 +2809,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) {
// q->printShapeInfo("Q shape"); // q->printShapeInfo("Q shape");
// r->printShapeInfo("R shape"); // r->printShapeInfo("R shape");
nd4j::ops::matmul opMul; nd4j::ops::matmul opMul;
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false);
auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
ASSERT_TRUE(exp->isSameShape(in)); ASSERT_TRUE(exp->isSameShape(in));
// ASSERT_TRUE(q->isSameShape(expQ)); // ASSERT_TRUE(q->isSameShape(expQ));
@ -2836,7 +2836,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) {
}); });
nd4j::ops::qr op; nd4j::ops::qr op;
auto res = op.execute({&in}, {}, {}, {false}); auto res = op.evaluate({&in}, {}, {}, {false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto q = res->at(0); auto q = res->at(0);
@ -2847,7 +2847,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) {
// r->printIndexedBuffer("Upper triangular 5x3"); // r->printIndexedBuffer("Upper triangular 5x3");
nd4j::ops::matmul opMul; nd4j::ops::matmul opMul;
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false);
auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
ASSERT_TRUE(exp->isSameShape(in)); ASSERT_TRUE(exp->isSameShape(in));
ASSERT_TRUE(exp->equalsTo(in)); ASSERT_TRUE(exp->equalsTo(in));
@ -2874,7 +2874,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) {
nd4j::ops::triangular_solve op; nd4j::ops::triangular_solve op;
auto res = op.execute({&a, &b}, {}, {}); auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
@ -2903,7 +2903,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_2) {
nd4j::ops::triangular_solve op; nd4j::ops::triangular_solve op;
auto res = op.execute({&a, &b}, {}, {}); auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
@ -2940,7 +2940,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_3) {
nd4j::ops::triangular_solve op; nd4j::ops::triangular_solve op;
auto res = op.execute({&a, &b}, {}, {}); auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
@ -2969,7 +2969,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) {
nd4j::ops::triangular_solve op; nd4j::ops::triangular_solve op;
auto res = op.execute({&a, &b}, {}, {}, {false}); auto res = op.evaluate({&a, &b}, {}, {}, {false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);
@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
nd4j::ops::triangular_solve op; nd4j::ops::triangular_solve op;
auto res = op.execute({&a, &b}, {}, {}, {false, true}); auto res = op.evaluate({&a, &b}, {}, {}, {false, true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0); auto z = res->at(0);

View File

@ -58,7 +58,7 @@ TEST_F(DeclarableOpsTests13, test_pow_1) {
auto e = NDArrayFactory::create<float>('c', {2, 2}, {8.f, 8.f, 8.f, 8.f}); auto e = NDArrayFactory::create<float>('c', {2, 2}, {8.f, 8.f, 8.f, 8.f});
nd4j::ops::Pow op; nd4j::ops::Pow op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -73,7 +73,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) {
auto limit = NDArrayFactory::create<int>(0); auto limit = NDArrayFactory::create<int>(0);
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&start, &limit}, {}, {}); auto result = op.evaluate({&start, &limit});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -85,7 +85,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) {
TEST_F(DeclarableOpsTests13, test_empty_range_2) { TEST_F(DeclarableOpsTests13, test_empty_range_2) {
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {1.0, 1.0}, {}); auto result = op.evaluate({}, {1.0, 1.0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -97,7 +97,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_2) {
TEST_F(DeclarableOpsTests13, test_empty_range_3) { TEST_F(DeclarableOpsTests13, test_empty_range_3) {
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {}, {1, 1}); auto result = op.evaluate({}, {1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -143,7 +143,7 @@ TEST_F(DeclarableOpsTests13, test_listdiff_1) {
auto oi = NDArrayFactory::create<int>('c', {2}); auto oi = NDArrayFactory::create<int>('c', {2});
nd4j::ops::listdiff op; nd4j::ops::listdiff op;
auto result = op.execute({&x, &y}, {&od, &oi}, {}, {}, {}); auto result = op.execute({&x, &y}, std::vector<NDArray*>{&od, &oi}, {}, {}, {});
ASSERT_EQ(Status::OK(), result); ASSERT_EQ(Status::OK(), result);
} }
@ -152,7 +152,7 @@ TEST_F(DeclarableOpsTests13, test_greater_1) {
auto y = NDArrayFactory::create<float>('c', {1, 4}); auto y = NDArrayFactory::create<float>('c', {1, 4});
nd4j::ops::greater op; nd4j::ops::greater op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -165,7 +165,7 @@ TEST_F(DeclarableOpsTests13, test_eval_reduction_shape_1) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2}, {1, 2}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {2}, {1, 2});
nd4j::ops::evaluate_reduction_shape op; nd4j::ops::evaluate_reduction_shape op;
auto result = op.execute({&x, &y}, {}, {}, {true}); auto result = op.evaluate({&x, &y}, {true});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -218,7 +218,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_1) {
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
auto exp = NDArrayFactory::create<double>('c', {2,3}, {1.2,2.2,3.2,4.2,5.2,6.2}); auto exp = NDArrayFactory::create<double>('c', {2,3}, {1.2,2.2,3.2,4.2,5.2,6.2});
nd4j::ops::barnes_gains op; nd4j::ops::barnes_gains op;
auto result = op.execute({&x, &y, &eps}, {}, {}); auto result = op.evaluate({&x, &y, &eps});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
//result->at(0)->printBuffer("Gains out"); //result->at(0)->printBuffer("Gains out");
ASSERT_TRUE(exp.equalsTo(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0)));
@ -232,7 +232,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_2) {
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
auto exp = NDArrayFactory::create<double>('c', {2,3}, {1.2, 0.01, 3.2, 0.01, 5.2, 0.01}); auto exp = NDArrayFactory::create<double>('c', {2,3}, {1.2, 0.01, 3.2, 0.01, 5.2, 0.01});
nd4j::ops::barnes_gains op; nd4j::ops::barnes_gains op;
auto result = op.execute({&x, &y, &eps}, {}, {}); auto result = op.evaluate({&x, &y, &eps}, {}, {});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
//result->at(0)->printBuffer("Gains out"); //result->at(0)->printBuffer("Gains out");
ASSERT_TRUE(exp.equalsTo(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0)));
@ -247,7 +247,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_3) {
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
auto exp = NDArrayFactory::create<double>('c', {2,3}, {0.01, 2.2, 0.01, 4.2, 0.01, 6.2}); auto exp = NDArrayFactory::create<double>('c', {2,3}, {0.01, 2.2, 0.01, 4.2, 0.01, 6.2});
nd4j::ops::barnes_gains op; nd4j::ops::barnes_gains op;
auto result = op.execute({&x, &y, &eps}, {}, {}); auto result = op.evaluate({&x, &y, &eps}, {}, {});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
//result->at(0)->printBuffer("Gains out"); //result->at(0)->printBuffer("Gains out");
ASSERT_TRUE(exp.equalsTo(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0)));
@ -269,7 +269,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) {
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2}); // auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
nd4j::ops::barnes_edge_forces op; nd4j::ops::barnes_edge_forces op;
auto result = op.execute({&rows, &cols, &vals, &data}, {}, {1}); auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {1});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
@ -293,7 +293,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) {
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2}); // auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
nd4j::ops::barnes_edge_forces op; nd4j::ops::barnes_edge_forces op;
auto result = op.execute({&rows, &cols, &vals, &data}, {}, {2}); auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {2});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
@ -317,7 +317,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2}); // auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
nd4j::ops::barnes_edge_forces op; nd4j::ops::barnes_edge_forces op;
auto result = op.execute({&rows, &cols, &vals, &data}, {}, {11}); auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {11});
//nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); //nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf());
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) {
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2}); // auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
nd4j::ops::barnes_symmetrized op; nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {1}); auto result = op.evaluate({&rows, &cols, &vals}, {}, {1});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
//result->at(2)->printBuffer("Symmetrized1"); //result->at(2)->printBuffer("Symmetrized1");
ASSERT_TRUE(exp.equalsTo(result->at(2))); ASSERT_TRUE(exp.equalsTo(result->at(2)));
@ -359,7 +359,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) {
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2}); // auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
nd4j::ops::barnes_symmetrized op; nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {3}); auto result = op.evaluate({&rows, &cols, &vals}, {}, {3});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
//result->at(2)->printBuffer("Symmetrized2"); //result->at(2)->printBuffer("Symmetrized2");
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i))); // ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
@ -378,7 +378,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) {
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2}); // auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
nd4j::ops::barnes_symmetrized op; nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {11}); auto result = op.evaluate({&rows, &cols, &vals}, {}, {11});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
//result->at(2)->printBuffer("Symmetrized3"); //result->at(2)->printBuffer("Symmetrized3");
//exp.printBuffer("EXPect symm3"); //exp.printBuffer("EXPect symm3");
@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2}); // auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
nd4j::ops::barnes_symmetrized op; nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {11}); auto result = op.evaluate({&rows, &cols, &vals}, {}, {11});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
auto res = result->at(2); auto res = result->at(2);
// res->printBuffer("Symmetrized4"); // res->printBuffer("Symmetrized4");
@ -428,7 +428,7 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) {
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2}); // auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
nd4j::ops::cell_contains op; nd4j::ops::cell_contains op;
auto result = op.execute({&corners, &width, &point}, {}, {5}); auto result = op.evaluate({&corners, &width, &point}, {}, {5});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
ASSERT_TRUE(result->at(0)->e<bool>(0)); ASSERT_TRUE(result->at(0)->e<bool>(0));
//result->at(2)->printBuffer("Symmetrized3"); //result->at(2)->printBuffer("Symmetrized3");
@ -446,7 +446,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_1) {
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
std::unique_ptr<nd4j::ResultSet> results (op.execute({&input, &factor}, {}, {2})); std::unique_ptr<nd4j::ResultSet> results (op.evaluate({&input, &factor}, {}, {2}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -467,7 +467,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_2) {
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.9}, {2})); std::unique_ptr<nd4j::ResultSet> results(op.evaluate({&input}, {0.9}, {2}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_3) {
NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {-0.9}, {2})); std::unique_ptr<nd4j::ResultSet> results(op.evaluate({&input}, {-0.9}, {2}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -506,7 +506,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_4) {
NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.5}, {1})); std::unique_ptr<nd4j::ResultSet> results(op.evaluate({&input}, {0.5}, {1}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -525,7 +525,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) {
NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32); NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.5}, {0})); std::unique_ptr<nd4j::ResultSet> results(op.evaluate({&input}, {0.5}, {0}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -545,7 +545,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op; nd4j::ops::adjust_saturation op;
auto results = op.execute({&input, &factor}, {}, {2}); auto results = op.evaluate({&input, &factor}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE); NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE);
nd4j::ops::adjust_saturation op; nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {10}, {2}); auto results = op.evaluate({&input}, {10}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -585,7 +585,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_3) {
NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op; nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {-10}, {2}); auto results = op.evaluate({&input}, {-10}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -605,7 +605,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_4) {
NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op; nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {0.5}, {1}); auto results = op.evaluate({&input}, {0.5}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -625,7 +625,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32); NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op; nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {0.5}, {0}); auto results = op.evaluate({&input}, {0.5}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -646,7 +646,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) {
e.assign(512); e.assign(512);
nd4j::ops::shift_bits op; nd4j::ops::shift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -664,7 +664,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_1) {
e.assign(32); e.assign(32);
nd4j::ops::rshift_bits op; nd4j::ops::rshift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -682,7 +682,7 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
e.assign(512); e.assign(512);
nd4j::ops::cyclic_shift_bits op; nd4j::ops::cyclic_shift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -700,7 +700,7 @@ TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) {
e.assign(32); e.assign(32);
nd4j::ops::cyclic_rshift_bits op; nd4j::ops::cyclic_rshift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -719,7 +719,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_2) {
e.assign(512); e.assign(512);
nd4j::ops::shift_bits op; nd4j::ops::shift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -738,7 +738,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_2) {
e.assign(32); e.assign(32);
nd4j::ops::rshift_bits op; nd4j::ops::rshift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -757,7 +757,7 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) {
e.assign(512); e.assign(512);
nd4j::ops::cyclic_shift_bits op; nd4j::ops::cyclic_shift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -776,7 +776,7 @@ TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) {
e.assign(32); e.assign(32);
nd4j::ops::cyclic_rshift_bits op; nd4j::ops::cyclic_rshift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -794,7 +794,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_3) {
e.assign(512); e.assign(512);
nd4j::ops::shift_bits op; nd4j::ops::shift_bits op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -817,7 +817,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) {
exp.linspace(1); exp.linspace(1);
nd4j::ops::space_to_batch_nd op; nd4j::ops::space_to_batch_nd op;
auto result = op.execute({&x, &blockShape, &paddings}, {}, {}); auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -844,7 +844,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_2) {
x.linspace(1); x.linspace(1);
nd4j::ops::space_to_batch_nd op; nd4j::ops::space_to_batch_nd op;
auto result = op.execute({&x, &blockShape, &paddings}, {}, {}); auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -875,7 +875,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_3) {
x.linspace(1); x.linspace(1);
nd4j::ops::space_to_batch_nd op; nd4j::ops::space_to_batch_nd op;
auto result = op.execute({&x, &blockShape, &paddings}, {}, {}); auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -901,7 +901,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) {
exp.linspace(1); exp.linspace(1);
nd4j::ops::batch_to_space_nd op; nd4j::ops::batch_to_space_nd op;
auto result = op.execute({&x, &blockShape, &crop}, {}, {}); auto result = op.evaluate({&x, &blockShape, &crop}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -924,7 +924,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_2) {
x.linspace(1); x.linspace(1);
nd4j::ops::batch_to_space_nd op; nd4j::ops::batch_to_space_nd op;
auto result = op.execute({&x, &blockShape, &crop}, {}, {}); auto result = op.evaluate({&x, &blockShape, &crop}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -948,7 +948,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) {
x.linspace(1); x.linspace(1);
nd4j::ops::batch_to_space_nd op; nd4j::ops::batch_to_space_nd op;
auto result = op.execute({&x, &blockShape, &crop}, {}, {}); auto result = op.evaluate({&x, &blockShape, &crop}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -974,7 +974,7 @@ TEST_F(DeclarableOpsTests13, mergemax_1) {
nd4j::ops::mergemax op; nd4j::ops::mergemax op;
auto result = op.execute({&x1, &x2, &x3}, {}, {}); auto result = op.evaluate({&x1, &x2, &x3}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1040,9 +1040,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) {
hI = 1.; hI = 1.;
cI = 2.; cI = 2.;
std::initializer_list<double> tArgs = {cellClip}; std::vector<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
auto expH = NDArrayFactory::create<float>('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, auto expH = NDArrayFactory::create<float>('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f,
0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f,
@ -1053,7 +1053,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) {
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f}); auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f});
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1110,9 +1110,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) {
hI = 1.; hI = 1.;
cI = 2.; cI = 2.;
std::initializer_list<double> tArgs = {cellClip}; std::vector<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
auto expH = NDArrayFactory::create<float>('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f, auto expH = NDArrayFactory::create<float>('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f,
0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f, 0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f,
@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) {
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f}); auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f});
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1178,9 +1178,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) {
hI = 1.; hI = 1.;
cI = 2.; cI = 2.;
std::initializer_list<double> tArgs = {cellClip}; std::vector<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f, NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f,
0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f, 0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f,
@ -1190,7 +1190,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) {
NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32); NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1256,9 +1256,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
cI({0,1, 0,0, 0,0}) = 2; cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2; cI({1,2, 0,0, 0,0}) = -2;
std::initializer_list<double> tArgs = {cellClip}; std::vector<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, 2 * nOut}, { NDArray expH('c', {sL, bS, 2 * nOut}, {
0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f, 0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f,
@ -1275,7 +1275,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
-0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32); -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1340,9 +1340,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) {
cI({0,1, 0,0, 0,0}) = 2; cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2; cI({1,2, 0,0, 0,0}) = -2;
std::initializer_list<double> tArgs = {cellClip}; std::vector<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {bS, sL, 2*nOut}, { NDArray expH('c', {bS, sL, 2*nOut}, {
0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, 0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f,
@ -1357,7 +1357,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) {
-0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32); -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1426,9 +1426,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
cI({0,1, 0,0, 0,0}) = 2; cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2; cI({1,2, 0,0, 0,0}) = -2;
std::initializer_list<double> tArgs = {cellClip}; std::vector<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, { NDArray expH('c', {sL, bS, nOut}, {
0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f, 0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f,
@ -1443,7 +1443,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
nd4j::DataType::FLOAT32); nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1519,7 +1519,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_7) {
NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32); NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1597,7 +1597,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) {
NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32); NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1684,7 +1684,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) {
-0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32); -0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1769,7 +1769,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) {
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32); NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1849,7 +1849,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) {
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32); NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1940,7 +1940,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32); 0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1977,7 +1977,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test1) {
nd4j::ops::batchnorm op; nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2010,7 +2010,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test2) {
nd4j::ops::batchnorm op; nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2039,7 +2039,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test3) {
nd4j::ops::batchnorm op; nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2067,7 +2067,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test4) {
nd4j::ops::batchnorm op; nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2095,7 +2095,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test5) {
nd4j::ops::batchnorm op; nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2124,7 +2124,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test6) {
nd4j::ops::batchnorm op; nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2193,7 +2193,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test8) {
nd4j::ops::batchnorm op; nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2236,7 +2236,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) {
nd4j::ops::batchnorm op; nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2271,7 +2271,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) {
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2314,7 +2314,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) {
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2356,7 +2356,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) {
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2395,7 +2395,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) {
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2439,7 +2439,7 @@ return;
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2484,7 +2484,7 @@ return;
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2532,7 +2532,7 @@ return;
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2581,7 +2581,7 @@ return;
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2635,7 +2635,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) {
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2687,7 +2687,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) {
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2751,7 +2751,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) {
nd4j::ops::batchnorm_bp op; nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3}); auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());

View File

@ -44,7 +44,7 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) {
exp.assign(4.0f); exp.assign(4.0f);
nd4j::ops::fill op; nd4j::ops::fill op;
auto result = op.execute({&x}, {4.0f},{}, {}); auto result = op.evaluate({&x}, {4.0f});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -62,7 +62,7 @@ TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) {
r.streamline('f'); r.streamline('f');
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {3, 2}, {}); auto result = op.evaluate({&x}, {3, 2});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -96,7 +96,7 @@ TEST_F(DeclarableOpsTests14, Multiply_test) {
e.assign(1.0); e.assign(1.0);
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
auto f = result->at(0); auto f = result->at(0);
NDArray r = *f; NDArray r = *f;
@ -113,7 +113,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) {
auto e = NDArrayFactory::create<Nd4jLong>('c', {2}, {5, 4}); auto e = NDArrayFactory::create<Nd4jLong>('c', {2}, {5, 4});
nd4j::ops::evaluate_reduction_shape op; nd4j::ops::evaluate_reduction_shape op;
auto result = op.execute({&x, &y}, {}, {}, {false, false}); auto result = op.evaluate({&x, &y}, {}, {}, {false, false});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -128,7 +128,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) {
auto e = NDArrayFactory::create<Nd4jLong>('c', {3}, {5, 1, 4}); auto e = NDArrayFactory::create<Nd4jLong>('c', {3}, {5, 1, 4});
nd4j::ops::evaluate_reduction_shape op; nd4j::ops::evaluate_reduction_shape op;
auto result = op.execute({&x, &y}, {}, {}, {true, false}); auto result = op.evaluate({&x, &y}, {}, {}, {true, false});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) {
nd4j::ops::add op; nd4j::ops::add op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));
@ -200,7 +200,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
nd4j::ops::subtract op; nd4j::ops::subtract op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));
@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
auto y = NDArrayFactory::create<int>(1); auto y = NDArrayFactory::create<int>(1);
nd4j::ops::fill op; nd4j::ops::fill op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -250,13 +250,13 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0}); auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
nd4j::ops::reduce_min sumOp; nd4j::ops::reduce_min sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1}); auto res2 = sumOp.evaluate({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK()); ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0); auto out = res2->at(0);
@ -270,7 +270,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_2) {
auto e = NDArrayFactory::create<float>('c', {0}); auto e = NDArrayFactory::create<float>('c', {0});
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -284,7 +284,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_3) {
auto e = NDArrayFactory::create<float>('c', {2, 0}); auto e = NDArrayFactory::create<float>('c', {2, 0});
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&x, &x}, {}, {0}); auto result = op.evaluate({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -298,7 +298,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_4) {
auto e = NDArrayFactory::create<float>('c', {2, 0}); auto e = NDArrayFactory::create<float>('c', {2, 0});
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&x, &x}, {}, {0}); auto result = op.evaluate({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -311,7 +311,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0}); auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_min sumOp; nd4j::ops::reduce_min sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1}); auto res2 = sumOp.evaluate({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK()); ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0); auto out = res2->at(0);
@ -323,7 +323,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0}); auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_max sumOp; nd4j::ops::reduce_max sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1}); auto res2 = sumOp.evaluate({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK()); ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0); auto out = res2->at(0);
@ -335,7 +335,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0}); auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_sum sumOp; nd4j::ops::reduce_sum sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1}); auto res2 = sumOp.evaluate({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK()); ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0); auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), 0.f); ASSERT_EQ(out->e<float>(0), 0.f);
@ -346,7 +346,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0}); auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_mean sumOp; nd4j::ops::reduce_mean sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1}); auto res2 = sumOp.evaluate({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK()); ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0); auto out = res2->at(0);
// out->printShapeInfo("ReduceMean empty shape with keep dims"); // out->printShapeInfo("ReduceMean empty shape with keep dims");
@ -366,7 +366,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) {
matrix.linspace(1); matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -387,7 +387,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) {
matrix.linspace(1); matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -405,7 +405,7 @@ TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
nd4j::ops::argmax op; nd4j::ops::argmax op;
//nd4j::ops::reduce_max op; //nd4j::ops::reduce_max op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -432,7 +432,7 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
auto x = NDArrayFactory::create<float>('c', {32, 0}); auto x = NDArrayFactory::create<float>('c', {32, 0});
nd4j::ops::tanh op; nd4j::ops::tanh op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -450,7 +450,7 @@ TEST_F(DeclarableOpsTests14, repeat_1) {
NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
nd4j::ops::repeat op; nd4j::ops::repeat op;
auto result = op.execute({&x}, {}, {2, 0}); auto result = op.evaluate({&x}, {}, {2, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -468,7 +468,7 @@ TEST_F(DeclarableOpsTests14, repeat_2) {
NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6}); NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6});
nd4j::ops::repeat op; nd4j::ops::repeat op;
auto result = op.execute({&x}, {}, {2, 1}); auto result = op.evaluate({&x}, {}, {2, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -486,7 +486,7 @@ TEST_F(DeclarableOpsTests14, repeat_3) {
NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6}); NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6});
nd4j::ops::repeat op; nd4j::ops::repeat op;
auto result = op.execute({&x}, {}, {1,2,3, 1}); auto result = op.evaluate({&x}, {}, {1,2,3, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -504,7 +504,7 @@ TEST_F(DeclarableOpsTests14, repeat_4) {
NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6}); NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6});
nd4j::ops::repeat op; nd4j::ops::repeat op;
auto result = op.execute({&x}, {}, {3,4, 0}); auto result = op.evaluate({&x}, {}, {3,4, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -522,7 +522,7 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24});
nd4j::ops::repeat op; nd4j::ops::repeat op;
auto result = op.execute({&x}, {}, {1,2,1, 1}); auto result = op.evaluate({&x}, {}, {1,2,1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);

View File

@ -49,7 +49,7 @@ TEST_F(DeclarableOpsTests15, Test_NormalizeMoments_1) {
auto z1 = NDArrayFactory::create<double>('c', {10}); auto z1 = NDArrayFactory::create<double>('c', {10});
nd4j::ops::normalize_moments op; nd4j::ops::normalize_moments op;
auto result = op.execute({&w, &x, &y}, {&z0, &z1}, {1e-4}, {}, {}); auto result = op.execute({&w, &x, &y}, std::vector<NDArray*>{&z0, &z1}, {1e-4}, {}, {});
ASSERT_EQ(Status::OK(), result); ASSERT_EQ(Status::OK(), result);
} }
@ -87,7 +87,7 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
auto eps = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); auto eps = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::standardize_bp op; nd4j::ops::standardize_bp op;
auto result = op.execute({&x, &eps}, {}, {0}, {}); auto result = op.evaluate({&x, &eps}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
} }
@ -103,7 +103,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
x.linspace(1.); x.linspace(1.);
nd4j::ops::adjust_contrast op; nd4j::ops::adjust_contrast op;
auto result = op.execute({&x, &factor}, {}, {}, {}); auto result = op.evaluate({&x, &factor}, {}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
@ -121,7 +121,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
}); });
x.linspace(1.); x.linspace(1.);
nd4j::ops::adjust_contrast op; nd4j::ops::adjust_contrast op;
auto result = op.execute({&x}, {2.}, {}, {}); auto result = op.evaluate({&x}, {2.});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast"); // out->printIndexedBuffer("Adjusted Constrast");
@ -139,7 +139,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
}); });
x.linspace(1.); x.linspace(1.);
nd4j::ops::adjust_contrast_v2 op; nd4j::ops::adjust_contrast_v2 op;
auto result = op.execute({&x}, {2.}, {}, {}); auto result = op.evaluate({&x}, {2.});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast"); // out->printIndexedBuffer("Adjusted Constrast");
@ -157,7 +157,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
}); });
x.linspace(1.); x.linspace(1.);
nd4j::ops::adjust_contrast_v2 op; nd4j::ops::adjust_contrast_v2 op;
auto result = op.execute({&x}, {2.}, {}, {}); auto result = op.evaluate({&x}, {2.}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast"); // out->printIndexedBuffer("Adjusted Constrast");
@ -172,7 +172,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) {
}); });
x.linspace(1.); x.linspace(1.);
nd4j::ops::adjust_contrast_v2 op; nd4j::ops::adjust_contrast_v2 op;
auto result = op.execute({&x}, {2.}, {}, {}); auto result = op.evaluate({&x}, {2.}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast"); // out->printIndexedBuffer("Adjusted Constrast");
@ -302,7 +302,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) {
}); });
nd4j::ops::adjust_contrast op; nd4j::ops::adjust_contrast op;
auto result = op.execute({&x}, {2.}, {}, {}); auto result = op.evaluate({&x}, {2.}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
// out->printBuffer("Adjusted Constrast6"); // out->printBuffer("Adjusted Constrast6");
@ -407,7 +407,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
}); });
// x.linspace(1.); // x.linspace(1.);
nd4j::ops::adjust_contrast_v2 op; nd4j::ops::adjust_contrast_v2 op;
auto result = op.execute({&x}, {2.}, {}, {}); auto result = op.evaluate({&x}, {2.}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
// out->printBuffer("Adjusted Constrast7"); // out->printBuffer("Adjusted Constrast7");
@ -423,7 +423,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 }); auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 });
x.linspace(1.); x.linspace(1.);
nd4j::ops::bitcast op; nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {}); auto result = op.evaluate({&x}, {(int) nd4j::DataType::DOUBLE});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
// out->printIndexedBuffer("Casted result"); // out->printIndexedBuffer("Casted result");
@ -437,7 +437,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f}); 0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f});
x.linspace(1.); x.linspace(1.);
nd4j::ops::bitcast op; nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {}); auto result = op.evaluate({&x}, {(int) nd4j::DataType::HALF});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
@ -450,7 +450,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_3) {
x.linspace(1.); x.linspace(1.);
nd4j::ops::bitcast op; nd4j::ops::bitcast op;
try { try {
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); auto result = op.evaluate({&x}, {(int) nd4j::DataType::INT64});
ASSERT_NE(Status::OK(), result->status()); ASSERT_NE(Status::OK(), result->status());
delete result; delete result;
} catch (std::exception& e) { } catch (std::exception& e) {
@ -478,7 +478,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) {
x.linspace(1.); x.linspace(1.);
nd4j::ops::bitcast op; nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
// e.printIndexedBuffer("Double to int64"); // e.printIndexedBuffer("Double to int64");
auto res = result->at(0); auto res = result->at(0);
@ -497,7 +497,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_5) {
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL,
3314989625590692528LL}); 3314989625590692528LL});
nd4j::ops::bitcast op; nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto res = result->at(0); auto res = result->at(0);
// res->printIndexedBuffer("BITCAST5"); // res->printIndexedBuffer("BITCAST5");
@ -515,7 +515,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL,
5476460161268730496LL}); 5476460161268730496LL});
nd4j::ops::bitcast op; nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto res = result->at(0); auto res = result->at(0);
// res->printIndexedBuffer("BITCAST6"); // res->printIndexedBuffer("BITCAST6");
@ -532,7 +532,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, { auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {
4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL}); 4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL});
nd4j::ops::bitcast op; nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto res = result->at(0); auto res = result->at(0);
// res->printIndexedBuffer("BITCAST7"); // res->printIndexedBuffer("BITCAST7");
@ -549,7 +549,7 @@ TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
auto gB = NDArrayFactory::create<double>('c', {1, 4}); auto gB = NDArrayFactory::create<double>('c', {1, 4});
nd4j::ops::matmul_bp op; nd4j::ops::matmul_bp op;
auto status = op.execute({&a, &b, &gI}, {&gA, &gB}, {}, {1, 0, 0}, {}); auto status = op.execute({&a, &b, &gI}, std::vector<NDArray*>{&gA, &gB}, {}, {1, 0, 0}, {});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -573,7 +573,7 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_1) {
auto y = NDArrayFactory::string("shouldn't ever trigger"); auto y = NDArrayFactory::string("shouldn't ever trigger");
nd4j::ops::check_numerics op; nd4j::ops::check_numerics op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -617,7 +617,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
auto b = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); auto b = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
nd4j::ops::layer_norm op; nd4j::ops::layer_norm op;
auto result = op.execute({&x, &g, &b}, {}, {0}, {false}); auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
} }
@ -629,7 +629,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f}); auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::layer_norm_bp op; nd4j::ops::layer_norm_bp op;
auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false}); auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
} }
@ -662,9 +662,9 @@ TEST_F(DeclarableOpsTests15, test_hashCode_1) {
y.linspace(2.); y.linspace(2.);
nd4j::ops::hashcode op; nd4j::ops::hashcode op;
auto resultA0 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64); auto resultA0 = op.evaluate({&x});
auto resultA1 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64); auto resultA1 = op.evaluate({&x});
auto resultB0 = op.execute({&y}, {}, {}, {}, false, nd4j::DataType::INT64); auto resultB0 = op.evaluate({&y});
// resultA0->at(0)->printIndexedBuffer("A0"); // resultA0->at(0)->printIndexedBuffer("A0");
// resultA1->at(0)->printIndexedBuffer("A1"); // resultA1->at(0)->printIndexedBuffer("A1");
// resultB0->at(0)->printIndexedBuffer("B0"); // resultB0->at(0)->printIndexedBuffer("B0");
@ -684,9 +684,9 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
y.linspace(2.); y.linspace(2.);
nd4j::ops::hashcode op; nd4j::ops::hashcode op;
auto resultA0 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64); auto resultA0 = op.evaluate({&x});
auto resultA1 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64); auto resultA1 = op.evaluate({&x});
auto resultB0 = op.execute({&y}, {}, {}, {}, false, nd4j::DataType::INT64); auto resultB0 = op.evaluate({&y});
// resultA0->at(0)->printIndexedBuffer("A0"); // resultA0->at(0)->printIndexedBuffer("A0");
// resultA1->at(0)->printIndexedBuffer("A1"); // resultA1->at(0)->printIndexedBuffer("A1");
@ -705,7 +705,7 @@ TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) {
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f}); auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&array}, {}, {1, 1}); auto result = op.evaluate({&array}, {}, {1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -742,7 +742,7 @@ TEST_F(DeclarableOpsTests15, test_rank_2) {
auto e = NDArrayFactory::create<int>('c', {}, {2}); auto e = NDArrayFactory::create<int>('c', {}, {2});
nd4j::ops::rank op; nd4j::ops::rank op;
auto result = op.execute({&array}, {}, {}); auto result = op.evaluate({&array}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
auto x8 = NDArrayFactory::create<float>('c', {12}); auto x8 = NDArrayFactory::create<float>('c', {12});
nd4j::ops::lstmBlock op; nd4j::ops::lstmBlock op;
auto result = op.execute({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0}); auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -790,7 +790,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
auto x8 = NDArrayFactory::create<float>('f', {4 * nIn}); auto x8 = NDArrayFactory::create<float>('f', {4 * nIn});
nd4j::ops::lstmBlock op; nd4j::ops::lstmBlock op;
auto result = op.execute({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1}); auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -860,7 +860,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) {
NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32); NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32);
NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32); NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32);
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({&rgbs}, {}, {}); auto result = op.evaluate({&rgbs}, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -876,7 +876,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_2) {
auto rgbs = NDArrayFactory::create<int>('f', { 3 }, { 1, 120, -25 }); auto rgbs = NDArrayFactory::create<int>('f', { 3 }, { 1, 120, -25 });
auto expected = NDArrayFactory::create<int>('f', { 1 }, { 67 }); auto expected = NDArrayFactory::create<int>('f', { 1 }, { 67 });
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -892,7 +892,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_3) {
NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32); NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32);
NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, nd4j::DataType::INT32); NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, nd4j::DataType::INT32);
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -910,7 +910,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_4) {
rgbs.permutei({1,0}); rgbs.permutei({1,0});
NDArray expected('c', { 2, 1 }, { 138, 58 }, nd4j::DataType::INT32); NDArray expected('c', { 2, 1 }, { 138, 58 }, nd4j::DataType::INT32);
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -926,7 +926,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_5) {
NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32); NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32);
NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, nd4j::DataType::INT32); NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, nd4j::DataType::INT32);
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({ &rgbs }, {}, {0}); auto result = op.evaluate({ &rgbs }, {}, {0});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -943,7 +943,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_6) {
auto expected = NDArrayFactory::create<float>('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); auto expected = NDArrayFactory::create<float>('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f});
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -960,7 +960,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_7) {
auto expected = NDArrayFactory::create<float>('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f }); auto expected = NDArrayFactory::create<float>('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f });
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({ &rgbs }, {}, {1}); auto result = op.evaluate({ &rgbs }, {}, {1});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -976,7 +976,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) {
auto rgbs = NDArrayFactory::create<float>('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); auto rgbs = NDArrayFactory::create<float>('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f});
try { try {
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
ASSERT_EQ(Status::THROW(), result->status()); ASSERT_EQ(Status::THROW(), result->status());
delete result; delete result;
} catch (std::exception& e) { } catch (std::exception& e) {
@ -991,7 +991,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) {
auto expected = NDArrayFactory::create<float>('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f }); auto expected = NDArrayFactory::create<float>('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f });
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1007,7 +1007,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) {
NDArray rgbs('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); NDArray rgbs('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op; nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1026,7 +1026,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) {
NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op; nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1043,7 +1043,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) {
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op; nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, { 0 }); auto result = op.evaluate({ &rgbs }, {}, { 0 });
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
@ -1059,7 +1059,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) {
NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op; nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1076,7 +1076,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) {
NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, nd4j::DataType::FLOAT32); NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op; nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, { 1 }); auto result = op.evaluate({ &rgbs }, {}, { 1 });
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1091,7 +1091,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) {
NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32); NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
try { try {
nd4j::ops::rgb_to_yuv op; nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
ASSERT_EQ(Status::THROW(), result->status()); ASSERT_EQ(Status::THROW(), result->status());
delete result; delete result;
} }
@ -1107,7 +1107,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) {
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32); NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op; nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1123,7 +1123,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) {
NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op; nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {}); auto result = op.evaluate({ &yuv }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1139,7 +1139,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) {
NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
NDArray expected('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); NDArray expected('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op; nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {}); auto result = op.evaluate({ &yuv }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1156,7 +1156,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) {
NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32); NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op; nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, { 0 }); auto result = op.evaluate({ &yuv }, {}, { 0 });
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
@ -1172,7 +1172,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) {
NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32); NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op; nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {}); auto result = op.evaluate({ &yuv }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1189,7 +1189,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) {
NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, nd4j::DataType::FLOAT32); NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op; nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, { 1 }); auto result = op.evaluate({ &yuv }, {}, { 1 });
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1204,7 +1204,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) {
NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32); NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
try { try {
nd4j::ops::yuv_to_rgb op; nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {}); auto result = op.evaluate({ &yuv }, {}, {});
ASSERT_EQ(Status::THROW(), result->status()); ASSERT_EQ(Status::THROW(), result->status());
delete result; delete result;
} }
@ -1220,7 +1220,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) {
NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, nd4j::DataType::FLOAT32); NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op; nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {}); auto result = op.evaluate({ &yuv }, {}, {});
auto output = result->at(0); auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1246,7 +1246,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test1) {
dLdz.assign(1.0); dLdz.assign(1.0);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto results = op.execute({ &x, &y, &dLdz }, {}, {}); auto results = op.evaluate({ &x, &y, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1275,7 +1275,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test2) {
dLdz.linspace(0.1, 0.1); dLdz.linspace(0.1, 0.1);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto results = op.execute({ &x, &y, &dLdz }, {}, {}); auto results = op.evaluate({ &x, &y, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto* dLdx = results->at(0); auto* dLdx = results->at(0);
@ -1305,7 +1305,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test3) {
dLdz.linspace(0.1, 0.1); dLdz.linspace(0.1, 0.1);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto resultsY = op.execute({ &xY, &yY, &dLdz }, {}, {}); auto resultsY = op.evaluate({ &xY, &yY, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsY->status()); ASSERT_EQ(ND4J_STATUS_OK, resultsY->status());
@ -1337,7 +1337,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test4) {
xX.assign(2.0); xX.assign(2.0);
yX.assign(4.0); yX.assign(4.0);
auto resultsX = op.execute({ &xX, &yX, &dLdz }, {}, {}); auto resultsX = op.evaluate({ &xX, &yX, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsX->status()); ASSERT_EQ(ND4J_STATUS_OK, resultsX->status());
@ -1369,7 +1369,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test5) {
dLdyExp.assign(pow(3, 4) * log(3)); dLdyExp.assign(pow(3, 4) * log(3));
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto results = op.execute({ &xConst, &yConst, &dLdz }, {}, {}); auto results = op.evaluate({ &xConst, &yConst, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto* dLdx = results->at(0); auto* dLdx = results->at(0);
@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test6) {
NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32); NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto resultsXC = op.execute({ &xConst, &y, &dLdzC }, {}, {}); auto resultsXC = op.evaluate({ &xConst, &y, &dLdzC }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsXC->status()); ASSERT_EQ(ND4J_STATUS_OK, resultsXC->status());
auto* dLdxXC = resultsXC->at(0); auto* dLdxXC = resultsXC->at(0);
@ -1428,7 +1428,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) {
auto dLdyExpYs = NDArrayFactory::create<float>(79.85056f); auto dLdyExpYs = NDArrayFactory::create<float>(79.85056f);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto resultsYs = op.execute({ &x, &Y, &dLdzC }, {}, {}); auto resultsYs = op.evaluate({ &x, &Y, &dLdzC }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsYs->status()); ASSERT_EQ(ND4J_STATUS_OK, resultsYs->status());
auto* dLdxY = resultsYs->at(0); auto* dLdxY = resultsYs->at(0);
@ -1454,7 +1454,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test8) {
NDArray dLdyExp = NDArrayFactory::create<float>(pow(4.f, 2.f) * log(4.f) * 0.1f); NDArray dLdyExp = NDArrayFactory::create<float>(pow(4.f, 2.f) * log(4.f) * 0.1f);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto results = op.execute({ &X, &Y, &dLdz }, {}, {}); auto results = op.evaluate({ &X, &Y, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1484,7 +1484,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test9) {
y.assign(2.0); y.assign(2.0);
dLdz.linspace(0.1, 0.1); dLdz.linspace(0.1, 0.1);
auto results = op.execute({ &x, &y, &dLdz }, {}, {}); auto results = op.evaluate({ &x, &y, &dLdz }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto* dLdx = results->at(0); auto* dLdx = results->at(0);
@ -1513,7 +1513,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test10) {
yB.assign(2.0); yB.assign(2.0);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {}); auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsB->status()); ASSERT_EQ(ND4J_STATUS_OK, resultsB->status());
@ -1540,7 +1540,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, nd4j::DataType::FLOAT32); NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, nd4j::DataType::FLOAT32);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;
auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {}); auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsB->status()); ASSERT_EQ(ND4J_STATUS_OK, resultsB->status());
auto* dLdxB = resultsB->at(0); auto* dLdxB = resultsB->at(0);

View File

@ -46,7 +46,7 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) {
auto e = NDArrayFactory::create<float>('c', { 3 }, { 3.f, 1.f, 1.f }); auto e = NDArrayFactory::create<float>('c', { 3 }, { 3.f, 1.f, 1.f });
nd4j::ops::scatter_upd op; nd4j::ops::scatter_upd op;
auto result = op.execute({ &x, &y, &w }, {}, {}); auto result = op.evaluate({ &x, &y, &w });
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -66,7 +66,7 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
x.linspace(1); x.linspace(1);
nd4j::ops::scatter_upd op; nd4j::ops::scatter_upd op;
auto result = op.execute({ &x, &indices, &updates }, {}, {}); auto result = op.evaluate({ &x, &indices, &updates });
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
auto e = NDArrayFactory::create<Nd4jLong>(18); auto e = NDArrayFactory::create<Nd4jLong>(18);
nd4j::ops::bits_hamming_distance op; nd4j::ops::bits_hamming_distance op;
auto result = op.execute({ &x, &y }, {}, {}); auto result = op.evaluate({ &x, &y });
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -166,7 +166,7 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
auto e = NDArrayFactory::create<Nd4jLong>('c', { 1, 0, 2 }); auto e = NDArrayFactory::create<Nd4jLong>('c', { 1, 0, 2 });
nd4j::ops::cast op; nd4j::ops::cast op;
auto result = op.execute({ &x }, {}, { 10 }); auto result = op.evaluate({&x}, {10});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));

View File

@ -48,7 +48,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
nd4j::ops::compat_sparse_to_dense op; nd4j::ops::compat_sparse_to_dense op;
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {}); auto result = op.evaluate({&ranges, &shape, &values, &def});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -63,7 +63,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
nd4j::ops::compat_sparse_to_dense op; nd4j::ops::compat_sparse_to_dense op;
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {}); auto result = op.evaluate({&ranges, &shape, &values, &def});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
nd4j::ops::compat_string_split op; nd4j::ops::compat_string_split op;
auto result = op.execute({&x, &delimiter}, {}, {}); auto result = op.evaluate({&x, &delimiter});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(2, result->size()); ASSERT_EQ(2, result->size());

File diff suppressed because it is too large Load Diff

View File

@ -43,7 +43,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_1) {
auto exp = x.tile(reps); auto exp = x.tile(reps);
nd4j::ops::tile op; nd4j::ops::tile op;
auto result = op.execute({&x, &rep_vector}, {}, {}); auto result = op.evaluate({&x, &rep_vector});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -61,7 +61,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_2) {
auto exp = x.tile(reps); auto exp = x.tile(reps);
nd4j::ops::tile op; nd4j::ops::tile op;
auto result = op.execute({&x}, {}, {2, 2}); auto result = op.evaluate({&x}, {}, {2, 2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests3, Test_Permute_1) {
auto exp= NDArrayFactory::create<float>('c', {2, 4, 3}); auto exp= NDArrayFactory::create<float>('c', {2, 4, 3});
nd4j::ops::permute op; nd4j::ops::permute op;
auto result = op.execute({&x, &permute}, {}, {}); auto result = op.evaluate({&x, &permute});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -92,7 +92,7 @@ TEST_F(DeclarableOpsTests3, Test_Permute_2) {
auto exp= NDArrayFactory::create<float>('c', {4, 3, 2}); auto exp= NDArrayFactory::create<float>('c', {4, 3, 2});
nd4j::ops::permute op; nd4j::ops::permute op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -110,7 +110,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) {
// auto expI= NDArrayFactory::create<float>('c', {3}, {0, 1, 4}); // auto expI= NDArrayFactory::create<float>('c', {3}, {0, 1, 4});
nd4j::ops::unique op; nd4j::ops::unique op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(2, result->size()); ASSERT_EQ(2, result->size());
@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_2) {
auto expC= NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 2, 1}); auto expC= NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 2, 1});
nd4j::ops::unique_with_counts op; nd4j::ops::unique_with_counts op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size()); ASSERT_EQ(3, result->size());
@ -169,7 +169,7 @@ TEST_F(DeclarableOpsTests3, Test_Rint_1) {
auto exp= NDArrayFactory::create<float>('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f}); auto exp= NDArrayFactory::create<float>('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f});
nd4j::ops::rint op; nd4j::ops::rint op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -188,7 +188,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) {
std::vector<int> dims({1}); std::vector<int> dims({1});
nd4j::ops::norm op; nd4j::ops::norm op;
auto result0 = op.execute({&x}, {0.}, {}); auto result0 = op.evaluate({&x}, {0.}, {});
auto z0 = result0->at(0); auto z0 = result0->at(0);
auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false);
@ -197,7 +197,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) {
delete result0; delete result0;
auto result1 = op.execute({&x}, {1.}, {1}); auto result1 = op.evaluate({&x}, {1.}, {1});
ASSERT_EQ(result1->status(), ND4J_STATUS_OK); ASSERT_EQ(result1->status(), ND4J_STATUS_OK);
auto z1 = result1->at(0); auto z1 = result1->at(0);
// z1->printIndexedBuffer("Z1"); // z1->printIndexedBuffer("Z1");
@ -210,7 +210,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) {
delete result1; delete result1;
auto result4 = op.execute({&x}, {4.}, {1}); auto result4 = op.evaluate({&x}, {4.}, {1});
auto z4 = result4->at(0); auto z4 = result4->at(0);
auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false);
@ -230,7 +230,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
std::vector<int> dims({1}); std::vector<int> dims({1});
nd4j::ops::norm op; nd4j::ops::norm op;
auto result0 = op.execute({&x}, {0}, {}); auto result0 = op.evaluate({&x}, {0}, {});
auto z0 = result0->at(0); auto z0 = result0->at(0);
auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false);
@ -239,7 +239,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
delete result0; delete result0;
auto result1 = op.execute({&x, &axis}, {1}, {}); auto result1 = op.evaluate({&x, &axis}, {1}, {});
auto z1 = result1->at(0); auto z1 = result1->at(0);
auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false);
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
delete result1; delete result1;
auto result4 = op.execute({&x, &axis}, {4}, {}); auto result4 = op.evaluate({&x, &axis}, {4}, {});
auto z4 = result4->at(0); auto z4 = result4->at(0);
auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false);
@ -264,7 +264,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) {
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0}); auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
nd4j::ops::clipbyavgnorm op; nd4j::ops::clipbyavgnorm op;
auto result = op.execute({&x}, {0.8}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {0.8}, {});
auto z = result->at(0); auto z = result->at(0);
@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) {
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f}); auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
nd4j::ops::clipbyavgnorm op; nd4j::ops::clipbyavgnorm op;
auto result = op.execute({&x}, {0.9}, {}); auto result = op.evaluate({&x}, {0.9}, {});
auto z = result->at(0); auto z = result->at(0);
@ -295,7 +295,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) {
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
nd4j::ops::clipbynorm op; nd4j::ops::clipbynorm op;
auto result = op.execute({&x}, {4.0}, {}); auto result = op.evaluate({&x}, {4.0}, {});
auto z = result->at(0); auto z = result->at(0);
@ -310,7 +310,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) {
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
nd4j::ops::clipbynorm op; nd4j::ops::clipbynorm op;
auto result = op.execute({&x}, {6.0}, {}); auto result = op.evaluate({&x}, {6.0}, {});
auto z = result->at(0); auto z = result->at(0);
@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) {
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
nd4j::ops::clipbynorm op; nd4j::ops::clipbynorm op;
auto result = op.execute({&x}, {1.0}, {1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {1.0}, {1});
auto z = result->at(0); auto z = result->at(0);
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
@ -360,7 +360,7 @@ TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
auto exp1= NDArrayFactory::create<Nd4jLong>('c', {3}, {1, 3, 5}); auto exp1= NDArrayFactory::create<Nd4jLong>('c', {3}, {1, 3, 5});
nd4j::ops::listdiff op; nd4j::ops::listdiff op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -386,7 +386,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_1) {
auto exp= NDArrayFactory::create<float>('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f}); auto exp= NDArrayFactory::create<float>('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&start, &stop, &step}, {}, {}); auto result = op.evaluate({&start, &stop, &step});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -406,7 +406,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_2) {
auto exp= NDArrayFactory::create<float>('c', {2}, {2.f, 1.f}); auto exp= NDArrayFactory::create<float>('c', {2}, {2.f, 1.f});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&start, &stop, &step}, {}, {}); auto result = op.evaluate({&start, &stop, &step});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -425,7 +425,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) {
auto exp= NDArrayFactory::create<float>('c', {2}, {0.f, 1.f}); auto exp= NDArrayFactory::create<float>('c', {2}, {0.f, 1.f});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&start, &stop, &step}, {}, {}); auto result = op.evaluate({&start, &stop, &step});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_4) {
auto exp= NDArrayFactory::create<float>('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); auto exp= NDArrayFactory::create<float>('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {-10., 10., 1.666}, {}); auto result = op.evaluate({}, {-10., 10., 1.666}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -459,7 +459,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_5) {
auto exp= NDArrayFactory::create<float>('c', {2}, {2.f, 1.f}); auto exp= NDArrayFactory::create<float>('c', {2}, {2.f, 1.f});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {2, 0, -1}, {}); auto result = op.evaluate({}, {2, 0, -1}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -475,7 +475,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_6) {
auto exp= NDArrayFactory::create<float>('c', {2}, {0.f, 1.f}); auto exp= NDArrayFactory::create<float>('c', {2}, {0.f, 1.f});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {0, 2, 1}, {}); auto result = op.evaluate({}, {0, 2, 1}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -491,7 +491,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_7) {
auto exp= NDArrayFactory::create<float>('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f}); auto exp= NDArrayFactory::create<float>('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {10,-5,-1.666}, {}); auto result = op.evaluate({}, {10,-5,-1.666}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -509,7 +509,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_8) {
auto exp= NDArrayFactory::create<int>('c', {2}, {2, 1}); auto exp= NDArrayFactory::create<int>('c', {2}, {2, 1});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {}, {2, 0, -1}); auto result = op.evaluate({}, {}, {2, 0, -1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -525,7 +525,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_9) {
auto exp= NDArrayFactory::create<int>('c', {2}, {0, 1}); auto exp= NDArrayFactory::create<int>('c', {2}, {0, 1});
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({}, {}, {0, 2, 1}); auto result = op.evaluate({}, {}, {0, 2, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) {
auto exp = MmulHelper::mmul(&x, &y); auto exp = MmulHelper::mmul(&x, &y);
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3}); auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size()); ASSERT_EQ(3, result->size());
@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
auto exp = MmulHelper::mmul(&x, &y); auto exp = MmulHelper::mmul(&x, &y);
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3}); auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size()); ASSERT_EQ(3, result->size());
@ -602,7 +602,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
auto exp = MmulHelper::mmul(&x, &y); auto exp = MmulHelper::mmul(&x, &y);
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3}); auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size()); ASSERT_EQ(3, result->size());
@ -630,7 +630,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
auto exp = MmulHelper::mmul(&x, &y); auto exp = MmulHelper::mmul(&x, &y);
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3}); auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size()); ASSERT_EQ(3, result->size());
@ -658,7 +658,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
auto exp = MmulHelper::mmul(&x, &y); auto exp = MmulHelper::mmul(&x, &y);
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3}); auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size()); ASSERT_EQ(3, result->size());
@ -687,7 +687,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) {
auto exp = MmulHelper::mmul(&x, &y); auto exp = MmulHelper::mmul(&x, &y);
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3}); auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size()); ASSERT_EQ(3, result->size());
@ -717,7 +717,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
// exp->printShapeInfo("exp shape"); // exp->printShapeInfo("exp shape");
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size()); ASSERT_EQ(3, result->size());
@ -744,7 +744,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
nd4j::ops::batched_gemm op; nd4j::ops::batched_gemm op;
try { try {
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
delete result; delete result;
ASSERT_TRUE(false); ASSERT_TRUE(false);
} catch (std::invalid_argument &e) { } catch (std::invalid_argument &e) {
@ -775,7 +775,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) {
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); auto exp= NDArrayFactory::create<double>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0});
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {1, 1}); auto result = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -794,7 +794,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) {
auto exp= NDArrayFactory::create<double>('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0}); auto exp= NDArrayFactory::create<double>('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0});
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {0, 0}); auto result = op.evaluate({&x, &y}, {}, {0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -813,7 +813,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) {
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {1, 0}); auto result = op.evaluate({&x, &y}, {}, {1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -832,7 +832,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) {
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {0, 1}); auto result = op.evaluate({&x, &y}, {}, {0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -851,7 +851,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) {
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -870,7 +870,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) {
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); auto exp= NDArrayFactory::create<double>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -889,7 +889,7 @@ TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
auto exp= NDArrayFactory::create<double>('c', {1, 3}, {2, 3, 4}); auto exp= NDArrayFactory::create<double>('c', {1, 3}, {2, 3, 4});
nd4j::ops::reversedivide op; nd4j::ops::reversedivide op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -920,7 +920,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test1) {
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f}); auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f});
nd4j::ops::sruCell op; nd4j::ops::sruCell op;
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); auto results = op.evaluate({&xt, &ct_1, &w, &b});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -956,7 +956,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test2) {
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f}); auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f});
nd4j::ops::sruCell op; nd4j::ops::sruCell op;
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); auto results = op.evaluate({&xt, &ct_1, &w, &b});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -991,7 +991,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) {
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f});
nd4j::ops::sruCell op; nd4j::ops::sruCell op;
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); auto results = op.evaluate({&xt, &ct_1, &w, &b});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1030,7 +1030,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test1) {
auto expHt = NDArrayFactory::create<float>('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f}); auto expHt = NDArrayFactory::create<float>('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f});
nd4j::ops::gruCell op; nd4j::ops::gruCell op;
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1066,7 +1066,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) {
auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f}); auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f});
nd4j::ops::gruCell op; nd4j::ops::gruCell op;
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1102,7 +1102,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test3) {
auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f}); auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f});
nd4j::ops::gruCell op; nd4j::ops::gruCell op;
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) {
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); auto expected= NDArrayFactory::create<double>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2});
nd4j::ops::invert_permutation op; nd4j::ops::invert_permutation op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1140,7 +1140,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) {
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); auto expected= NDArrayFactory::create<double>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2});
nd4j::ops::invert_permutation op; nd4j::ops::invert_permutation op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1159,7 +1159,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test3) {
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); auto expected= NDArrayFactory::create<double>('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7});
nd4j::ops::invert_permutation op; nd4j::ops::invert_permutation op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1180,7 +1180,7 @@ TEST_F(DeclarableOpsTests3, diag_test1) {
auto expected= NDArrayFactory::create<double>('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); auto expected= NDArrayFactory::create<double>('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6});
nd4j::ops::diag op; nd4j::ops::diag op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1201,7 +1201,7 @@ TEST_F(DeclarableOpsTests3, diag_test2) {
auto expected= NDArrayFactory::create<double>('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); auto expected= NDArrayFactory::create<double>('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6});
nd4j::ops::diag op; nd4j::ops::diag op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1222,7 +1222,7 @@ TEST_F(DeclarableOpsTests3, diag_test_vector) {
auto expected= NDArrayFactory::create<double>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); auto expected= NDArrayFactory::create<double>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4});
nd4j::ops::diag op; nd4j::ops::diag op;
auto results = op.execute({input}, {}, {}); auto results = op.evaluate({input});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1246,7 +1246,7 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) {
auto expected= NDArrayFactory::create<double>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); auto expected= NDArrayFactory::create<double>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4});
nd4j::ops::diag op; nd4j::ops::diag op;
auto results = op.execute({input}, {}, {}); auto results = op.evaluate({input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1267,7 +1267,7 @@ TEST_F(DeclarableOpsTests3, diag_test3) {
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); auto expected= NDArrayFactory::create<double>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3});
nd4j::ops::diag op; nd4j::ops::diag op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1288,7 +1288,7 @@ TEST_F(DeclarableOpsTests3, diag_test4) {
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); auto expected= NDArrayFactory::create<double>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3});
nd4j::ops::diag op; nd4j::ops::diag op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1309,7 +1309,7 @@ TEST_F(DeclarableOpsTests3, diag_test5) {
auto expected= NDArrayFactory::create<double>('c', {1,1}, {2}); auto expected= NDArrayFactory::create<double>('c', {1,1}, {2});
nd4j::ops::diag op; nd4j::ops::diag op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1330,7 +1330,7 @@ TEST_F(DeclarableOpsTests3, diag_test6) {
auto expected= NDArrayFactory::create<double>('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8}); auto expected= NDArrayFactory::create<double>('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8});
nd4j::ops::diag op; nd4j::ops::diag op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1353,7 +1353,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) {
auto expected= NDArrayFactory::create<double>('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0}); auto expected= NDArrayFactory::create<double>('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0});
nd4j::ops::matrix_set_diag op; nd4j::ops::matrix_set_diag op;
auto results = op.execute({&input, &diagonal}, {}, {}); auto results = op.evaluate({&input, &diagonal}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1376,7 +1376,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) {
auto expected= NDArrayFactory::create<float>('c', {1,1,2}, {1.f, 0.f}); auto expected= NDArrayFactory::create<float>('c', {1,1,2}, {1.f, 0.f});
nd4j::ops::matrix_set_diag op; nd4j::ops::matrix_set_diag op;
auto results = op.execute({&input, &diagonal}, {}, {}); auto results = op.evaluate({&input, &diagonal}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) {
auto expected= NDArrayFactory::create<double>('c', {2,1,4}, {1,0,0,0,1,0,0,0}); auto expected= NDArrayFactory::create<double>('c', {2,1,4}, {1,0,0,0,1,0,0,0});
nd4j::ops::matrix_set_diag op; nd4j::ops::matrix_set_diag op;
auto results = op.execute({&input, &diagonal}, {}, {}); auto results = op.evaluate({&input, &diagonal}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1422,7 +1422,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) {
auto expected= NDArrayFactory::create<double>('c', {2,1,4,1}, {1,0,0,0,1,0,0,0}); auto expected= NDArrayFactory::create<double>('c', {2,1,4,1}, {1,0,0,0,1,0,0,0});
nd4j::ops::matrix_set_diag op; nd4j::ops::matrix_set_diag op;
auto results = op.execute({&input, &diagonal}, {}, {}); auto results = op.evaluate({&input, &diagonal}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1443,7 +1443,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test1) {
auto expected= NDArrayFactory::create<double>('c', {2}, {1,4}); auto expected= NDArrayFactory::create<double>('c', {2}, {1,4});
nd4j::ops::diag_part op; nd4j::ops::diag_part op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1465,7 +1465,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test2) {
auto expected= NDArrayFactory::create<double>('c', {2,2}, {1,6,11,16}); auto expected= NDArrayFactory::create<double>('c', {2,2}, {1,6,11,16});
nd4j::ops::diag_part op; nd4j::ops::diag_part op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1486,7 +1486,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test3) {
auto expected= NDArrayFactory::create<double>('c', {2,2,2}, {1,10,19,28,37,46,55,64}); auto expected= NDArrayFactory::create<double>('c', {2,2,2}, {1,10,19,28,37,46,55,64});
nd4j::ops::diag_part op; nd4j::ops::diag_part op;
auto results = op.execute({&input}, {}, {}); auto results = op.evaluate({&input}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1512,7 +1512,7 @@ TEST_F(DeclarableOpsTests3, betainc_test1) {
auto expected = NDArrayFactory::create<float16>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); auto expected = NDArrayFactory::create<float16>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1538,7 +1538,7 @@ TEST_F(DeclarableOpsTests3, betainc_test2) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1564,7 +1564,7 @@ TEST_F(DeclarableOpsTests3, betainc_test3) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests3, betainc_test4) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1616,7 +1616,7 @@ TEST_F(DeclarableOpsTests3, betainc_test5) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1642,7 +1642,7 @@ TEST_F(DeclarableOpsTests3, betainc_test6) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1668,7 +1668,7 @@ TEST_F(DeclarableOpsTests3, betainc_test7) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1694,7 +1694,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1720,7 +1720,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1746,7 +1746,7 @@ TEST_F(DeclarableOpsTests3, betainc_test10) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1767,7 +1767,7 @@ TEST_F(DeclarableOpsTests3, betainc_test11) {
NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32); NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32);
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1789,7 +1789,7 @@ TEST_F(DeclarableOpsTests3, betainc_test12) {
NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32); NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32);
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.evaluate({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1813,7 +1813,7 @@ TEST_F(DeclarableOpsTests3, zeta_test1) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f});
nd4j::ops::zeta op; nd4j::ops::zeta op;
auto results = op.execute({&x, &q}, {}, {}); auto results = op.evaluate({&x, &q}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f});
nd4j::ops::zeta op; nd4j::ops::zeta op;
auto results = op.execute({&x, &q}, {}, {}); auto results = op.evaluate({&x, &q}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1861,7 +1861,7 @@ TEST_F(DeclarableOpsTests3, zeta_test3) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f});
nd4j::ops::zeta op; nd4j::ops::zeta op;
auto results = op.execute({&x, &q}, {}, {}); auto results = op.evaluate({&x, &q}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1886,7 +1886,7 @@ TEST_F(DeclarableOpsTests3, zeta_test4) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f});
nd4j::ops::zeta op; nd4j::ops::zeta op;
auto results = op.execute({&x, &q}, {}, {}); auto results = op.evaluate({&x, &q}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1910,7 +1910,7 @@ TEST_F(DeclarableOpsTests3, zeta_test5) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f});
nd4j::ops::zeta op; nd4j::ops::zeta op;
auto results = op.execute({&x, &q}, {}, {}); auto results = op.evaluate({&x, &q}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1934,7 +1934,7 @@ TEST_F(DeclarableOpsTests3, zeta_test6) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f});
nd4j::ops::zeta op; nd4j::ops::zeta op;
auto results = op.execute({&x, &q}, {}, {}); auto results = op.evaluate({&x, &q}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1958,7 +1958,7 @@ TEST_F(DeclarableOpsTests3, zeta_test7) {
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f}); auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f});
nd4j::ops::zeta op; nd4j::ops::zeta op;
auto results = op.execute({&x, &q}, {}, {}); auto results = op.evaluate({&x, &q}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1982,7 +1982,7 @@ TEST_F(DeclarableOpsTests3, zeta_test8) {
auto expected= NDArrayFactory::create<double>('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); auto expected= NDArrayFactory::create<double>('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
nd4j::ops::zeta op; nd4j::ops::zeta op;
auto results = op.execute({&x, &q}, {}, {}); auto results = op.evaluate({&x, &q}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2054,7 +2054,7 @@ TEST_F(DeclarableOpsTests3, Test_SplitV_Validation_1) {
auto z1 = NDArrayFactory::create<float>('c', {3, 7}); auto z1 = NDArrayFactory::create<float>('c', {3, 7});
nd4j::ops::split_v op; nd4j::ops::split_v op;
auto status = op.execute({&x, &indices, &axis}, {&z0, &z1}, {}, {}, {}); auto status = op.execute({&x, &indices, &axis}, std::vector<NDArray*>{&z0, &z1}, {}, {}, {});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -2070,7 +2070,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) {
auto expected= NDArrayFactory::create<double>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08}); auto expected= NDArrayFactory::create<double>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08});
nd4j::ops::polygamma op; nd4j::ops::polygamma op;
auto results = op.execute({&n, &x}, {}, {}); auto results = op.evaluate({&n, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2097,7 +2097,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test2) {
//ASSERT_FALSE(true); //ASSERT_FALSE(true);
nd4j::ops::polygamma op; nd4j::ops::polygamma op;
auto results = op.execute({&n, &x}, {}, {}); auto results = op.evaluate({&n, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2120,7 +2120,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) {
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
nd4j::ops::polygamma op; nd4j::ops::polygamma op;
auto results = op.execute({&n, &x}, {}, {}); auto results = op.evaluate({&n, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2141,7 +2141,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test4) {
1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE); 1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE);
nd4j::ops::polygamma op; nd4j::ops::polygamma op;
auto results = op.execute({&n, &x}, {}, {}); auto results = op.evaluate({&n, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2161,7 +2161,7 @@ TEST_F(DeclarableOpsTests3, digamma_1) {
std::numeric_limits<double>::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE); std::numeric_limits<double>::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE);
nd4j::ops::digamma op; nd4j::ops::digamma op;
auto results = op.execute({&x}, {}, {}); auto results = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2182,7 +2182,7 @@ TEST_F(DeclarableOpsTests3, svd_test1) {
auto expV= NDArrayFactory::create<double>('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543}); auto expV= NDArrayFactory::create<double>('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {1, 1, 16}); auto results = op.evaluate({&x}, {}, {1, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2219,7 +2219,7 @@ TEST_F(DeclarableOpsTests3, svd_test2) {
auto expV= NDArrayFactory::create<double>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); auto expV= NDArrayFactory::create<double>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {1, 1, 16}); auto results = op.evaluate({&x}, {}, {1, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests3, svd_test3) {
auto expV= NDArrayFactory::create<double>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); auto expV= NDArrayFactory::create<double>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 1, 16}); auto results = op.evaluate({&x}, {}, {0, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2293,7 +2293,7 @@ TEST_F(DeclarableOpsTests3, svd_test4) {
auto expV= NDArrayFactory::create<double>('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065}); auto expV= NDArrayFactory::create<double>('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {1, 1, 16}); auto results = op.evaluate({&x}, {}, {1, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2330,7 +2330,7 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
auto expV= NDArrayFactory::create<double>('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); auto expV= NDArrayFactory::create<double>('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 1, 16}); auto results = op.evaluate({&x}, {}, {0, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2385,7 +2385,7 @@ TEST_F(DeclarableOpsTests3, svd_test6) {
-0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428}); -0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {1, 1, 16}); auto results = op.evaluate({&x}, {}, {1, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2423,7 +2423,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); 39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 0, 16}); auto results = op.evaluate({&x}, {}, {0, 0, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2622,7 +2622,7 @@ TEST_F(DeclarableOpsTests3, svd_test9) {
1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01}); 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {1, 1, 16}); auto results = op.evaluate({&x}, {}, {1, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2681,7 +2681,7 @@ TEST_F(DeclarableOpsTests3, svd_test10) {
-4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01}); -4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 1, 16}); auto results = op.evaluate({&x}, {}, {0, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2726,7 +2726,7 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
-0.43596, 0.83108, -0.34531}); -0.43596, 0.83108, -0.34531});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 1, 16}); auto results = op.evaluate({&x}, {}, {0, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2761,7 +2761,7 @@ TEST_F(DeclarableOpsTests3, svd_test12) {
NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {1, 0, 16}); auto results = op.evaluate({&x}, {}, {1, 0, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2780,7 +2780,7 @@ TEST_F(DeclarableOpsTests3, elu_test1) {
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9});
nd4j::ops::elu op; nd4j::ops::elu op;
auto results = op.execute({&x}, {0.5}, {}); auto results = op.evaluate({&x}, {0.5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2799,7 +2799,7 @@ TEST_F(DeclarableOpsTests3, elu_bp_test1) {
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2});
nd4j::ops::elu_bp op; nd4j::ops::elu_bp op;
auto results = op.execute({ &x, &eps }, {0.5}, {}); auto results = op.evaluate({ &x, &eps }, {0.5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2816,7 +2816,7 @@ TEST_F(DeclarableOpsTests3, lrelu_test1) {
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
nd4j::ops::lrelu op; nd4j::ops::lrelu op;
auto results = op.execute({&x}, {0.2}, {}); auto results = op.evaluate({&x}, {0.2}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2833,7 +2833,7 @@ TEST_F(DeclarableOpsTests3, lrelu_bp_test1) {
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2});
nd4j::ops::lrelu_bp op; nd4j::ops::lrelu_bp op;
auto results = op.execute({&x, &eps}, {0.2}, {}); auto results = op.evaluate({&x, &eps}, {0.2}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2850,7 +2850,7 @@ TEST_F(DeclarableOpsTests3, selu_test1) {
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
nd4j::ops::selu op; nd4j::ops::selu op;
auto results = op.execute({&x}, {}, {}); auto results = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2868,7 +2868,7 @@ TEST_F(DeclarableOpsTests3, selu_test2) {
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402});
nd4j::ops::selu_bp op; nd4j::ops::selu_bp op;
auto results = op.execute({&x, &eps}, {0.2}, {}); auto results = op.evaluate({&x, &eps}, {0.2}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2888,7 +2888,7 @@ TEST_F(DeclarableOpsTests3, EQScalarTests_1) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::eq_scalar op; nd4j::ops::eq_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_TRUE(res); ASSERT_TRUE(res);
} }
@ -2900,7 +2900,7 @@ TEST_F(DeclarableOpsTests3, EQScalarTests_2) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::eq_scalar op; nd4j::ops::eq_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_FALSE(res); ASSERT_FALSE(res);
} }
@ -2911,7 +2911,7 @@ TEST_F(DeclarableOpsTests3, GTScalarTests_1) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::gt_scalar op; nd4j::ops::gt_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_FALSE(res); ASSERT_FALSE(res);
} }
@ -2922,7 +2922,7 @@ TEST_F(DeclarableOpsTests3, GTScalarTests_2) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::gt_scalar op; nd4j::ops::gt_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_TRUE(res); ASSERT_TRUE(res);
} }
@ -2933,7 +2933,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_1) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::gte_scalar op; nd4j::ops::gte_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_TRUE(res); ASSERT_TRUE(res);
} }
@ -2944,7 +2944,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_2) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::gte_scalar op; nd4j::ops::gte_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_TRUE(res); ASSERT_TRUE(res);
} }
@ -2955,7 +2955,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_3) {
auto scalar = NDArrayFactory::create(2.0f); auto scalar = NDArrayFactory::create(2.0f);
nd4j::ops::gte_scalar op; nd4j::ops::gte_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_FALSE(res); ASSERT_FALSE(res);
} }
@ -2966,7 +2966,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_1) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::lte_scalar op; nd4j::ops::lte_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_TRUE(res); ASSERT_TRUE(res);
} }
@ -2977,7 +2977,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_2) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::lte_scalar op; nd4j::ops::lte_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_FALSE(res); ASSERT_FALSE(res);
} }
@ -2988,7 +2988,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_3) {
auto scalar = NDArrayFactory::create(2.0f); auto scalar = NDArrayFactory::create(2.0f);
nd4j::ops::lte_scalar op; nd4j::ops::lte_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_TRUE(res); ASSERT_TRUE(res);
} }
@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests3, NEQScalarTests_1) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::neq_scalar op; nd4j::ops::neq_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_FALSE(res); ASSERT_FALSE(res);
} }
@ -3011,7 +3011,7 @@ TEST_F(DeclarableOpsTests3, NEQScalarTests_2) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::neq_scalar op; nd4j::ops::neq_scalar op;
auto res = op.evaluate({&x, &scalar}); auto res = op.verify({&x, &scalar});
ASSERT_TRUE(res); ASSERT_TRUE(res);
} }
@ -3022,7 +3022,7 @@ TEST_F(DeclarableOpsTests3, NOOPTests_1) {
auto scalar = NDArrayFactory::create(1.0f); auto scalar = NDArrayFactory::create(1.0f);
nd4j::ops::noop op; nd4j::ops::noop op;
auto res = op.execute({&x, &scalar}, {}, {}); auto res = op.evaluate({&x, &scalar}, {}, {});
ASSERT_TRUE(res->status() == nd4j::Status::OK()); ASSERT_TRUE(res->status() == nd4j::Status::OK());
delete res; delete res;
} }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) {
matrix.linspace(1); matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -71,7 +71,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
matrix.linspace(1); matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -92,7 +92,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
//matrix.linspace(1); //matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -114,7 +114,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
//matrix.linspace(1); //matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -133,7 +133,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
auto s = NDArrayFactory::create_<int>('c', {1}, {1}); auto s = NDArrayFactory::create_<int>('c', {1}, {1});
nd4j::ops::ones_as opOnes; nd4j::ops::ones_as opOnes;
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f}); //auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
auto onesRes = opOnes.execute({&matrix}, {}, {}); auto onesRes = opOnes.evaluate({&matrix});
//matrix.linspace(1); //matrix.linspace(1);
ASSERT_EQ(onesRes->status(), Status::OK()); ASSERT_EQ(onesRes->status(), Status::OK());
@ -181,7 +181,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) {
//matrix.linspace(1); //matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -201,7 +201,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
//matrix.linspace(1); //matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -222,7 +222,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
//matrix.linspace(1); //matrix.linspace(1);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0}); auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -244,7 +244,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
grad.linspace(1); grad.linspace(1);
nd4j::ops::strided_slice_bp op; nd4j::ops::strided_slice_bp op;
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -266,7 +266,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
//grad.linspace(1); //grad.linspace(1);
nd4j::ops::strided_slice_bp op; nd4j::ops::strided_slice_bp op;
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -288,7 +288,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
grad.linspace(1); grad.linspace(1);
nd4j::ops::strided_slice_bp op; nd4j::ops::strided_slice_bp op;
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -302,7 +302,7 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f}); auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
nd4j::ops::test_scalar op; nd4j::ops::test_scalar op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -321,7 +321,7 @@ TEST_F(DeclarableOpsTests6, Test_Order_1) {
exp.linspace(1); exp.linspace(1);
nd4j::ops::order op; nd4j::ops::order op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -336,7 +336,7 @@ TEST_F(DeclarableOpsTests6, cumSum_1) {
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1.f, 3.f, 6.f, 10.f}); auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1.f, 3.f, 6.f, 10.f});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0}); auto result = op.evaluate({&x}, {}, {0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -352,7 +352,7 @@ TEST_F(DeclarableOpsTests6, cumSum_2) {
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f}); auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1}); auto result = op.evaluate({&x}, {}, {0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -369,7 +369,7 @@ TEST_F(DeclarableOpsTests6, cumSum_3) {
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}); auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 0}); auto result = op.evaluate({&x}, {}, {0, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -385,7 +385,7 @@ TEST_F(DeclarableOpsTests6, cumSum_4) {
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -401,7 +401,7 @@ TEST_F(DeclarableOpsTests6, cumSum_5) {
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {0, 1, 1}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -416,7 +416,7 @@ TEST_F(DeclarableOpsTests6, cumSum_6) {
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {1, 1, 0}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -431,7 +431,7 @@ TEST_F(DeclarableOpsTests6, cumSum_7) {
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {1, 1, 1}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -447,7 +447,7 @@ TEST_F(DeclarableOpsTests6, cumSum_8) {
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x, &axis}, {}, {1, 1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -475,7 +475,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
exclusive = 0; reverse = 0; exclusive = 0; reverse = 0;
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
ASSERT_TRUE(expFF.equalsTo(z)); ASSERT_TRUE(expFF.equalsTo(z));
@ -484,7 +484,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
//************************************// //************************************//
exclusive = 1; reverse = 0; exclusive = 1; reverse = 0;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
z = result->at(0); z = result->at(0);
ASSERT_TRUE(expTF.equalsTo(z)); ASSERT_TRUE(expTF.equalsTo(z));
@ -493,7 +493,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
//************************************// //************************************//
exclusive = 0; reverse = 1; exclusive = 0; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
z = result->at(0); z = result->at(0);
ASSERT_TRUE(expFT.equalsTo(z)); ASSERT_TRUE(expFT.equalsTo(z));
@ -502,7 +502,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
//************************************// //************************************//
exclusive = 1; reverse = 1; exclusive = 1; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
z = result->at(0); z = result->at(0);
ASSERT_TRUE(expTT.equalsTo(z)); ASSERT_TRUE(expTT.equalsTo(z));
@ -516,7 +516,7 @@ TEST_F(DeclarableOpsTests6, cumSum_10) {
auto y = NDArrayFactory::create<int>(-3); auto y = NDArrayFactory::create<int>(-3);
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x, &y}, {}, {1, 1}, {}); auto result = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -531,7 +531,7 @@ TEST_F(DeclarableOpsTests6, cumSum_11) {
x.linspace(1); x.linspace(1);
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {0, 1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -550,7 +550,7 @@ TEST_F(DeclarableOpsTests6, cumSum_12) {
x.linspace(1); x.linspace(1);
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -569,7 +569,7 @@ TEST_F(DeclarableOpsTests6, cumSum_13) {
x.linspace(1); x.linspace(1);
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {1, 1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -588,7 +588,7 @@ TEST_F(DeclarableOpsTests6, cumSum_14) {
x.linspace(1); x.linspace(1);
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {1, 1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -607,7 +607,7 @@ TEST_F(DeclarableOpsTests6, cumSum_15) {
x.linspace(1); x.linspace(1);
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 2}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {}, {0, 1, 2});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -623,7 +623,7 @@ TEST_F(DeclarableOpsTests6, cumSum_16) {
NDArray x('f', {3, 4}, nd4j::DataType::FLOAT32); NDArray x('f', {3, 4}, nd4j::DataType::FLOAT32);
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1}); auto result = op.evaluate({&x}, {}, {0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -659,7 +659,7 @@ TEST_F(DeclarableOpsTests6, cumSum_17) {
} }
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1}); auto result = op.evaluate({&x}, {}, {0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -692,7 +692,7 @@ TEST_F(DeclarableOpsTests6, cumSum_18) {
} }
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 0, 1}); auto result = op.evaluate({&x}, {}, {1, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -725,7 +725,7 @@ TEST_F(DeclarableOpsTests6, cumSum_19) {
} }
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 1}); auto result = op.evaluate({&x}, {}, {0, 1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -759,7 +759,7 @@ TEST_F(DeclarableOpsTests6, cumSum_20) {
} }
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 1}); auto result = op.evaluate({&x}, {}, {1, 1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -778,7 +778,7 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) {
auto exp = NDArrayFactory::create<int>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); auto exp = NDArrayFactory::create<int>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
nd4j::ops::mergemaxindex op; nd4j::ops::mergemaxindex op;
auto ress = op.execute({&x, &y, &z}, {}, {}, {}); auto ress = op.evaluate({&x, &y, &z}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("MergeMaxIndex Result is "); // ress->at(0)->printIndexedBuffer("MergeMaxIndex Result is ");
@ -797,7 +797,7 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
nd4j::ops::mergemaxindex op; nd4j::ops::mergemaxindex op;
auto ress = op.execute({&x, &y, &z}, {}, {nd4j::DataType::INT64}, {}); auto ress = op.evaluate({&x, &y, &z}, {}, {nd4j::DataType::INT64});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("MergeMaxIndex2 Result is "); // ress->at(0)->printIndexedBuffer("MergeMaxIndex2 Result is ");
@ -814,7 +814,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) {
auto shape = NDArrayFactory::create<Nd4jLong>({2, 2}); auto shape = NDArrayFactory::create<Nd4jLong>({2, 2});
nd4j::ops::dropout op; nd4j::ops::dropout op;
auto ress = op.execute({&x, &shape}, {0.2f}, {113}, {}, false, nd4j::DataType::DOUBLE); auto ress = op.evaluate({&x, &shape}, {0.2f}, {113});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//ress->at(0)->printIndexedBuffer("Result is "); //ress->at(0)->printIndexedBuffer("Result is ");
@ -830,7 +830,7 @@ TEST_F(DeclarableOpsTests6, TestMod_1) {
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0}); auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0});
nd4j::ops::mod op; nd4j::ops::mod op;
auto ress = op.execute({&x, &y}, {}, {}, {}); auto ress = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("MOD Result is "); // ress->at(0)->printIndexedBuffer("MOD Result is ");
@ -848,7 +848,7 @@ TEST_F(DeclarableOpsTests6, TestMod_BP_1) {
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}); auto exp = NDArrayFactory::create<double>('c', {2, 2, 2});
nd4j::ops::mod_bp op; nd4j::ops::mod_bp op;
auto ress = op.execute({&x, &y, &eps}, {}, {}, {}); auto ress = op.evaluate({&x, &y, &eps});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("MOD_BP Result is "); // ress->at(0)->printIndexedBuffer("MOD_BP Result is ");
@ -867,7 +867,7 @@ TEST_F(DeclarableOpsTests6, TestRank_1) {
auto exp = NDArrayFactory::create<int>(3); auto exp = NDArrayFactory::create<int>(3);
nd4j::ops::rank op; nd4j::ops::rank op;
auto ress = op.execute({&x}, {}, {}, {}); auto ress = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
@ -881,7 +881,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_2) {
nd4j::ops::dropout op; nd4j::ops::dropout op;
auto ress = op.execute({&x}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE); auto ress = op.evaluate({&x}, {0.4f}, {113});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
@ -896,7 +896,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_3) {
nd4j::ops::dropout op; nd4j::ops::dropout op;
auto ress = op.execute({&x, &shape}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE); auto ress = op.evaluate({&x, &shape}, {0.4f}, {113});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
@ -913,7 +913,7 @@ TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) {
nd4j::ops::max_pool_with_argmax op; nd4j::ops::max_pool_with_argmax op;
auto ress = op.execute({&x}, {}, {1,1,1,1,1,1,1,1,1}); auto ress = op.evaluate({&x}, {}, {1,1,1,1,1,1,1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_TRUE(expI.isSameShape(ress->at(0))); ASSERT_TRUE(expI.isSameShape(ress->at(0)));
@ -942,7 +942,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_1) {
nd4j::ops::sufficient_statistics op; nd4j::ops::sufficient_statistics op;
auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto ress = op.evaluate({&x, &axis});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_EQ(ress->at(0)->e<double>(0), count); ASSERT_EQ(ress->at(0)->e<double>(0), count);
@ -974,7 +974,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_2) {
nd4j::ops::sufficient_statistics op; nd4j::ops::sufficient_statistics op;
auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto ress = op.evaluate({&x, &axis});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_EQ(ress->at(0)->e<double>(0), count); ASSERT_EQ(ress->at(0)->e<double>(0), count);
@ -996,7 +996,7 @@ TEST_F(DeclarableOpsTests6, BinCount_1) {
nd4j::ops::bincount op; nd4j::ops::bincount op;
auto res = op.execute({&x}, {}, {}); auto res = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1021,7 +1021,7 @@ TEST_F(DeclarableOpsTests6, BinCount_2) {
nd4j::ops::bincount op; nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {}); auto res = op.evaluate({&x, &weights});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1046,7 +1046,7 @@ TEST_F(DeclarableOpsTests6, BinCount_3) {
nd4j::ops::bincount op; nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {0, 2}); auto res = op.evaluate({&x, &weights}, {}, {0, 2});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1071,7 +1071,7 @@ TEST_F(DeclarableOpsTests6, BinCount_4) {
nd4j::ops::bincount op; nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {4, 4}); auto res = op.evaluate({&x, &weights}, {}, {4, 4});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1097,7 +1097,7 @@ TEST_F(DeclarableOpsTests6, BinCount_5) {
nd4j::ops::bincount op; nd4j::ops::bincount op;
auto res = op.execute({&x, &weights, &minV, &maxV}, {}, {}); auto res = op.evaluate({&x, &weights, &minV, &maxV});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printBuffer("BC out"); // res->at(0)->printBuffer("BC out");
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1116,7 +1116,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) {
nd4j::ops::broadcast_dynamic_shape op; nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}); auto res = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1135,7 +1135,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
nd4j::ops::broadcast_dynamic_shape op; nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); auto res = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1153,7 +1153,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {
nd4j::ops::broadcast_dynamic_shape op; nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}); auto res = op.evaluate({&x, &y}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1172,7 +1172,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
nd4j::ops::broadcast_dynamic_shape op; nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); auto res = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
//res->at(0)->printBuffer("Shape SGO 4"); //res->at(0)->printBuffer("Shape SGO 4");
@ -1191,7 +1191,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4}); auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
nd4j::ops::broadcast_dynamic_shape op; nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); auto res = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1209,7 +1209,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3}); auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
nd4j::ops::broadcast_dynamic_shape op; nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); auto res = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1274,7 +1274,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
// auto expNorm(8.660254); // auto expNorm(8.660254);
nd4j::ops::clip_by_global_norm op; nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x}, {0.8}, {}); auto result = op.evaluate({&x}, {0.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1316,7 +1316,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) {
); );
nd4j::ops::clip_by_global_norm op; nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x, &a}, {1.8}, {}); auto result = op.evaluate({&x, &a}, {1.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1346,7 +1346,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) {
); );
nd4j::ops::clip_by_global_norm op; nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x, &a}, {0.8}, {}); auto result = op.evaluate({&x, &a}, {0.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1372,7 +1372,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) {
auto exp = NDArrayFactory::create<double>({36.0, -48.0}); auto exp = NDArrayFactory::create<double>({36.0, -48.0});
nd4j::ops::matrix_determinant op; nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1393,7 +1393,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
auto exp = NDArrayFactory::create<double>({-2.0, -2.0}); auto exp = NDArrayFactory::create<double>({-2.0, -2.0});
nd4j::ops::matrix_determinant op; nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1414,7 +1414,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
NDArray exp('c', {1}, {-54.0}); NDArray exp('c', {1}, {-54.0});
nd4j::ops::matrix_determinant op; nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1435,7 +1435,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
auto exp = NDArrayFactory::create<double>('c', {1}, {189.0}); auto exp = NDArrayFactory::create<double>('c', {1}, {189.0});
nd4j::ops::matrix_determinant op; nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1459,7 +1459,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
x.p(12, 12.0); x.p(12, 12.0);
nd4j::ops::matrix_determinant op; nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1483,7 +1483,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) {
x.p(12, 12.0); x.p(12, 12.0);
nd4j::ops::matrix_determinant op; nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1505,7 +1505,7 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
auto exp = NDArrayFactory::create<double>({3.58351893845611, 3.871201010907891}); auto exp = NDArrayFactory::create<double>({3.58351893845611, 3.871201010907891});
nd4j::ops::log_matrix_determinant op; nd4j::ops::log_matrix_determinant op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1524,7 +1524,7 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008}); auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008});
nd4j::ops::logdet op; nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1542,7 +1542,7 @@ TEST_F(DeclarableOpsTests6, LogDet_2) {
auto exp = NDArrayFactory::create<double>('c', {1}, { 3.5835189}); auto exp = NDArrayFactory::create<double>('c', {1}, { 3.5835189});
nd4j::ops::logdet op; nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1561,7 +1561,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) {
auto exp = NDArrayFactory::create<double>( 3.5835189); auto exp = NDArrayFactory::create<double>( 3.5835189);
nd4j::ops::logdet op; nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1605,7 +1605,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1624,7 +1624,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f});
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1643,7 +1643,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f }); auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1662,7 +1662,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f }); auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1740,7 +1740,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1774,7 +1774,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1808,7 +1808,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1842,7 +1842,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests6, ReluLayer_1) {
26.2, 31.65, 60.7}); 26.2, 31.65, 60.7});
nd4j::ops::relu_layer op; nd4j::ops::relu_layer op;
auto result = op.execute({&x, &w, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &w, &b});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -1923,7 +1923,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test1) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
nd4j::ops::static_rnn op; nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1966,7 +1966,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test2) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648});
nd4j::ops::static_rnn op; nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2009,7 +2009,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test3) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2});
nd4j::ops::static_rnn op; nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2051,7 +2051,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test4) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882});
nd4j::ops::static_rnn op; nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2093,7 +2093,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test5) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653});
nd4j::ops::static_rnn op; nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2144,7 +2144,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) {
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25});
nd4j::ops::static_bidirectional_rnn op; nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2197,7 +2197,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) {
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.}); auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.});
nd4j::ops::static_bidirectional_rnn op; nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2250,7 +2250,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) {
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713}); auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713});
nd4j::ops::static_bidirectional_rnn op; nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2296,7 +2296,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
nd4j::ops::dynamic_rnn op; nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1}); auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2341,7 +2341,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
nd4j::ops::dynamic_rnn op; nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2383,7 +2383,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
nd4j::ops::dynamic_rnn op; nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2424,7 +2424,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608});
nd4j::ops::dynamic_rnn op; nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2465,7 +2465,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) {
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833}); auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833});
nd4j::ops::dynamic_rnn op; nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {}); auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2521,7 +2521,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) {
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25}); auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25});
nd4j::ops::dynamic_bidirectional_rnn op; nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2581,7 +2581,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) {
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25}); auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25});
nd4j::ops::dynamic_bidirectional_rnn op; nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2637,7 +2637,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) {
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.}); auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.});
nd4j::ops::dynamic_bidirectional_rnn op; nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2696,7 +2696,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) {
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357}); auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357});
nd4j::ops::dynamic_bidirectional_rnn op; nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {}); auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2749,7 +2749,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234}); auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234});
nd4j::ops::dynamic_bidirectional_rnn op; nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2776,7 +2776,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_1) {
auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f}); auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f});
nd4j::ops::diag op; nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));
@ -2789,7 +2789,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f}); auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
nd4j::ops::diag op; nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));
@ -2802,7 +2802,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f}); auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
nd4j::ops::diag op; nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -51,7 +51,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) {
nd4j::ops::reduce_stdev_bp op; nd4j::ops::reduce_stdev_bp op;
auto result = op.execute({&x, &gradO2}, {0,0}, {1}); auto result = op.evaluate({&x, &gradO2}, {0,0}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
// output->printIndexedBuffer(); // output->printIndexedBuffer();
@ -59,7 +59,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) {
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
delete result; delete result;
result = op.execute({&x, &gradO1}, {1,0}, {1}); result = op.evaluate({&x, &gradO1}, {1,0}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
output = result->at(0); output = result->at(0);
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
@ -80,7 +80,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) {
nd4j::ops::reduce_stdev_bp op; nd4j::ops::reduce_stdev_bp op;
auto result = op.execute({&x, &gradO2, &axis}, {}, {}, {false, false}); auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
// output->printIndexedBuffer(); // output->printIndexedBuffer();
@ -88,7 +88,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) {
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
delete result; delete result;
result = op.execute({&x, &gradO1}, {1,0}, {1}); result = op.evaluate({&x, &gradO1}, {1,0}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
output = result->at(0); output = result->at(0);
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {1}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -272,7 +272,7 @@ TEST_F(DeclarableOpsTests9, concat_test2) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {1}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -296,7 +296,7 @@ TEST_F(DeclarableOpsTests9, concat_test3) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {0}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -316,7 +316,7 @@ TEST_F(DeclarableOpsTests9, concat_test4) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {1}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -336,7 +336,7 @@ TEST_F(DeclarableOpsTests9, concat_test5) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {0}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -356,7 +356,7 @@ TEST_F(DeclarableOpsTests9, concat_test6) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {0}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -376,7 +376,7 @@ TEST_F(DeclarableOpsTests9, concat_test7) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {0}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -394,7 +394,7 @@ TEST_F(DeclarableOpsTests9, concat_test8) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0}, {}, {0}); auto result = op.evaluate({&x0}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -412,7 +412,7 @@ TEST_F(DeclarableOpsTests9, concat_test9) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0}, {}, {0}); auto result = op.evaluate({&x0}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -437,7 +437,7 @@ TEST_F(DeclarableOpsTests9, concat_test10) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {1}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -462,7 +462,7 @@ TEST_F(DeclarableOpsTests9, concat_test11) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {1}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests9, concat_test12) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {1}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -512,7 +512,7 @@ TEST_F(DeclarableOpsTests9, concat_test13) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2}, {}, {1}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -532,7 +532,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
x1 = 2.; x1 = 2.;
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1}, {}, {0}, {}); auto result = op.evaluate({&x0, &x1}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -555,7 +555,7 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
auto exp = NDArrayFactory::create<double>('c', {3}, {1, 0, 3}); auto exp = NDArrayFactory::create<double>('c', {3}, {1, 0, 3});
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x, &y}, {}, {0}); auto result = op.evaluate({&x, &y}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
auto exp = NDArrayFactory::create<double>('c', {0,2,3}); auto exp = NDArrayFactory::create<double>('c', {0,2,3});
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x, &y}, {}, {0}); auto result = op.evaluate({&x, &y}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -594,7 +594,7 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
x1 = 2.; x1 = 2.;
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1}, {}, {0}, {}); auto result = op.evaluate({&x0, &x1}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -675,7 +675,7 @@ TEST_F(DeclarableOpsTests9, concat_test20) {
x3.assign(4.0); x3.assign(4.0);
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {}); auto result = op.evaluate({&x0, &x1, &x2, &x3}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -763,7 +763,7 @@ TEST_F(DeclarableOpsTests9, concat_test25) {
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &axis}, {}, {}, {true}); auto result = op.evaluate({&x0, &x1, &axis}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -784,7 +784,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test1) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op; nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {2, 3}); auto results = op.evaluate({&input, &gradO}, {}, {2, 3});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -804,7 +804,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test2) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op; nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {1, 3}); auto results = op.evaluate({&input, &gradO}, {}, {1, 3});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
@ -823,7 +823,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test3) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op; nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {1, 1}); auto results = op.evaluate({&input, &gradO}, {}, {1, 1});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -843,7 +843,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test4) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op; nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {2}); auto results = op.evaluate({&input, &gradO}, {}, {2});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -863,7 +863,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test5) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op; nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {1}); auto results = op.evaluate({&input, &gradO}, {}, {1});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -883,7 +883,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test6) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op; nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {1, 3, 2}); auto results = op.evaluate({&input, &gradO}, {}, {1, 3, 2});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -904,7 +904,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test7) {
gradO.linspace(0.01, 0.01); gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op; nd4j::ops::tile_bp op;
auto results = op.execute({&input, &reps, &gradO}, {}, {}); auto results = op.evaluate({&input, &reps, &gradO}, {}, {});
auto gradI = results->at(0); auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -922,7 +922,7 @@ TEST_F(DeclarableOpsTests9, tile_test1) {
auto expOut = NDArrayFactory::create<double>('c', {2, 6,}, {1.,2.,3.,4.,5.,6., 1.,2.,3.,4.,5.,6.}); auto expOut = NDArrayFactory::create<double>('c', {2, 6,}, {1.,2.,3.,4.,5.,6., 1.,2.,3.,4.,5.,6.});
nd4j::ops::tile op; nd4j::ops::tile op;
auto results = op.execute({&input, &reps}, {}, {}); auto results = op.evaluate({&input, &reps}, {}, {});
auto out = results->at(0); auto out = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -944,7 +944,7 @@ TEST_F(DeclarableOpsTests9, matmul_test1) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -966,7 +966,7 @@ TEST_F(DeclarableOpsTests9, matmul_test2) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -987,7 +987,7 @@ TEST_F(DeclarableOpsTests9, matmul_test3) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1009,7 +1009,7 @@ TEST_F(DeclarableOpsTests9, matmul_test4) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1031,7 +1031,7 @@ TEST_F(DeclarableOpsTests9, matmul_test5) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1}); auto results = op.evaluate({&x, &y}, {}, {1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1052,7 +1052,7 @@ TEST_F(DeclarableOpsTests9, matmul_test6) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1075,7 +1075,7 @@ TEST_F(DeclarableOpsTests9, matmul_test7) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {0, 1}); auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1100,7 +1100,7 @@ TEST_F(DeclarableOpsTests9, matmul_test8) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {0, 1}); auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1125,7 +1125,7 @@ TEST_F(DeclarableOpsTests9, matmul_test9) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1142,7 +1142,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_BP_1) {
NDArray shape('c', {2}, {2, 2}); NDArray shape('c', {2}, {2, 2});
nd4j::ops::dropout_bp op; nd4j::ops::dropout_bp op;
auto ress = op.execute({&x, &errs, &shape}, {0.2f}, {113}); auto ress = op.evaluate({&x, &errs, &shape}, {0.2f}, {113});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//ress->at(0)->printIndexedBuffer("Result is "); //ress->at(0)->printIndexedBuffer("Result is ");
@ -1159,7 +1159,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) {
//NDArray<float> shape({2.f, 2.f}); //NDArray<float> shape({2.f, 2.f});
nd4j::ops::dropout op; nd4j::ops::dropout op;
x.linspace(1); x.linspace(1);
auto ress = op.execute({&x}, {0.2f}, {113}); auto ress = op.evaluate({&x}, {0.2f}, {113});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
NDArray* res = ress->at(0); //->printIndexedBuffer("Result is "); NDArray* res = ress->at(0); //->printIndexedBuffer("Result is ");
@ -1167,7 +1167,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) {
//res->printIndexedBuffer("Result for Dropout_1"); //res->printIndexedBuffer("Result for Dropout_1");
auto countZero = res->reduceNumber(reduce::CountZero); auto countZero = res->reduceNumber(reduce::CountZero);
ASSERT_NEAR(countZero.e<Nd4jLong>(0), 80, 5); ASSERT_NEAR(countZero.e<Nd4jLong>(0), 80, 5);
auto ress2 = op.execute({&x}, {0.2f}, {113}); auto ress2 = op.evaluate({&x}, {0.2f}, {113});
ASSERT_EQ(ND4J_STATUS_OK, ress2->status()); ASSERT_EQ(ND4J_STATUS_OK, ress2->status());
NDArray* res2 = ress2->at(0); NDArray* res2 = ress2->at(0);
@ -1214,7 +1214,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
*/ */
nd4j::ops::dropout op; nd4j::ops::dropout op;
auto ress = op.execute({&x1}, {0.5f}, {119}); auto ress = op.evaluate({&x1}, {0.5f}, {119});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//ress->at(0)->printIndexedBuffer("01Dropout result is "); //ress->at(0)->printIndexedBuffer("01Dropout result is ");
@ -1225,11 +1225,11 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
//NDArray<float> exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f}); //NDArray<float> exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f});
//02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] //02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000]
auto ressX = op2.execute({&x1, &x1}, {0.5f}, {119}); // , false, nd4j::DataType::FLOAT32); // skipped due given by default auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, {119}); // , false, nd4j::DataType::FLOAT32); // skipped due given by default
//x0.printIndexedBuffer("X0"); //x0.printIndexedBuffer("X0");
//x1.printIndexedBuffer("X1"); //x1.printIndexedBuffer("X1");
ASSERT_EQ(ND4J_STATUS_OK, ressX->status()); ASSERT_EQ(ND4J_STATUS_OK, ressX->status());
auto ressY = op2.execute({&x1, &x0}, {0.5f}, {119}); auto ressY = op2.evaluate({&x1, &x0}, {0.5f}, {119});
ASSERT_EQ(ND4J_STATUS_OK, ressY->status()); ASSERT_EQ(ND4J_STATUS_OK, ressY->status());
//ressY->at(0)->printIndexedBuffer("BP"); //ressY->at(0)->printIndexedBuffer("BP");
//ress->at(0)->printIndexedBuffer("FF"); //ress->at(0)->printIndexedBuffer("FF");
@ -1264,17 +1264,17 @@ TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) {
nd4j::ops::dropout op; nd4j::ops::dropout op;
auto ress = op.execute({&x}, {0.5f}, {119}); auto ress = op.evaluate({&x}, {0.5f}, {119});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("01Dropout result is "); // ress->at(0)->printIndexedBuffer("01Dropout result is ");
nd4j::ops::dropout_bp op2; nd4j::ops::dropout_bp op2;
auto ressX = op2.execute({&x, &x}, {0.5f}, {119}); auto ressX = op2.evaluate({&x, &x}, {0.5f}, {119});
ASSERT_EQ(ND4J_STATUS_OK, ressX->status()); ASSERT_EQ(ND4J_STATUS_OK, ressX->status());
auto ressY = op2.execute({&x, &x}, {0.5f}, {119}); auto ressY = op2.evaluate({&x, &x}, {0.5f}, {119});
ASSERT_EQ(ND4J_STATUS_OK, ressY->status()); ASSERT_EQ(ND4J_STATUS_OK, ressY->status());
//ress->at(0)->printIndexedBuffer("FF Dropout result is "); //ress->at(0)->printIndexedBuffer("FF Dropout result is ");
@ -1307,12 +1307,12 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) {
nd4j::ops::alpha_dropout_bp op; nd4j::ops::alpha_dropout_bp op;
auto ress = op.execute({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119});
ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ND4J_STATUS_OK, ress->status());
NDArray* res = ress->at(0); NDArray* res = ress->at(0);
auto ress2 = op.execute({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119});
ASSERT_EQ(ND4J_STATUS_OK, ress2->status()); ASSERT_EQ(ND4J_STATUS_OK, ress2->status());
NDArray* res2 = ress2->at(0); NDArray* res2 = ress2->at(0);
@ -1336,7 +1336,7 @@ TEST_F(DeclarableOpsTests9, matmul_test10) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1356,7 +1356,7 @@ TEST_F(DeclarableOpsTests9, matmul_test11) {
x.linspace(1.); x.linspace(1.);
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
auto z = results->at(0); auto z = results->at(0);
@ -1377,7 +1377,7 @@ TEST_F(DeclarableOpsTests9, matmul_test12) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
auto z = results->at(0); auto z = results->at(0);
@ -1398,7 +1398,7 @@ TEST_F(DeclarableOpsTests9, matmul_test13) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {0, 0, 1}); auto results = op.evaluate({&x, &y}, {}, {0, 0, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1419,7 +1419,7 @@ TEST_F(DeclarableOpsTests9, matmul_test14) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 0, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1440,7 +1440,7 @@ TEST_F(DeclarableOpsTests9, matmul_test15) {
y.linspace(0.5, 0.5); y.linspace(0.5, 0.5);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 0, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1464,7 +1464,7 @@ TEST_F(DeclarableOpsTests9, matmul_test16) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 1, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1485,7 +1485,7 @@ TEST_F(DeclarableOpsTests9, matmul_test17) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 0}); auto results = op.evaluate({&x, &y}, {}, {1, 0});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1506,7 +1506,7 @@ TEST_F(DeclarableOpsTests9, matmul_test18) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {0, 1}); auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1527,7 +1527,7 @@ TEST_F(DeclarableOpsTests9, matmul_test19) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1549,7 +1549,7 @@ TEST_F(DeclarableOpsTests9, matmul_test20) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1,1,1}); auto results = op.evaluate({&x, &y}, {}, {1,1,1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1571,7 +1571,7 @@ TEST_F(DeclarableOpsTests9, matmul_test21) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {}); auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1593,7 +1593,7 @@ TEST_F(DeclarableOpsTests9, matmul_test22) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1}); auto results = op.evaluate({&x, &y}, {}, {1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1615,7 +1615,7 @@ TEST_F(DeclarableOpsTests9, matmul_test23) {
y.linspace(0.1, 0.1); y.linspace(0.1, 0.1);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1634,7 +1634,7 @@ TEST_F(DeclarableOpsTests9, matmul_test24) {
auto exp = NDArrayFactory::create<double>(6.); auto exp = NDArrayFactory::create<double>(6.);
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1}); auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0); auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -1650,7 +1650,7 @@ TEST_F(DeclarableOpsTests9, test_range_int_1) {
auto x2 = NDArrayFactory::create<int>(1); auto x2 = NDArrayFactory::create<int>(1);
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&x0, &x1, &x2}, {}, {}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1664,7 +1664,7 @@ TEST_F(DeclarableOpsTests9, test_range_empty_1) {
auto x2 = NDArrayFactory::create<int>(1); auto x2 = NDArrayFactory::create<int>(1);
nd4j::ops::range op; nd4j::ops::range op;
auto result = op.execute({&x0, &x1, &x2}, {}, {}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1703,7 +1703,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_1) {
x.linspace(1.0); x.linspace(1.0);
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(5, result->size()); ASSERT_EQ(5, result->size());
@ -1721,7 +1721,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) {
auto z5 = NDArrayFactory::create<double>(5); auto z5 = NDArrayFactory::create<double>(5);
std::vector<NDArray*> z({&z1, &z2, &z3, &z4, &z5}); std::vector<NDArray*> z({&z1, &z2, &z3, &z4, &z5});
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(5, result->size()); ASSERT_EQ(5, result->size());
for (size_t i = 0; i < result->size(); i++) { for (size_t i = 0; i < result->size(); i++) {
@ -1758,7 +1758,7 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
} }
nd4j::ops::clipbynorm op; nd4j::ops::clipbynorm op;
auto result = op.execute({&y}, {clip}, {axis}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&y}, {clip}, {axis});
auto outFF = result->at(0); auto outFF = result->at(0);
ASSERT_TRUE(expect.isSameShape(outFF)); ASSERT_TRUE(expect.isSameShape(outFF));
@ -1852,7 +1852,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
exclusive = 0; reverse = 0; exclusive = 0; reverse = 0;
nd4j::ops::cumprod op; nd4j::ops::cumprod op;
auto result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
ASSERT_TRUE(expFF.equalsTo(z)); ASSERT_TRUE(expFF.equalsTo(z));
@ -1861,7 +1861,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
//************************************// //************************************//
exclusive = 1; reverse = 0; exclusive = 1; reverse = 0;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
z = result->at(0); z = result->at(0);
ASSERT_TRUE(expTF.equalsTo(z)); ASSERT_TRUE(expTF.equalsTo(z));
@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
//************************************// //************************************//
exclusive = 0; reverse = 1; exclusive = 0; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
z = result->at(0); z = result->at(0);
ASSERT_TRUE(expFT.equalsTo(z)); ASSERT_TRUE(expFT.equalsTo(z));
@ -1879,7 +1879,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
//************************************// //************************************//
exclusive = 1; reverse = 1; exclusive = 1; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
z = result->at(0); z = result->at(0);
ASSERT_TRUE(expTT.equalsTo(z)); ASSERT_TRUE(expTT.equalsTo(z));
@ -1910,7 +1910,7 @@ TEST_F(DeclarableOpsTests9, cumprod_2) {
} }
nd4j::ops::cumprod op; nd4j::ops::cumprod op;
auto result = op.execute({&x}, {}, {0, 0, 1}); auto result = op.evaluate({&x}, {}, {0, 0, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2111,7 +2111,7 @@ TEST_F(DeclarableOpsTests9, prelu_test1) {
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2129,7 +2129,7 @@ TEST_F(DeclarableOpsTests9, prelu_test2) {
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2147,7 +2147,7 @@ TEST_F(DeclarableOpsTests9, prelu_test3) {
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2165,7 +2165,7 @@ TEST_F(DeclarableOpsTests9, prelu_test4) {
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2183,7 +2183,7 @@ TEST_F(DeclarableOpsTests9, prelu_test5) {
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2201,7 +2201,7 @@ TEST_F(DeclarableOpsTests9, prelu_test6) {
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {1,0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {1,0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2220,7 +2220,7 @@ TEST_F(DeclarableOpsTests9, prelu_test7) {
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {1,0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {1,0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2238,7 +2238,7 @@ TEST_F(DeclarableOpsTests9, prelu_test8) {
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {1,0,1,0,1,0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {1,0,1,0,1,0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests9, prelu_test9) {
auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2274,7 +2274,7 @@ TEST_F(DeclarableOpsTests9, prelu_test10) {
auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2299,7 +2299,7 @@ TEST_F(DeclarableOpsTests9, prelu_test11) {
62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {1,3}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {1,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2323,7 +2323,7 @@ TEST_F(DeclarableOpsTests9, prelu_test12) {
53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {-1, 2}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {-1, 2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2347,7 +2347,7 @@ TEST_F(DeclarableOpsTests9, prelu_test13) {
53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {-1, 2}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {-1, 2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2372,7 +2372,7 @@ TEST_F(DeclarableOpsTests9, prelu_test14) {
55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
nd4j::ops::prelu op; nd4j::ops::prelu op;
auto result = op.execute({&x, &alpha}, {}, {-2}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x, &alpha}, {}, {-2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2391,7 +2391,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
nd4j::ops::thresholdedrelu op; nd4j::ops::thresholdedrelu op;
auto result = op.execute({&x}, {theta}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {theta});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2411,7 +2411,7 @@ TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
nd4j::ops::compare_and_bitpack op; nd4j::ops::compare_and_bitpack op;
auto result = op.execute({&x, &threshold}, {}, {}, {}); auto result = op.evaluate({&x, &threshold}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
// output->printIndexedBuffer("Packed to uint8"); // output->printIndexedBuffer("Packed to uint8");
@ -2429,7 +2429,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
nd4j::ops::thresholdedrelu op; nd4j::ops::thresholdedrelu op;
auto result = op.execute({&x}, {theta}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.evaluate({&x}, {theta});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0); auto output = result->at(0);
@ -2544,7 +2544,7 @@ TEST_F(DeclarableOpsTests9, multiply_test1) {
y.linspace(0.1f, 0.1f); y.linspace(0.1f, 0.1f);
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2564,7 +2564,7 @@ TEST_F(DeclarableOpsTests9, multiply_test2) {
// y.linspace(0.1f, 0.1f); // y.linspace(0.1f, 0.1f);
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&y, &x}, {}, {}); auto result = op.evaluate({&y, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2584,7 +2584,7 @@ TEST_F(DeclarableOpsTests9, multiply_test3) {
y.linspace(0.1f, 0.1f); y.linspace(0.1f, 0.1f);
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2603,7 +2603,7 @@ TEST_F(DeclarableOpsTests9, multiply_test4) {
x.linspace(1.f); x.linspace(1.f);
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2621,7 +2621,7 @@ TEST_F(DeclarableOpsTests9, multiply_test5) {
auto exp = NDArrayFactory::create<double>(0.1f); auto exp = NDArrayFactory::create<double>(0.1f);
nd4j::ops::multiply op; nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -2643,8 +2643,8 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test1) {
nd4j::ops::multiply opFF; nd4j::ops::multiply opFF;
nd4j::ops::multiply_bp opBP; nd4j::ops::multiply_bp opBP;
auto resFF = opFF.execute({&x, &y}, {}, {}); auto resFF = opFF.evaluate({&x, &y}, {}, {});
auto resBP = opBP.execute({&x, &y, &dLdz}, {}, {}); auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {});
// resFF->at(0)->printIndexedBuffer("Multiply 1x1"); // resFF->at(0)->printIndexedBuffer("Multiply 1x1");
// resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x"); // resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x");
// resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/ // resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/
@ -2800,7 +2800,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_2) {
// resFF->at(0)->printIndexedBuffer("FF floormod"); // resFF->at(0)->printIndexedBuffer("FF floormod");
// delete resFF; // delete resFF;
nd4j::ops::floormod_bp opBP; nd4j::ops::floormod_bp opBP;
auto resBP = opBP.execute({&x, &y, &dLdz}, {}, {}); auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {});
ASSERT_TRUE(resBP->status() == ND4J_STATUS_OK); ASSERT_TRUE(resBP->status() == ND4J_STATUS_OK);
// resBP->at(0)->printIndexedBuffer("BP floormod /dx"); // resBP->at(0)->printIndexedBuffer("BP floormod /dx");
@ -2832,10 +2832,10 @@ TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) {
dLdzZ.assign(3); dLdzZ.assign(3);
nd4j::ops::dynamic_partition op1; nd4j::ops::dynamic_partition op1;
auto res1 = op1.execute({&x, &y}, {}, {3}); auto res1 = op1.evaluate({&x, &y}, {}, {3});
nd4j::ops::dynamic_partition_bp op2; nd4j::ops::dynamic_partition_bp op2;
auto res2 = op2.execute({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); auto res2 = op2.evaluate({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3});
ASSERT_TRUE(res2->status() == ND4J_STATUS_OK); ASSERT_TRUE(res2->status() == ND4J_STATUS_OK);
ASSERT_TRUE(res2->size() == 2); ASSERT_TRUE(res2->size() == 2);
// printf("How many: %ul\n", res2->size()); // printf("How many: %ul\n", res2->size());
@ -2879,7 +2879,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
eps.assign(1.f); eps.assign(1.f);
nd4j::ops::floormod_bp op; nd4j::ops::floormod_bp op;
auto result = op.execute({&x, &y, &eps}, {}, {}); auto result = op.evaluate({&x, &y, &eps}, {}, {});
ASSERT_TRUE(result->size() == 2); ASSERT_TRUE(result->size() == 2);
auto gradX = result->at(0); auto gradX = result->at(0);
@ -2924,7 +2924,7 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {}); const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {});
nd4j::ops::gruCell op; nd4j::ops::gruCell op;
auto results = op.execute(argsHolderFF); auto results = op.evaluate(argsHolderFF);
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2964,7 +2964,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
nd4j::ops::cholesky op; nd4j::ops::cholesky op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto res = result->at(0); auto res = result->at(0);
// res->printIndexedBuffer("Output for Cholesky1"); // res->printIndexedBuffer("Output for Cholesky1");
@ -2980,7 +2980,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
nd4j::ops::cholesky op; nd4j::ops::cholesky op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto res = result->at(0); auto res = result->at(0);
// res->printIndexedBuffer("Output for Cholesky 2"); // res->printIndexedBuffer("Output for Cholesky 2");
@ -2996,7 +2996,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) {
nd4j::ops::cholesky op; nd4j::ops::cholesky op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto res = result->at(0); auto res = result->at(0);
// res->printIndexedBuffer("Output for Cholesky 3"); // res->printIndexedBuffer("Output for Cholesky 3");

View File

@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
nd4j::ops::choose op; nd4j::ops::choose op;
//greater than test //greater than test
auto result = op.execute({&x}, {0.0},{3}); auto result = op.evaluate({&x}, {0.0},{3});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(1); auto z = result->at(1);

View File

@ -66,7 +66,7 @@ TEST_F(EmptyTests, Test_Concat_1) {
ASSERT_TRUE(empty->isEmpty()); ASSERT_TRUE(empty->isEmpty());
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({empty, vector}, {}, {0}); auto result = op.evaluate({empty, vector}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -91,7 +91,7 @@ TEST_F(EmptyTests, Test_Concat_2) {
ASSERT_TRUE(empty->isEmpty()); ASSERT_TRUE(empty->isEmpty());
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({empty, scalar1, scalar2}, {}, {0}); auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -116,7 +116,7 @@ TEST_F(EmptyTests, Test_Concat_3) {
ASSERT_TRUE(empty.isEmpty()); ASSERT_TRUE(empty.isEmpty());
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&empty, &scalar1, &scalar2}, {}, {0}); auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -135,7 +135,7 @@ TEST_F(EmptyTests, Test_Concat_4) {
ASSERT_TRUE(empty.isEmpty()); ASSERT_TRUE(empty.isEmpty());
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&scalar1, &empty, &scalar2}, {}, {0}); auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -151,7 +151,7 @@ TEST_F(EmptyTests, Test_Reshape_1) {
auto empty = NDArrayFactory::empty_<int>(); auto empty = NDArrayFactory::empty_<int>();
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&vector, empty}, {}, {}); auto result = op.evaluate({&vector, empty}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -167,7 +167,7 @@ TEST_F(EmptyTests, Test_Reshape_2) {
auto empty = NDArrayFactory::empty_<Nd4jLong>(); auto empty = NDArrayFactory::empty_<Nd4jLong>();
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&vector, empty}, {}, {}, {}, true); auto result = op.evaluate({&vector, empty}, {}, {}, {}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -184,7 +184,7 @@ TEST_F(EmptyTests, Test_Reshape_3) {
auto e = NDArrayFactory::create<float>('c', {10, 0}); auto e = NDArrayFactory::create<float>('c', {10, 0});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -213,7 +213,7 @@ TEST_F(EmptyTests, test_empty_scatter_1) {
x.linspace(1.0f); x.linspace(1.0f);
nd4j::ops::scatter_upd op; nd4j::ops::scatter_upd op;
auto result = op.execute({&x, &indices, &updates}, {}, {}, {true}); auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -311,12 +311,12 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
auto e1 = NDArrayFactory::create<float>('c', {0, 1}); auto e1 = NDArrayFactory::create<float>('c', {0, 1});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result0 = op.execute({&x0, &shape0}, {}, {}); auto result0 = op.evaluate({&x0, &shape0}, {}, {});
ASSERT_EQ(Status::OK(), result0->status()); ASSERT_EQ(Status::OK(), result0->status());
auto z0 = result0->at(0); auto z0 = result0->at(0);
ASSERT_EQ(e0, *z0); ASSERT_EQ(e0, *z0);
auto result1 = op.execute({&x1, &shape1}, {}, {}); auto result1 = op.evaluate({&x1, &shape1}, {}, {});
ASSERT_EQ(Status::OK(), result1->status()); ASSERT_EQ(Status::OK(), result1->status());
auto z1 = result1->at(0); auto z1 = result1->at(0);
ASSERT_EQ(e1, *z1); ASSERT_EQ(e1, *z1);
@ -332,7 +332,7 @@ TEST_F(EmptyTests, test_empty_matmul_1) {
auto e = NDArrayFactory::create<float>('c', {0, 0}); auto e = NDArrayFactory::create<float>('c', {0, 0});
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -347,7 +347,7 @@ TEST_F(EmptyTests, test_empty_matmul_2) {
auto e = NDArrayFactory::create<float>('c', {1, 0, 0}); auto e = NDArrayFactory::create<float>('c', {1, 0, 0});
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);

View File

@ -46,7 +46,7 @@ TEST_F(IndexingTests, StridedSlice_1) {
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -65,7 +65,7 @@ TEST_F(IndexingTests, StridedSlice_2) {
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1}); auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -85,7 +85,7 @@ TEST_F(IndexingTests, StridedSlice_3) {
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2}); auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -108,7 +108,7 @@ TEST_F(IndexingTests, SimpleSlice_1) {
nd4j::ops::slice op; nd4j::ops::slice op;
auto result = op.execute({&input}, {}, {1,0,0, 1,1,3}); auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -134,7 +134,7 @@ TEST_F(IndexingTests, SimpleSlice_2) {
nd4j::ops::slice op; nd4j::ops::slice op;
auto result = op.execute({&input}, {}, {1,0,0, 1,2,3}); auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -159,7 +159,7 @@ TEST_F(IndexingTests, SimpleSlice_3) {
nd4j::ops::slice op; nd4j::ops::slice op;
auto result = op.execute({&input}, {}, {1,0,0, 2,1,3}); auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -179,7 +179,7 @@ TEST_F(IndexingTests, SimpleSlice_4) {
nd4j::ops::slice op; nd4j::ops::slice op;
auto result = op.execute({&input, &start, &stop}, {}, {}); auto result = op.evaluate({&input, &start, &stop});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -202,7 +202,7 @@ TEST_F(IndexingTests, MaskedSlice_0) {
exp.assign(2.0f); exp.assign(2.0f);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix}, {}, {0,0,0,0,0, 1, 2, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -228,7 +228,7 @@ TEST_F(IndexingTests, MaskedSlice_00) {
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -252,7 +252,7 @@ TEST_F(IndexingTests, MaskedSlice_1) {
exp.assign(2.0f); exp.assign(2.0f);
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix}, {}, {0,0,0,0,1, 1, 2, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -273,7 +273,7 @@ TEST_F(IndexingTests, MaskedSlice_2) {
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -293,7 +293,7 @@ TEST_F(IndexingTests, MaskedSlice_3) {
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -313,7 +313,7 @@ TEST_F(IndexingTests, MaskedSlice_4) {
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -336,7 +336,7 @@ TEST_F(IndexingTests, Live_Slice_1) {
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3}); auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -359,7 +359,7 @@ TEST_F(IndexingTests, Test_StridedSlice_1) {
auto exp = NDArrayFactory::create<float>({5.0f, 2}); auto exp = NDArrayFactory::create<float>({5.0f, 2});
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -379,7 +379,7 @@ TEST_F(IndexingTests, Test_StridedSlice_2) {
auto exp = NDArrayFactory::create<float>('c', {1}, {5.0}); auto exp = NDArrayFactory::create<float>('c', {1}, {5.0});
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -402,7 +402,7 @@ TEST_F(IndexingTests, Test_StridedSlice_3) {
auto exp = NDArrayFactory::create<float>('c', {1}, {6.0}); auto exp = NDArrayFactory::create<float>('c', {1}, {6.0});
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -423,7 +423,7 @@ TEST_F(IndexingTests, Test_StridedSlice_4) {
auto exp = NDArrayFactory::create<float>({5.0f, 2}); auto exp = NDArrayFactory::create<float>({5.0f, 2});
nd4j::ops::strided_slice op; nd4j::ops::strided_slice op;
auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1}); // auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());

View File

@ -62,7 +62,7 @@ TEST_F(LegacyOpsTests, TransformTests_2) {
exp.assign(-1.0); exp.assign(-1.0);
nd4j::ops::LegacyTransformSameOp op(transform::Neg); // Neg nd4j::ops::LegacyTransformSameOp op(transform::Neg); // Neg
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -119,7 +119,7 @@ TEST_F(LegacyOpsTests, PWT_Tests_2) {
exp.assign(6.0); exp.assign(6.0);
nd4j::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply nd4j::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
auto z = result->at(0); auto z = result->at(0);
@ -152,7 +152,7 @@ TEST_F(LegacyOpsTests, Scalar_Test_2) {
auto y = NDArrayFactory::create<float>(5.0f); auto y = NDArrayFactory::create<float>(5.0f);
nd4j::ops::LegacyScalarOp op(scalar::Add, y); nd4j::ops::LegacyScalarOp op(scalar::Add, y);
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
auto z = result->at(0); auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -167,7 +167,7 @@ TEST_F(LegacyOpsTests, ReduceTests_1) {
int opNum = reduce::Sum; int opNum = reduce::Sum;
nd4j::ops::LegacyReduceSameOp op(opNum); nd4j::ops::LegacyReduceSameOp op(opNum);
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -186,7 +186,7 @@ TEST_F(LegacyOpsTests, ReduceTests_2) {
nd4j::ops::LegacyReduceSameOp op(reduce::Sum); nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1}); auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
auto result = op.execute({&x, &axis}, {}, {}); auto result = op.evaluate({&x, &axis}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -208,7 +208,7 @@ TEST_F(LegacyOpsTests, ReduceTests_3) {
nd4j::ops::LegacyReduceSameOp op(reduce::Sum); nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
auto result = op.execute({&x, &indices}, {}, {}); auto result = op.evaluate({&x, &indices}, {}, {});
auto z = result->at(0); auto z = result->at(0);
auto exp = x.reduceAlongDimension(reduce::Sum,{1}); auto exp = x.reduceAlongDimension(reduce::Sum,{1});
@ -228,7 +228,7 @@ TEST_F(LegacyOpsTests, ReduceTests_4) {
nd4j::ops::LegacyReduceSameOp op(reduce::Sum); nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
auto result = op.execute({&x, &indices}, {}, {}, {true}); auto result = op.evaluate({&x, &indices}, {}, {}, {true});
auto z = result->at(0); auto z = result->at(0);
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
// indices.printShapeInfo("Indices shape"); // indices.printShapeInfo("Indices shape");
@ -247,7 +247,7 @@ TEST_F(LegacyOpsTests, ReduceTests_5) {
int opNum = reduce::Mean; int opNum = reduce::Mean;
nd4j::ops::LegacyReduceFloatOp op(opNum); nd4j::ops::LegacyReduceFloatOp op(opNum);
ResultSet* result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -266,7 +266,7 @@ TEST_F(LegacyOpsTests, ReduceTests_6) {
auto axis = NDArrayFactory::create<int>('c', {1}, {1}); auto axis = NDArrayFactory::create<int>('c', {1}, {1});
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
auto result = op.execute({&x, &axis}, {}, {}); auto result = op.evaluate({&x, &axis}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -288,7 +288,7 @@ TEST_F(LegacyOpsTests, ReduceTests_7) {
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
auto result = op.execute({&x, &indices}, {}, {}); auto result = op.evaluate({&x, &indices}, {}, {});
auto z = result->at(0); auto z = result->at(0);
auto exp = x.reduceAlongDimension(reduce::Mean,{1}); auto exp = x.reduceAlongDimension(reduce::Mean,{1});
@ -308,7 +308,7 @@ TEST_F(LegacyOpsTests, ReduceTests_8) {
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
auto result = op.execute({&x, &indices}, {}, {}, {true}); auto result = op.evaluate({&x, &indices}, {}, {}, {true});
auto z = result->at(0); auto z = result->at(0);
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
@ -329,7 +329,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_1) {
nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax);
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());
@ -349,7 +349,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
auto exp = NDArrayFactory::create<Nd4jLong>({4,4,4,4,4}); auto exp = NDArrayFactory::create<Nd4jLong>({4,4,4,4,4});
nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax);
auto result = op.execute({&x, &indices}, {}, {}); auto result = op.evaluate({&x, &indices}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result->size());

View File

@ -133,7 +133,7 @@ TEST_F(MultiDataTypeTests, Basic_Test_7) {
auto e = NDArrayFactory::create<float>('c', {2, 3}, {0.f, 2.f, 4.f, 6.f, 8.f, 10.f}); auto e = NDArrayFactory::create<float>('c', {2, 3}, {0.f, 2.f, 4.f, 6.f, 8.f, 10.f});
nd4j::ops::add op; nd4j::ops::add op;
auto result = op.execute({&x, &y},{}, {}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);

View File

@ -65,7 +65,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
auto inferenceVector = NDArrayFactory::empty<float>(); auto inferenceVector = NDArrayFactory::empty<float>();
nd4j::ops::skipgram op; nd4j::ops::skipgram op;
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto row0 = syn0({0,1, 0,0}, true); auto row0 = syn0({0,1, 0,0}, true);
@ -106,7 +106,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
auto inferenceVector = NDArrayFactory::empty<float>(); auto inferenceVector = NDArrayFactory::empty<float>();
nd4j::ops::skipgram op; nd4j::ops::skipgram op;
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto row0 = syn0({0,1, 0,0}, true); auto row0 = syn0({0,1, 0,0}, true);
@ -157,8 +157,8 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
auto inferenceVector = NDArrayFactory::empty<float>(); auto inferenceVector = NDArrayFactory::empty<float>();
nd4j::ops::skipgram op; nd4j::ops::skipgram op;
auto result0 = op.execute({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true); auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
auto result1 = op.execute({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true); auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result0->status()); ASSERT_EQ(Status::OK(), result0->status());
auto row00 = syn00({0,1, 0,0}, true); auto row00 = syn00({0,1, 0,0}, true);
@ -191,7 +191,7 @@ TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
auto inferenceVector = NDArrayFactory::empty<float>(); auto inferenceVector = NDArrayFactory::empty<float>();
nd4j::ops::skipgram op; nd4j::ops::skipgram op;
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
@ -226,7 +226,7 @@ TEST_F(NlpTests, basic_sg_ns_test_1) {
auto inferenceVector = NDArrayFactory::empty<float>(); auto inferenceVector = NDArrayFactory::empty<float>();
nd4j::ops::skipgram op; nd4j::ops::skipgram op;
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto row0 = syn0({1,2, 0,0}, true); auto row0 = syn0({1,2, 0,0}, true);
@ -268,7 +268,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
auto inferenceVector = NDArrayFactory::empty<float>(); auto inferenceVector = NDArrayFactory::empty<float>();
nd4j::ops::cbow op; nd4j::ops::cbow op;
auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, true); auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto row_s0_0 = syn0({0,1, 0,0}, true); auto row_s0_0 = syn0({0,1, 0,0}, true);
@ -322,7 +322,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
auto inferenceVector = NDArrayFactory::empty<float>(); auto inferenceVector = NDArrayFactory::empty<float>();
nd4j::ops::cbow op; nd4j::ops::cbow op;
auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, true); auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto row_s0_0 = syn0({0,1, 0,0}, true); auto row_s0_0 = syn0({0,1, 0,0}, true);
@ -371,7 +371,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
expTable.assign(0.5); expTable.assign(0.5);
nd4j::ops::skipgram op; nd4j::ops::skipgram op;
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto row0 = syn0({0,1, 0,0}, true); auto row0 = syn0({0,1, 0,0}, true);
@ -415,7 +415,7 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
negTable.linspace(0.0); negTable.linspace(0.0);
nd4j::ops::skipgram op; nd4j::ops::skipgram op;
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto row0 = syn0({0,0, 0,0}, true); auto row0 = syn0({0,0, 0,0}, true);
@ -452,7 +452,7 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
auto inferenceVector = NDArrayFactory::empty<float>(); auto inferenceVector = NDArrayFactory::empty<float>();
nd4j::ops::cbow op; nd4j::ops::cbow op;
auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, true); auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto exp0 = NDArrayFactory::create<float>('c', {1, 10}); auto exp0 = NDArrayFactory::create<float>('c', {1, 10});

View File

@ -41,7 +41,7 @@ TEST_F(ParityOpsTests, TestZeroAs1) {
nd4j::ops::zeros_as op; nd4j::ops::zeros_as op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
auto z = result->at(0); auto z = result->at(0);
@ -60,7 +60,7 @@ TEST_F(ParityOpsTests, TestMaximum1) {
nd4j::ops::maximum op; nd4j::ops::maximum op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
auto z = result->at(0); auto z = result->at(0);
@ -80,7 +80,7 @@ TEST_F(ParityOpsTests, TestMinimum1) {
nd4j::ops::minimum op; nd4j::ops::minimum op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
auto z = result->at(0); auto z = result->at(0);
@ -99,7 +99,7 @@ TEST_F(ParityOpsTests, TestTear1) {
nd4j::ops::tear op; nd4j::ops::tear op;
auto result = op.execute({&input}, {}, {1}); auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(10, result->size()); ASSERT_EQ(10, result->size());
@ -119,7 +119,7 @@ TEST_F(ParityOpsTests, TestUnstack1) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {0}); auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(10, result->size()); ASSERT_EQ(10, result->size());
@ -141,7 +141,7 @@ TEST_F(ParityOpsTests, TestUnstack2) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {2}); auto result = op.evaluate({&input}, {}, {2});
ASSERT_EQ(6, result->size()); ASSERT_EQ(6, result->size());
@ -158,7 +158,7 @@ TEST_F(ParityOpsTests, TestUnstack3) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {2}); auto result = op.evaluate({&input}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -177,7 +177,7 @@ TEST_F(ParityOpsTests, TestUnstack4) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {1}); auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -195,7 +195,7 @@ TEST_F(ParityOpsTests, TestUnstack5) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {0}); auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -213,7 +213,7 @@ TEST_F(ParityOpsTests, TestUnstack6) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {0}); auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -231,7 +231,7 @@ TEST_F(ParityOpsTests, TestUnstack7) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {1}); auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -249,7 +249,7 @@ TEST_F(ParityOpsTests, TestUnstack8) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {0}); auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -267,7 +267,7 @@ TEST_F(ParityOpsTests, TestUnstack9) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {1}); auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -286,7 +286,7 @@ TEST_F(ParityOpsTests, TestUnstack10) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {0}); auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.isSameShape(result->at(0)));
@ -304,7 +304,7 @@ TEST_F(ParityOpsTests, TestUnstack11) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {2}); auto result = op.evaluate({&input}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.isSameShape(result->at(0)));
@ -320,7 +320,7 @@ TEST_F(ParityOpsTests, TestUnstack12) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
auto result = op.execute({&input}, {}, {1}); auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_TRUE(result->size() == 0); ASSERT_TRUE(result->size() == 0);
@ -334,7 +334,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest1) {
auto reshaped = input.reshape('c', {5, 1, 5}); auto reshaped = input.reshape('c', {5, 1, 5});
nd4j::ops::expand_dims op; nd4j::ops::expand_dims op;
auto result = op.execute({&input}, {}, {1}); auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -353,7 +353,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest2) {
auto reshaped = input.reshape('c', {1, 3, 4}); auto reshaped = input.reshape('c', {1, 3, 4});
nd4j::ops::expand_dims op; nd4j::ops::expand_dims op;
auto result = op.execute({&input}, {}, {0}); auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -372,7 +372,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest3) {
auto reshaped = input.reshape('c', {3, 1, 4}); auto reshaped = input.reshape('c', {3, 1, 4});
nd4j::ops::expand_dims op; nd4j::ops::expand_dims op;
auto result = op.execute({&input}, {}, {-2}); auto result = op.evaluate({&input}, {}, {-2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -390,7 +390,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest4) {
auto reshaped = input.reshape('c', {1, 3, 4}); auto reshaped = input.reshape('c', {1, 3, 4});
nd4j::ops::expand_dims op; nd4j::ops::expand_dims op;
auto result = op.execute({&input}, {}, {-3}); auto result = op.evaluate({&input}, {}, {-3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -408,7 +408,7 @@ TEST_F(ParityOpsTests, Test_Shape_1) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4}, {3, 4, 5, 6}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {4}, {3, 4, 5, 6});
nd4j::ops::shape_of op; nd4j::ops::shape_of op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -426,7 +426,7 @@ TEST_F(ParityOpsTests, Test_Equals_1) {
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 0, 1, 0, 1}); auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 0, 1, 0, 1});
nd4j::ops::equals op; nd4j::ops::equals op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -444,7 +444,7 @@ TEST_F(ParityOpsTests, Test_NotEquals_1) {
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 1, 0, 1, 0}); auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 1, 0, 1, 0});
nd4j::ops::not_equals op; nd4j::ops::not_equals op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -461,7 +461,7 @@ TEST_F(ParityOpsTests, Test_Less_1) {
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 1, 0, 0, 0}); auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 1, 0, 0, 0});
nd4j::ops::less op; nd4j::ops::less op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -478,7 +478,7 @@ TEST_F(ParityOpsTests, Test_LessEquals_1) {
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 1, 1, 0, 0}); auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 1, 1, 0, 0});
nd4j::ops::less_equal op; nd4j::ops::less_equal op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -495,7 +495,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_1) {
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 1, 1, 1}); auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 1, 1, 1});
nd4j::ops::greater_equal op; nd4j::ops::greater_equal op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -512,7 +512,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_2) {
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 1, 1, 1}); auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 1, 1, 1});
nd4j::ops::greater_equal op; nd4j::ops::greater_equal op;
auto result = op.execute({&x, &y}, {}, {}, {}, false); auto result = op.evaluate({&x, &y}, {}, {}, {}, {}, false);
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -529,7 +529,7 @@ TEST_F(ParityOpsTests, Test_Greater_1) {
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 0, 1, 1}); auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 0, 1, 1});
nd4j::ops::greater op; nd4j::ops::greater op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -547,7 +547,7 @@ TEST_F(ParityOpsTests, Test_Where_1) {
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9}); auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9});
nd4j::ops::Where op; nd4j::ops::Where op;
auto result = op.execute({&mask, &x, &y}, {}, {}); auto result = op.evaluate({&mask, &x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -567,7 +567,7 @@ TEST_F(ParityOpsTests, Test_Where_2) {
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1});
nd4j::ops::Where op; nd4j::ops::Where op;
auto result = op.execute({&mask, &x, &y}, {}, {}); auto result = op.evaluate({&mask, &x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -584,7 +584,7 @@ TEST_F(ParityOpsTests, Test_Where_3) {
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2}); auto exp = NDArrayFactory::create<Nd4jLong>('c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2});
nd4j::ops::Where op; nd4j::ops::Where op;
auto result = op.execute({&mask}, {}, {}); auto result = op.evaluate({&mask}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -604,7 +604,7 @@ TEST_F(ParityOpsTests, Test_Select_1) {
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1});
nd4j::ops::select op; nd4j::ops::select op;
auto result = op.execute({&mask, &x, &y}, {}, {}); auto result = op.evaluate({&mask, &x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -622,7 +622,7 @@ TEST_F(ParityOpsTests, Test_Select_2) {
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1, 8, 3, 6}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1, 8, 3, 6});
nd4j::ops::select op; nd4j::ops::select op;
auto result = op.execute({&mask, &x, &y}, {}, {}); auto result = op.evaluate({&mask, &x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -641,7 +641,7 @@ TEST_F(ParityOpsTests, Test_Select_3) {
auto exp = NDArrayFactory::create<float>('c', {1, 1}, {2}); auto exp = NDArrayFactory::create<float>('c', {1, 1}, {2});
nd4j::ops::select op; nd4j::ops::select op;
auto result = op.execute({&mask, &x, &y}, {}, {}); auto result = op.evaluate({&mask, &x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -660,7 +660,7 @@ TEST_F(ParityOpsTests, Test_Reshape_TF_1) {
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x, &shape}, {}, {}); auto result = op.evaluate({&x, &shape}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -677,7 +677,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) {
auto bias = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5}); auto bias = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
nd4j::ops::biasadd op; nd4j::ops::biasadd op;
auto result = op.execute({&x, &bias}, {}, {}); auto result = op.evaluate({&x, &bias}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -697,7 +697,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {2, 3, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {2, 3, 3, 4});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -715,7 +715,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {2, 3, 4, 5}); auto exp = NDArrayFactory::create<float>('c', {1, 4}, {2, 3, 4, 5});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&vec, &idc, &updates}, {}, {}); auto result = op.evaluate({&vec, &idc, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -732,7 +732,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -749,7 +749,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -766,7 +766,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -784,7 +784,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -801,7 +801,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
auto exp = NDArrayFactory::create<float>('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f}); auto exp = NDArrayFactory::create<float>('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -850,7 +850,7 @@ TEST_F(ParityOpsTests, scatterMax_test1) {
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4});
nd4j::ops::scatter_max op; nd4j::ops::scatter_max op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -867,7 +867,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) {
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {10, 2, 30, 4}); auto exp = NDArrayFactory::create<float>('c', {1, 4}, {10, 2, 30, 4});
nd4j::ops::scatter_max op; nd4j::ops::scatter_max op;
auto result = op.execute({&vec, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -884,7 +884,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8});
nd4j::ops::scatter_max op; nd4j::ops::scatter_max op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -901,7 +901,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8});
nd4j::ops::scatter_max op; nd4j::ops::scatter_max op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {true}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {true}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -918,7 +918,7 @@ TEST_F(ParityOpsTests, scatterMax_test5) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10});
nd4j::ops::scatter_max op; nd4j::ops::scatter_max op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -935,7 +935,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2});
nd4j::ops::scatter_max op; nd4j::ops::scatter_max op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -953,7 +953,7 @@ TEST_F(ParityOpsTests, scatterMin_test1) {
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-1, 1, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-1, 1, 3, 4});
nd4j::ops::scatter_min op; nd4j::ops::scatter_min op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -970,7 +970,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) {
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 3, 1}); auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 3, 1});
nd4j::ops::scatter_min op; nd4j::ops::scatter_min op;
auto result = op.execute({&vec, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -987,7 +987,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8});
nd4j::ops::scatter_min op; nd4j::ops::scatter_min op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1004,7 +1004,7 @@ TEST_F(ParityOpsTests, scatterMin_test4) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8});
nd4j::ops::scatter_min op; nd4j::ops::scatter_min op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1036,7 +1036,7 @@ TEST_F(ParityOpsTests, scatterND_test1) {
auto exp = NDArrayFactory::create<float>('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); auto exp = NDArrayFactory::create<float>('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {false, true}); auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1058,7 +1058,7 @@ TEST_F(ParityOpsTests, scatterND_test2) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {}); auto result = op.evaluate({&indices, &updates, &shape}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1083,7 +1083,7 @@ TEST_F(ParityOpsTests, scatterND_test3) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {false, true}); auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1103,7 +1103,7 @@ TEST_F(ParityOpsTests, scatterND_test4) {
auto exp = NDArrayFactory::create<float>('c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f}); auto exp = NDArrayFactory::create<float>('c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f});
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {}); auto result = op.evaluate({&indices, &updates, &shape}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1123,7 +1123,7 @@ TEST_F(ParityOpsTests, scatterND_test5) {
auto exp = NDArrayFactory::create<float>('c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); auto exp = NDArrayFactory::create<float>('c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true}); auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1150,7 +1150,7 @@ TEST_F(ParityOpsTests, scatterND_test6) {
updates.linspace(1); updates.linspace(1);
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {}); auto result = op.evaluate({&indices, &updates, &shape}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1177,7 +1177,7 @@ TEST_F(ParityOpsTests, scatterND_test7) {
updates.linspace(1); updates.linspace(1);
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true, true}); auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1198,7 +1198,7 @@ TEST_F(ParityOpsTests, scatterND_test8) {
auto exp = NDArrayFactory::create<float>('c', {6,4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); auto exp = NDArrayFactory::create<float>('c', {6,4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {true}); auto result = op.evaluate({&indices, &updates, &shape}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1233,7 +1233,7 @@ TEST_F(ParityOpsTests, scatterND_add_test1) {
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f}); auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f});
nd4j::ops::scatter_nd_add op; nd4j::ops::scatter_nd_add op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1256,7 +1256,7 @@ TEST_F(ParityOpsTests, scatterND_add_test2) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_add op; nd4j::ops::scatter_nd_add op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1280,7 +1280,7 @@ TEST_F(ParityOpsTests, scatterND_add_test3) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_add op; nd4j::ops::scatter_nd_add op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1307,7 +1307,7 @@ TEST_F(ParityOpsTests, scatterND_add_test4) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_add op; nd4j::ops::scatter_nd_add op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1343,7 +1343,7 @@ TEST_F(ParityOpsTests, scatterND_add_test5) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_add op; nd4j::ops::scatter_nd_add op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1376,7 +1376,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test1) {
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f}); auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f});
nd4j::ops::scatter_nd_sub op; nd4j::ops::scatter_nd_sub op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1399,7 +1399,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test2) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_sub op; nd4j::ops::scatter_nd_sub op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1424,7 +1424,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test3) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_sub op; nd4j::ops::scatter_nd_sub op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1451,7 +1451,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test4) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_sub op; nd4j::ops::scatter_nd_sub op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1487,7 +1487,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test5) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_sub op; nd4j::ops::scatter_nd_sub op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1508,7 +1508,7 @@ TEST_F(ParityOpsTests, scatterND_update_test1) {
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f}); auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f});
nd4j::ops::scatter_nd_update op; nd4j::ops::scatter_nd_update op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1531,7 +1531,7 @@ TEST_F(ParityOpsTests, scatterND_update_test2) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_update op; nd4j::ops::scatter_nd_update op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1555,7 +1555,7 @@ TEST_F(ParityOpsTests, scatterND_update_test3) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_update op; nd4j::ops::scatter_nd_update op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1583,7 +1583,7 @@ TEST_F(ParityOpsTests, scatterND_update_test4) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_update op; nd4j::ops::scatter_nd_update op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1619,7 +1619,7 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd_update op; nd4j::ops::scatter_nd_update op;
auto result = op.execute({&input, &indices, &updates}, {}, {}); auto result = op.evaluate({&input, &indices, &updates}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1652,7 +1652,7 @@ TEST_F(ParityOpsTests, scatter_update_1) {
NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32); NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32);
nd4j::ops::scatter_update op; nd4j::ops::scatter_update op;
auto results = op.execute({&x, &updates}, {}, {6, 1,1, 2,1,0}); auto results = op.evaluate({&x, &updates}, {}, {6, 1,1, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
// x.printBuffer(); // x.printBuffer();
@ -1672,7 +1672,7 @@ TEST_F(ParityOpsTests, scatter_update_2) {
NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32); NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32);
nd4j::ops::scatter_update op; nd4j::ops::scatter_update op;
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,1,0}); auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1691,7 +1691,7 @@ TEST_F(ParityOpsTests, scatter_update_3) {
NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32); NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32);
nd4j::ops::scatter_update op; nd4j::ops::scatter_update op;
auto results = op.execute({&x, &updates}, {}, {6, 2,1,2, 2,1,0}); auto results = op.evaluate({&x, &updates}, {}, {6, 2,1,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1710,7 +1710,7 @@ TEST_F(ParityOpsTests, scatter_update_4) {
NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32); NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32);
nd4j::ops::scatter_update op; nd4j::ops::scatter_update op;
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,3,0}); auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,3,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());

View File

@ -257,7 +257,7 @@ TEST_F(RNGTests, Test_Gaussian_21) {
ASSERT_FALSE(x0.equalsTo(nexp1)); ASSERT_FALSE(x0.equalsTo(nexp1));
ASSERT_FALSE(x0.equalsTo(nexp2)); ASSERT_FALSE(x0.equalsTo(nexp2));
nd4j::ops::moments op; nd4j::ops::moments op;
auto result = op.execute({&x0}, {}, {}); auto result = op.evaluate({&x0}, {}, {});
//x0.printIndexedBuffer("X0 Normal"); //x0.printIndexedBuffer("X0 Normal");
//x1.printIndexedBuffer("X1 Normal"); //x1.printIndexedBuffer("X1 Normal");
ASSERT_TRUE(result->status() == Status::OK()); ASSERT_TRUE(result->status() == Status::OK());
@ -289,7 +289,7 @@ TEST_F(RNGTests, Test_Gaussian_22) {
ASSERT_FALSE(x0.equalsTo(nexp1)); ASSERT_FALSE(x0.equalsTo(nexp1));
ASSERT_FALSE(x0.equalsTo(nexp2)); ASSERT_FALSE(x0.equalsTo(nexp2));
nd4j::ops::moments op; nd4j::ops::moments op;
auto result = op.execute({&x0}, {}, {}); auto result = op.evaluate({&x0}, {}, {});
//x0.printIndexedBuffer("X0 Normal"); //x0.printIndexedBuffer("X0 Normal");
//x1.printIndexedBuffer("X1 Normal"); //x1.printIndexedBuffer("X1 Normal");
ASSERT_TRUE(result->status() == Status::OK()); ASSERT_TRUE(result->status() == Status::OK());
@ -412,14 +412,14 @@ TEST_F(RNGTests, Test_Truncated_21) {
ASSERT_NEAR(mean.e<float>(0), 1.f, 0.002); ASSERT_NEAR(mean.e<float>(0), 1.f, 0.002);
ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5); ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5);
nd4j::ops::moments op; nd4j::ops::moments op;
auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
// result->at(0)->printBuffer("MEAN"); // result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE"); // result->at(1)->printBuffer("VARIANCE");
delete result; delete result;
nd4j::ops::reduce_min minOp; nd4j::ops::reduce_min minOp;
nd4j::ops::reduce_max maxOp; nd4j::ops::reduce_max maxOp;
auto minRes = minOp.execute({&x1}, {}, {}, {}); auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.execute({&x0}, {}, {}, {}); auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated"); // minRes->at(0)->printBuffer("MIN for Truncated");
// maxRes->at(0)->printBuffer("MAX for Truncated"); // maxRes->at(0)->printBuffer("MAX for Truncated");
@ -459,14 +459,14 @@ TEST_F(RNGTests, Test_Truncated_22) {
ASSERT_NEAR(mean.e<float>(0), 2.f, 0.01); ASSERT_NEAR(mean.e<float>(0), 2.f, 0.01);
ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52); ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52);
nd4j::ops::moments op; nd4j::ops::moments op;
auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
// result->at(0)->printBuffer("MEAN"); // result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE"); // result->at(1)->printBuffer("VARIANCE");
delete result; delete result;
nd4j::ops::reduce_min minOp; nd4j::ops::reduce_min minOp;
nd4j::ops::reduce_max maxOp; nd4j::ops::reduce_max maxOp;
auto minRes = minOp.execute({&x1}, {}, {}, {}); auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.execute({&x0}, {}, {}, {}); auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated2"); // minRes->at(0)->printBuffer("MIN for Truncated2");
// maxRes->at(0)->printBuffer("MAX for Truncated2"); // maxRes->at(0)->printBuffer("MAX for Truncated2");
@ -506,14 +506,14 @@ TEST_F(RNGTests, Test_Truncated_23) {
ASSERT_NEAR(mean.e<float>(0), 0.f, 0.01); ASSERT_NEAR(mean.e<float>(0), 0.f, 0.01);
ASSERT_NEAR(deviation.e<float>(0), 1.f, 0.5); ASSERT_NEAR(deviation.e<float>(0), 1.f, 0.5);
nd4j::ops::moments op; nd4j::ops::moments op;
auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.evaluate({&x0});
// result->at(0)->printBuffer("MEAN"); // result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE"); // result->at(1)->printBuffer("VARIANCE");
delete result; delete result;
nd4j::ops::reduce_min minOp; nd4j::ops::reduce_min minOp;
nd4j::ops::reduce_max maxOp; nd4j::ops::reduce_max maxOp;
auto minRes = minOp.execute({&x1}, {}, {}, {}); auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.execute({&x0}, {}, {}, {}); auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated3"); // minRes->at(0)->printBuffer("MIN for Truncated3");
// maxRes->at(0)->printBuffer("MAX for Truncated3"); // maxRes->at(0)->printBuffer("MAX for Truncated3");
@ -686,7 +686,7 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
nd4j::ops::random_normal op; nd4j::ops::random_normal op;
auto result = op.execute({&x}, {0.0, 1.0f}, {}); auto result = op.evaluate({&x}, {0.0, 1.0f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -707,7 +707,7 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
nd4j::ops::random_bernoulli op; nd4j::ops::random_bernoulli op;
auto result = op.execute({&x}, {0.5f}, {}); auto result = op.evaluate({&x}, {0.5f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -728,7 +728,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
nd4j::ops::random_exponential op; nd4j::ops::random_exponential op;
auto result = op.execute({&x}, {0.25f}, {0}); auto result = op.evaluate({&x}, {0.25f}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -752,7 +752,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
nd4j::ops::random_exponential op; nd4j::ops::random_exponential op;
auto result = op.execute({&x, &y}, {0.25f}, {0}); auto result = op.evaluate({&x, &y}, {0.25f}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -776,7 +776,7 @@ TEST_F(RNGTests, Test_PoissonDistribution_1) {
nd4j::ops::random_poisson op; nd4j::ops::random_poisson op;
auto result = op.execute({&x, &la}, {}, {}); auto result = op.evaluate({&x, &la}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -796,7 +796,7 @@ TEST_F(RNGTests, Test_GammaDistribution_1) {
nd4j::ops::random_gamma op; nd4j::ops::random_gamma op;
auto result = op.execute({&x, &al}, {}, {}); auto result = op.evaluate({&x, &al}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -817,7 +817,7 @@ TEST_F(RNGTests, Test_GammaDistribution_2) {
be.assign(1.0); be.assign(1.0);
nd4j::ops::random_gamma op; nd4j::ops::random_gamma op;
auto result = op.execute({&x, &al, &be}, {}, {}); auto result = op.evaluate({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -838,7 +838,7 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
be.assign(2.0); be.assign(2.0);
nd4j::ops::random_gamma op; nd4j::ops::random_gamma op;
auto result = op.execute({&x, &al, &be}, {}, {}); auto result = op.evaluate({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -857,7 +857,7 @@ TEST_F(RNGTests, Test_UniformDistribution_04) {
nd4j::ops::randomuniform op; nd4j::ops::randomuniform op;
auto result = op.execute({&x, &al, &be}, {}, {DataType::INT32}); auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -878,7 +878,7 @@ namespace nd4j {
auto min = NDArrayFactory::create(0.0); auto min = NDArrayFactory::create(0.0);
auto max = NDArrayFactory::create(1.0); auto max = NDArrayFactory::create(1.0);
nd4j::ops::randomuniform op; nd4j::ops::randomuniform op;
op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, false); op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, {}, false);
list.emplace_back(arrayR); list.emplace_back(arrayR);
} }
@ -1013,14 +1013,14 @@ TEST_F(RNGTests, test_multinomial_1) {
nd4j::ops::random_multinomial op; nd4j::ops::random_multinomial op;
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, false) ); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, {}, false) );
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64); NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64);
auto result = op.execute({ &probsZ, &samples }, { }, { 1, INT64 }); auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
auto outputZ = result->at(0); auto outputZ = result->at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
@ -1038,7 +1038,7 @@ TEST_F(RNGTests, test_multinomial_2) {
nd4j::ops::random_multinomial op; nd4j::ops::random_multinomial op;
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
@ -1047,7 +1047,7 @@ TEST_F(RNGTests, test_multinomial_2) {
NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64); NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64);
rng.setStates(1234, 1234); rng.setStates(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, {}, false));
ASSERT_TRUE(expected2.isSameShape(output2)); ASSERT_TRUE(expected2.isSameShape(output2));
ASSERT_TRUE(expected2.equalsTo(output2)); ASSERT_TRUE(expected2.equalsTo(output2));
} }
@ -1061,10 +1061,10 @@ TEST_F(RNGTests, test_multinomial_3) {
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
nd4j::ops::random_multinomial op; nd4j::ops::random_multinomial op;
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false));
rng.setStates(1234, 1234); rng.setStates(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -1078,10 +1078,10 @@ TEST_F(RNGTests, test_multinomial_4) {
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
nd4j::ops::random_multinomial op; nd4j::ops::random_multinomial op;
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, {}, false));
rng.setStates(1234, 1234); rng.setStates(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, {}, false));
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -1101,7 +1101,7 @@ TEST_F(RNGTests, test_multinomial_5) {
NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64); NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64);
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false));
auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
auto mean = output.meanNumber(); auto mean = output.meanNumber();
@ -1115,7 +1115,7 @@ TEST_F(RNGTests, test_multinomial_5) {
ASSERT_TRUE(value >= 0 && value < ClassValue); ASSERT_TRUE(value >= 0 && value < ClassValue);
} }
auto resultR = op.execute({ &probs, &samples }, { }, { 1 }); auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 });
auto outputR = resultR->at(0); auto outputR = resultR->at(0);
ASSERT_EQ(Status::OK(), resultR->status()); ASSERT_EQ(Status::OK(), resultR->status());
@ -1148,7 +1148,7 @@ TEST_F(RNGTests, test_multinomial_6) {
// without seed // without seed
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
auto resultR = op.execute({ &probsR, &samples }, { }, { 0 }); auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 });
auto outputR = resultR->at(0); auto outputR = resultR->at(0);
ASSERT_EQ(Status::OK(), resultR->status()); ASSERT_EQ(Status::OK(), resultR->status());
@ -1180,7 +1180,7 @@ TEST_F(RNGTests, test_multinomial_6) {
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64); NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);

View File

@ -94,7 +94,7 @@ TEST_F(ScalarTests, Test_Concat_1) {
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&t, &u, &v}, {}, {0}); auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -114,7 +114,7 @@ TEST_F(ScalarTests, Test_Concat_2) {
auto exp = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5}); auto exp = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&t, &u, &v}, {}, {0}); auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -135,7 +135,7 @@ TEST_F(ScalarTests, Test_Concat_3) {
auto exp = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5}); auto exp = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&t, &u, &v}, {}, {0}); auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -154,7 +154,7 @@ TEST_F(ScalarTests, Test_ExpandDims_1) {
auto exp = NDArrayFactory::create<float>('c', {1}, {2.0f}); auto exp = NDArrayFactory::create<float>('c', {1}, {2.0f});
nd4j::ops::expand_dims op; nd4j::ops::expand_dims op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -171,7 +171,7 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
auto exp = NDArrayFactory::create<float>(2.0f); auto exp = NDArrayFactory::create<float>(2.0f);
nd4j::ops::squeeze op; nd4j::ops::squeeze op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -188,7 +188,7 @@ TEST_F(ScalarTests, Test_Reshape_1) {
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f}); auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {-99, 1, 1, 1}); auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -205,7 +205,7 @@ TEST_F(ScalarTests, Test_Permute_1) {
auto exp = NDArrayFactory::create<float>(3.0f); auto exp = NDArrayFactory::create<float>(3.0f);
nd4j::ops::permute op; nd4j::ops::permute op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -224,7 +224,7 @@ TEST_F(ScalarTests, Test_Stack_1) {
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&t, &u, &v}, {}, {0}); auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -243,7 +243,7 @@ TEST_F(ScalarTests, Test_Stack_2) {
auto exp = NDArrayFactory::create<float>('c', {4, 1, 1}, {1, 2, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {4, 1, 1}, {1, 2, 3, 4});
nd4j::ops::stack op; nd4j::ops::stack op;
auto result = op.execute({&t, &u, &v, &w}, {}, {0}); auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -265,7 +265,7 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) {
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4});
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&t, &u, &v, &w}, {}, {0}); auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -285,7 +285,7 @@ TEST_F(ScalarTests, Test_Concat_Scalar_2) {
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4});
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&t, &u, &v, &w}, {}, {1}); auto result = op.evaluate({&t, &u, &v, &w}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);

View File

@ -307,7 +307,7 @@ TEST_F(ShapeTests, Tests_Transpose_119_2) {
auto exp = x.transpose(); auto exp = x.transpose();
nd4j::ops::transpose op; nd4j::ops::transpose op;
auto result = op.execute({&x},{}, {}); auto result = op.evaluate({&x});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);

View File

@ -68,7 +68,7 @@ TEST_F(SingleDimTests, Test_Concat_1) {
auto exp = NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6}); auto exp = NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6});
nd4j::ops::concat op; nd4j::ops::concat op;
auto result = op.execute({&x, &y}, {}, {0}); auto result = op.evaluate({&x, &y}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -102,7 +102,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) {
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
nd4j::ops::expand_dims op; nd4j::ops::expand_dims op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -120,7 +120,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) {
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {3, 1}, {1, 2, 3});
nd4j::ops::expand_dims op; nd4j::ops::expand_dims op;
auto result = op.execute({&x}, {}, {1}); auto result = op.evaluate({&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -140,7 +140,7 @@ TEST_F(SingleDimTests, Test_Squeeze_1) {
auto exp = NDArrayFactory::create<float>(3.0f); auto exp = NDArrayFactory::create<float>(3.0f);
nd4j::ops::squeeze op; nd4j::ops::squeeze op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -157,7 +157,7 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
nd4j::ops::squeeze op; nd4j::ops::squeeze op;
auto result = op.execute({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -173,7 +173,7 @@ TEST_F(SingleDimTests, Test_Reshape_1) {
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {-99, 3}); auto result = op.evaluate({&x}, {}, {-99, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -189,7 +189,7 @@ TEST_F(SingleDimTests, Test_Reshape_2) {
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
nd4j::ops::reshape op; nd4j::ops::reshape op;
auto result = op.execute({&x}, {}, {-99, 1, 3}); auto result = op.evaluate({&x}, {}, {-99, 1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -206,7 +206,7 @@ TEST_F(SingleDimTests, Test_Permute_1) {
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
nd4j::ops::permute op; nd4j::ops::permute op;
auto result = op.execute({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);

View File

@ -23,8 +23,10 @@ public final class DType {
public static final byte QINT16 = 16; public static final byte QINT16 = 16;
public static final byte BFLOAT16 = 17; public static final byte BFLOAT16 = 17;
public static final byte UTF8 = 50; public static final byte UTF8 = 50;
public static final byte UTF16 = 51;
public static final byte UTF32 = 52;
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", }; public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", "UTF16", "UTF32", };
public static String name(int e) { return names[e]; } public static String name(int e) { return names[e]; }
} }

View File

@ -75,28 +75,28 @@ public final class FlatNode extends Table {
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
public static int createFlatNode(FlatBufferBuilder builder, public static int createFlatNode(FlatBufferBuilder builder,
int id, int id,
int nameOffset, int nameOffset,
byte opType, byte opType,
long opNum, long opNum,
int propertiesOffset, int propertiesOffset,
int inputOffset, int inputOffset,
int inputPairedOffset, int inputPairedOffset,
int outputOffset, int outputOffset,
int extraParamsOffset, int extraParamsOffset,
int extraIntegerOffset, int extraIntegerOffset,
int extraBoolsOffset, int extraBoolsOffset,
int dimensionsOffset, int dimensionsOffset,
int device, int device,
int scope_id, int scope_id,
int scope_nameOffset, int scope_nameOffset,
int outputNamesOffset, int outputNamesOffset,
int opNameOffset, int opNameOffset,
int outputTypesOffset, int outputTypesOffset,
int scalarOffset, int scalarOffset,
int controlDepsOffset, int controlDepsOffset,
int varControlDepsOffset, int varControlDepsOffset,
int controlDepForOffset) { int controlDepForOffset) {
builder.startObject(22); builder.startObject(22);
FlatNode.addOpNum(builder, opNum); FlatNode.addOpNum(builder, opNum);
FlatNode.addControlDepFor(builder, controlDepForOffset); FlatNode.addControlDepFor(builder, controlDepForOffset);

View File

@ -37,16 +37,16 @@ public final class FlatVariable extends Table {
public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; } public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; }
public static int createFlatVariable(FlatBufferBuilder builder, public static int createFlatVariable(FlatBufferBuilder builder,
int idOffset, int idOffset,
int nameOffset, int nameOffset,
byte dtype, byte dtype,
int shapeOffset, int shapeOffset,
int ndarrayOffset, int ndarrayOffset,
int device, int device,
byte variabletype, byte variabletype,
int controlDepsOffset, int controlDepsOffset,
int controlDepForOpOffset, int controlDepForOpOffset,
int controlDepsForVarOffset) { int controlDepsForVarOffset) {
builder.startObject(10); builder.startObject(10);
FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset); FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset);
FlatVariable.addControlDepForOp(builder, controlDepForOpOffset); FlatVariable.addControlDepForOp(builder, controlDepForOpOffset);
@ -88,4 +88,3 @@ public final class FlatVariable extends Table {
public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); } public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); }
} }

View File

@ -27,15 +27,15 @@ public final class UIEvent extends Table {
public int plugin() { int o = __offset(20); return o != 0 ? bb.getShort(o + bb_pos) & 0xFFFF : 0; } public int plugin() { int o = __offset(20); return o != 0 ? bb.getShort(o + bb_pos) & 0xFFFF : 0; }
public static int createUIEvent(FlatBufferBuilder builder, public static int createUIEvent(FlatBufferBuilder builder,
byte eventType, byte eventType,
byte eventSubType, byte eventSubType,
int nameIdx, int nameIdx,
long timestamp, long timestamp,
int iteration, int iteration,
int epoch, int epoch,
short variableId, short variableId,
int frameIterOffset, int frameIterOffset,
int plugin) { int plugin) {
builder.startObject(9); builder.startObject(9);
UIEvent.addTimestamp(builder, timestamp); UIEvent.addTimestamp(builder, timestamp);
UIEvent.addFrameIter(builder, frameIterOffset); UIEvent.addFrameIter(builder, frameIterOffset);

View File

@ -32,12 +32,12 @@ public final class UIOp extends Table {
public ByteBuffer uiLabelExtraInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 14, 1); } public ByteBuffer uiLabelExtraInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 14, 1); }
public static int createUIOp(FlatBufferBuilder builder, public static int createUIOp(FlatBufferBuilder builder,
int nameOffset, int nameOffset,
int opNameOffset, int opNameOffset,
int inputsOffset, int inputsOffset,
int outputsOffset, int outputsOffset,
int controlDepsOffset, int controlDepsOffset,
int uiLabelExtraOffset) { int uiLabelExtraOffset) {
builder.startObject(6); builder.startObject(6);
UIOp.addUiLabelExtra(builder, uiLabelExtraOffset); UIOp.addUiLabelExtra(builder, uiLabelExtraOffset);
UIOp.addControlDeps(builder, controlDepsOffset); UIOp.addControlDeps(builder, controlDepsOffset);

View File

@ -47,19 +47,19 @@ public final class UIVariable extends Table {
public FlatArray constantValue(FlatArray obj) { int o = __offset(28); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public FlatArray constantValue(FlatArray obj) { int o = __offset(28); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
public static int createUIVariable(FlatBufferBuilder builder, public static int createUIVariable(FlatBufferBuilder builder,
int idOffset, int idOffset,
int nameOffset, int nameOffset,
byte type, byte type,
byte datatype, byte datatype,
int shapeOffset, int shapeOffset,
int controlDepsOffset, int controlDepsOffset,
int outputOfOpOffset, int outputOfOpOffset,
int inputsForOpOffset, int inputsForOpOffset,
int controlDepsForOpOffset, int controlDepsForOpOffset,
int controlDepsForVarOffset, int controlDepsForVarOffset,
int gradientVariableOffset, int gradientVariableOffset,
int uiLabelExtraOffset, int uiLabelExtraOffset,
int constantValueOffset) { int constantValueOffset) {
builder.startObject(13); builder.startObject(13);
UIVariable.addConstantValue(builder, constantValueOffset); UIVariable.addConstantValue(builder, constantValueOffset);
UIVariable.addUiLabelExtra(builder, uiLabelExtraOffset); UIVariable.addUiLabelExtra(builder, uiLabelExtraOffset);

View File

@ -3096,6 +3096,9 @@ public native void setGraphContextInputArray(OpaqueContext ptr, int index, Point
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments);
public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, int numberOfArguments);
public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
@ -6435,6 +6438,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments); public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments);
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
public native void setDArguments(@Cast("nd4j::DataType*") IntPointer arguments, int numberOfArguments);
public native void setDArguments(@Cast("nd4j::DataType*") IntBuffer arguments, int numberOfArguments);
public native void setDArguments(@Cast("nd4j::DataType*") int[] arguments, int numberOfArguments);
public native void setTArguments(@StdVector DoublePointer tArgs); public native void setTArguments(@StdVector DoublePointer tArgs);
public native void setTArguments(@StdVector DoubleBuffer tArgs); public native void setTArguments(@StdVector DoubleBuffer tArgs);
@ -6444,6 +6450,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs);
public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs);
public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntPointer dArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
@ -6547,6 +6556,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native @StdVector DoublePointer getTArguments(); public native @StdVector DoublePointer getTArguments();
public native @StdVector IntPointer getIArguments(); public native @StdVector IntPointer getIArguments();
public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); public native @Cast("bool*") @StdVector BooleanPointer getBArguments();
public native @Cast("nd4j::DataType*") @StdVector IntPointer getDArguments();
public native @StdVector IntPointer getAxis(); public native @StdVector IntPointer getAxis();
public native @Cast("samediff::Engine") int engine(); public native @Cast("samediff::Engine") int engine();
@ -6554,6 +6564,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native @Cast("size_t") long numT(); public native @Cast("size_t") long numT();
public native @Cast("size_t") long numI(); public native @Cast("size_t") long numI();
public native @Cast("size_t") long numB(); public native @Cast("size_t") long numB();
public native @Cast("size_t") long numD();
public native IntIntPair input(int idx); public native IntIntPair input(int idx);
@ -9418,39 +9429,43 @@ public static final int PREALLOC_SIZE = 33554432;
*/ */
public native @Cast("Nd4jStatus") int execute(Context block); public native @Cast("Nd4jStatus") int execute(Context block);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder); public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
@ -9649,8 +9664,9 @@ public static final int PREALLOC_SIZE = 33554432;
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public BooleanOp(Pointer p) { super(p); } public BooleanOp(Pointer p) { super(p); }
public native @Cast("bool") boolean evaluate(@ByRef NDArrayVector args);
public native @Cast("bool") boolean evaluate(@ByRef Context block); public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args);
public native @Cast("bool") boolean verify(@ByRef Context block);
public native @Cast("Nd4jStatus") int execute(Context block); public native @Cast("Nd4jStatus") int execute(Context block);

View File

@ -3099,6 +3099,9 @@ public native void setGraphContextInputArray(OpaqueContext ptr, int index, Point
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments);
public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, int numberOfArguments);
public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
@ -6438,6 +6441,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments); public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments);
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
public native void setDArguments(@Cast("nd4j::DataType*") IntPointer arguments, int numberOfArguments);
public native void setDArguments(@Cast("nd4j::DataType*") IntBuffer arguments, int numberOfArguments);
public native void setDArguments(@Cast("nd4j::DataType*") int[] arguments, int numberOfArguments);
public native void setTArguments(@StdVector DoublePointer tArgs); public native void setTArguments(@StdVector DoublePointer tArgs);
public native void setTArguments(@StdVector DoubleBuffer tArgs); public native void setTArguments(@StdVector DoubleBuffer tArgs);
@ -6447,6 +6453,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs);
public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs);
public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntPointer dArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
@ -6550,6 +6559,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native @StdVector DoublePointer getTArguments(); public native @StdVector DoublePointer getTArguments();
public native @StdVector IntPointer getIArguments(); public native @StdVector IntPointer getIArguments();
public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); public native @Cast("bool*") @StdVector BooleanPointer getBArguments();
public native @Cast("nd4j::DataType*") @StdVector IntPointer getDArguments();
public native @StdVector IntPointer getAxis(); public native @StdVector IntPointer getAxis();
public native @Cast("samediff::Engine") int engine(); public native @Cast("samediff::Engine") int engine();
@ -6557,6 +6567,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native @Cast("size_t") long numT(); public native @Cast("size_t") long numT();
public native @Cast("size_t") long numI(); public native @Cast("size_t") long numI();
public native @Cast("size_t") long numB(); public native @Cast("size_t") long numB();
public native @Cast("size_t") long numD();
public native IntIntPair input(int idx); public native IntIntPair input(int idx);
@ -11130,7 +11141,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList()) // #define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
// #define D_ARG(INDEX) block.getDArguments()->at(INDEX)
// #define INT_ARG(INDEX) block.getIArguments()->at(INDEX) // #define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
// #define I_ARG(INDEX) INT_ARG(INDEX)
// #define T_ARG(INDEX) block.getTArguments()->at(INDEX) // #define T_ARG(INDEX) block.getTArguments()->at(INDEX)
// #define B_ARG(INDEX) block.getBArguments()->at(INDEX) // #define B_ARG(INDEX) block.getBArguments()->at(INDEX)
@ -11629,39 +11642,43 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
*/ */
public native @Cast("Nd4jStatus") int execute(Context block); public native @Cast("Nd4jStatus") int execute(Context block);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder); public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
@ -11860,8 +11877,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public BooleanOp(Pointer p) { super(p); } public BooleanOp(Pointer p) { super(p); }
public native @Cast("bool") boolean evaluate(@ByRef NDArrayVector args);
public native @Cast("bool") boolean evaluate(@ByRef Context block); public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args);
public native @Cast("bool") boolean verify(@ByRef Context block);
public native @Cast("Nd4jStatus") int execute(Context block); public native @Cast("Nd4jStatus") int execute(Context block);

View File

@ -1,4 +1,4 @@
//Generated by flatc compiler (version 1.9.0) //Generated by flatc compiler (version 1.10.0)
//If you make any local changes, they will be lost //If you make any local changes, they will be lost
//source: graph.fbs //source: graph.fbs
@ -31,17 +31,17 @@ public final class GraphInferenceServerGrpc {
private GraphInferenceServerGrpc() {} private GraphInferenceServerGrpc() {}
public static final String SERVICE_NAME = "nd4j.graph.GraphInferenceServer"; public static final String SERVICE_NAME = "org.nd4j.graph.GraphInferenceServer";
// Static method descriptors that strictly reflect the proto. // Static method descriptors that strictly reflect the proto.
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
@Deprecated // Use {@link #getRegisterGraphMethod()} instead. @Deprecated // Use {@link #getRegisterGraphMethod()} instead.
public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph, public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
org.nd4j.graph.FlatResponse> METHOD_REGISTER_GRAPH = getRegisterGraphMethod(); org.nd4j.graph.FlatResponse> METHOD_REGISTER_GRAPH = getRegisterGraphMethod();
private static volatile io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph, private static volatile io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
org.nd4j.graph.FlatResponse> getRegisterGraphMethod; org.nd4j.graph.FlatResponse> getRegisterGraphMethod;
private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatGraph> extractorOfFlatGraph; private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatGraph> extractorOfFlatGraph;
private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatGraph> getExtractorOfFlatGraph() { private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatGraph> getExtractorOfFlatGraph() {
if (extractorOfFlatGraph != null) return extractorOfFlatGraph; if (extractorOfFlatGraph != null) return extractorOfFlatGraph;
@ -55,7 +55,7 @@ public final class GraphInferenceServerGrpc {
return extractorOfFlatGraph; return extractorOfFlatGraph;
} }
} }
private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatResponse> extractorOfFlatResponse; private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatResponse> extractorOfFlatResponse;
private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatResponse> getExtractorOfFlatResponse() { private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatResponse> getExtractorOfFlatResponse() {
if (extractorOfFlatResponse != null) return extractorOfFlatResponse; if (extractorOfFlatResponse != null) return extractorOfFlatResponse;
@ -69,7 +69,7 @@ public final class GraphInferenceServerGrpc {
return extractorOfFlatResponse; return extractorOfFlatResponse;
} }
} }
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
public static io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph, public static io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
org.nd4j.graph.FlatResponse> getRegisterGraphMethod() { org.nd4j.graph.FlatResponse> getRegisterGraphMethod() {
@ -77,11 +77,11 @@ public final class GraphInferenceServerGrpc {
if ((getRegisterGraphMethod = GraphInferenceServerGrpc.getRegisterGraphMethod) == null) { if ((getRegisterGraphMethod = GraphInferenceServerGrpc.getRegisterGraphMethod) == null) {
synchronized (GraphInferenceServerGrpc.class) { synchronized (GraphInferenceServerGrpc.class) {
if ((getRegisterGraphMethod = GraphInferenceServerGrpc.getRegisterGraphMethod) == null) { if ((getRegisterGraphMethod = GraphInferenceServerGrpc.getRegisterGraphMethod) == null) {
GraphInferenceServerGrpc.getRegisterGraphMethod = getRegisterGraphMethod = GraphInferenceServerGrpc.getRegisterGraphMethod = getRegisterGraphMethod =
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatGraph, org.nd4j.graph.FlatResponse>newBuilder() io.grpc.MethodDescriptor.<org.nd4j.graph.FlatGraph, org.nd4j.graph.FlatResponse>newBuilder()
.setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setType(io.grpc.MethodDescriptor.MethodType.UNARY)
.setFullMethodName(generateFullMethodName( .setFullMethodName(generateFullMethodName(
"nd4j.graph.GraphInferenceServer", "RegisterGraph")) "org.nd4j.graph.GraphInferenceServer", "RegisterGraph"))
.setSampledToLocalTracing(true) .setSampledToLocalTracing(true)
.setRequestMarshaller(FlatbuffersUtils.marshaller( .setRequestMarshaller(FlatbuffersUtils.marshaller(
org.nd4j.graph.FlatGraph.class, getExtractorOfFlatGraph())) org.nd4j.graph.FlatGraph.class, getExtractorOfFlatGraph()))
@ -94,15 +94,15 @@ public final class GraphInferenceServerGrpc {
} }
return getRegisterGraphMethod; return getRegisterGraphMethod;
} }
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
@Deprecated // Use {@link #getForgetGraphMethod()} instead. @Deprecated // Use {@link #getForgetGraphMethod()} instead.
public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatDropRequest, public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatDropRequest,
org.nd4j.graph.FlatResponse> METHOD_FORGET_GRAPH = getForgetGraphMethod(); org.nd4j.graph.FlatResponse> METHOD_FORGET_GRAPH = getForgetGraphMethod();
private static volatile io.grpc.MethodDescriptor<org.nd4j.graph.FlatDropRequest, private static volatile io.grpc.MethodDescriptor<org.nd4j.graph.FlatDropRequest,
org.nd4j.graph.FlatResponse> getForgetGraphMethod; org.nd4j.graph.FlatResponse> getForgetGraphMethod;
private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatDropRequest> extractorOfFlatDropRequest; private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatDropRequest> extractorOfFlatDropRequest;
private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatDropRequest> getExtractorOfFlatDropRequest() { private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatDropRequest> getExtractorOfFlatDropRequest() {
if (extractorOfFlatDropRequest != null) return extractorOfFlatDropRequest; if (extractorOfFlatDropRequest != null) return extractorOfFlatDropRequest;
@ -116,7 +116,7 @@ public final class GraphInferenceServerGrpc {
return extractorOfFlatDropRequest; return extractorOfFlatDropRequest;
} }
} }
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
public static io.grpc.MethodDescriptor<org.nd4j.graph.FlatDropRequest, public static io.grpc.MethodDescriptor<org.nd4j.graph.FlatDropRequest,
org.nd4j.graph.FlatResponse> getForgetGraphMethod() { org.nd4j.graph.FlatResponse> getForgetGraphMethod() {
@ -124,11 +124,11 @@ public final class GraphInferenceServerGrpc {
if ((getForgetGraphMethod = GraphInferenceServerGrpc.getForgetGraphMethod) == null) { if ((getForgetGraphMethod = GraphInferenceServerGrpc.getForgetGraphMethod) == null) {
synchronized (GraphInferenceServerGrpc.class) { synchronized (GraphInferenceServerGrpc.class) {
if ((getForgetGraphMethod = GraphInferenceServerGrpc.getForgetGraphMethod) == null) { if ((getForgetGraphMethod = GraphInferenceServerGrpc.getForgetGraphMethod) == null) {
GraphInferenceServerGrpc.getForgetGraphMethod = getForgetGraphMethod = GraphInferenceServerGrpc.getForgetGraphMethod = getForgetGraphMethod =
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatDropRequest, org.nd4j.graph.FlatResponse>newBuilder() io.grpc.MethodDescriptor.<org.nd4j.graph.FlatDropRequest, org.nd4j.graph.FlatResponse>newBuilder()
.setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setType(io.grpc.MethodDescriptor.MethodType.UNARY)
.setFullMethodName(generateFullMethodName( .setFullMethodName(generateFullMethodName(
"nd4j.graph.GraphInferenceServer", "ForgetGraph")) "org.nd4j.graph.GraphInferenceServer", "ForgetGraph"))
.setSampledToLocalTracing(true) .setSampledToLocalTracing(true)
.setRequestMarshaller(FlatbuffersUtils.marshaller( .setRequestMarshaller(FlatbuffersUtils.marshaller(
org.nd4j.graph.FlatDropRequest.class, getExtractorOfFlatDropRequest())) org.nd4j.graph.FlatDropRequest.class, getExtractorOfFlatDropRequest()))
@ -141,15 +141,48 @@ public final class GraphInferenceServerGrpc {
} }
return getForgetGraphMethod; return getForgetGraphMethod;
} }
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
@Deprecated // Use {@link #getInferenceRequestMethod()} instead. @Deprecated // Use {@link #getReplaceGraphMethod()} instead.
public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
org.nd4j.graph.FlatResponse> METHOD_REPLACE_GRAPH = getReplaceGraphMethod();
private static volatile io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
org.nd4j.graph.FlatResponse> getReplaceGraphMethod;
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
public static io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
org.nd4j.graph.FlatResponse> getReplaceGraphMethod() {
io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph, org.nd4j.graph.FlatResponse> getReplaceGraphMethod;
if ((getReplaceGraphMethod = GraphInferenceServerGrpc.getReplaceGraphMethod) == null) {
synchronized (GraphInferenceServerGrpc.class) {
if ((getReplaceGraphMethod = GraphInferenceServerGrpc.getReplaceGraphMethod) == null) {
GraphInferenceServerGrpc.getReplaceGraphMethod = getReplaceGraphMethod =
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatGraph, org.nd4j.graph.FlatResponse>newBuilder()
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
.setFullMethodName(generateFullMethodName(
"org.nd4j.graph.GraphInferenceServer", "ReplaceGraph"))
.setSampledToLocalTracing(true)
.setRequestMarshaller(FlatbuffersUtils.marshaller(
org.nd4j.graph.FlatGraph.class, getExtractorOfFlatGraph()))
.setResponseMarshaller(FlatbuffersUtils.marshaller(
org.nd4j.graph.FlatResponse.class, getExtractorOfFlatResponse()))
.setSchemaDescriptor(null)
.build();
}
}
}
return getReplaceGraphMethod;
}
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
@Deprecated // Use {@link #getInferenceRequestMethod()} instead.
public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatInferenceRequest, public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatInferenceRequest,
org.nd4j.graph.FlatResult> METHOD_INFERENCE_REQUEST = getInferenceRequestMethod(); org.nd4j.graph.FlatResult> METHOD_INFERENCE_REQUEST = getInferenceRequestMethod();
private static volatile io.grpc.MethodDescriptor<org.nd4j.graph.FlatInferenceRequest, private static volatile io.grpc.MethodDescriptor<org.nd4j.graph.FlatInferenceRequest,
org.nd4j.graph.FlatResult> getInferenceRequestMethod; org.nd4j.graph.FlatResult> getInferenceRequestMethod;
private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatInferenceRequest> extractorOfFlatInferenceRequest; private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatInferenceRequest> extractorOfFlatInferenceRequest;
private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatInferenceRequest> getExtractorOfFlatInferenceRequest() { private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatInferenceRequest> getExtractorOfFlatInferenceRequest() {
if (extractorOfFlatInferenceRequest != null) return extractorOfFlatInferenceRequest; if (extractorOfFlatInferenceRequest != null) return extractorOfFlatInferenceRequest;
@ -163,7 +196,7 @@ public final class GraphInferenceServerGrpc {
return extractorOfFlatInferenceRequest; return extractorOfFlatInferenceRequest;
} }
} }
private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatResult> extractorOfFlatResult; private static volatile FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatResult> extractorOfFlatResult;
private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatResult> getExtractorOfFlatResult() { private static FlatbuffersUtils.FBExtactor<org.nd4j.graph.FlatResult> getExtractorOfFlatResult() {
if (extractorOfFlatResult != null) return extractorOfFlatResult; if (extractorOfFlatResult != null) return extractorOfFlatResult;
@ -177,7 +210,7 @@ public final class GraphInferenceServerGrpc {
return extractorOfFlatResult; return extractorOfFlatResult;
} }
} }
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
public static io.grpc.MethodDescriptor<org.nd4j.graph.FlatInferenceRequest, public static io.grpc.MethodDescriptor<org.nd4j.graph.FlatInferenceRequest,
org.nd4j.graph.FlatResult> getInferenceRequestMethod() { org.nd4j.graph.FlatResult> getInferenceRequestMethod() {
@ -185,11 +218,11 @@ public final class GraphInferenceServerGrpc {
if ((getInferenceRequestMethod = GraphInferenceServerGrpc.getInferenceRequestMethod) == null) { if ((getInferenceRequestMethod = GraphInferenceServerGrpc.getInferenceRequestMethod) == null) {
synchronized (GraphInferenceServerGrpc.class) { synchronized (GraphInferenceServerGrpc.class) {
if ((getInferenceRequestMethod = GraphInferenceServerGrpc.getInferenceRequestMethod) == null) { if ((getInferenceRequestMethod = GraphInferenceServerGrpc.getInferenceRequestMethod) == null) {
GraphInferenceServerGrpc.getInferenceRequestMethod = getInferenceRequestMethod = GraphInferenceServerGrpc.getInferenceRequestMethod = getInferenceRequestMethod =
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatInferenceRequest, org.nd4j.graph.FlatResult>newBuilder() io.grpc.MethodDescriptor.<org.nd4j.graph.FlatInferenceRequest, org.nd4j.graph.FlatResult>newBuilder()
.setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setType(io.grpc.MethodDescriptor.MethodType.UNARY)
.setFullMethodName(generateFullMethodName( .setFullMethodName(generateFullMethodName(
"nd4j.graph.GraphInferenceServer", "InferenceRequest")) "org.nd4j.graph.GraphInferenceServer", "InferenceRequest"))
.setSampledToLocalTracing(true) .setSampledToLocalTracing(true)
.setRequestMarshaller(FlatbuffersUtils.marshaller( .setRequestMarshaller(FlatbuffersUtils.marshaller(
org.nd4j.graph.FlatInferenceRequest.class, getExtractorOfFlatInferenceRequest())) org.nd4j.graph.FlatInferenceRequest.class, getExtractorOfFlatInferenceRequest()))
@ -202,14 +235,14 @@ public final class GraphInferenceServerGrpc {
} }
return getInferenceRequestMethod; return getInferenceRequestMethod;
} }
/** /**
* Creates a new async stub that supports all call types for the service * Creates a new async stub that supports all call types for the service
*/ */
public static GraphInferenceServerStub newStub(io.grpc.Channel channel) { public static GraphInferenceServerStub newStub(io.grpc.Channel channel) {
return new GraphInferenceServerStub(channel); return new GraphInferenceServerStub(channel);
} }
/** /**
* Creates a new blocking-style stub that supports unary and streaming output calls on the service * Creates a new blocking-style stub that supports unary and streaming output calls on the service
*/ */
@ -217,7 +250,7 @@ public final class GraphInferenceServerGrpc {
io.grpc.Channel channel) { io.grpc.Channel channel) {
return new GraphInferenceServerBlockingStub(channel); return new GraphInferenceServerBlockingStub(channel);
} }
/** /**
* Creates a new ListenableFuture-style stub that supports unary calls on the service * Creates a new ListenableFuture-style stub that supports unary calls on the service
*/ */
@ -225,32 +258,39 @@ public final class GraphInferenceServerGrpc {
io.grpc.Channel channel) { io.grpc.Channel channel) {
return new GraphInferenceServerFutureStub(channel); return new GraphInferenceServerFutureStub(channel);
} }
/** /**
*/ */
public static abstract class GraphInferenceServerImplBase implements io.grpc.BindableService { public static abstract class GraphInferenceServerImplBase implements io.grpc.BindableService {
/** /**
*/ */
public void registerGraph(org.nd4j.graph.FlatGraph request, public void registerGraph(org.nd4j.graph.FlatGraph request,
io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse> responseObserver) { io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse> responseObserver) {
asyncUnimplementedUnaryCall(getRegisterGraphMethod(), responseObserver); asyncUnimplementedUnaryCall(getRegisterGraphMethod(), responseObserver);
} }
/** /**
*/ */
public void forgetGraph(org.nd4j.graph.FlatDropRequest request, public void forgetGraph(org.nd4j.graph.FlatDropRequest request,
io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse> responseObserver) { io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse> responseObserver) {
asyncUnimplementedUnaryCall(getForgetGraphMethod(), responseObserver); asyncUnimplementedUnaryCall(getForgetGraphMethod(), responseObserver);
} }
/**
*/
public void replaceGraph(org.nd4j.graph.FlatGraph request,
io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse> responseObserver) {
asyncUnimplementedUnaryCall(getReplaceGraphMethod(), responseObserver);
}
/** /**
*/ */
public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request, public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request,
io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResult> responseObserver) { io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResult> responseObserver) {
asyncUnimplementedUnaryCall(getInferenceRequestMethod(), responseObserver); asyncUnimplementedUnaryCall(getInferenceRequestMethod(), responseObserver);
} }
@Override public final io.grpc.ServerServiceDefinition bindService() { @Override public final io.grpc.ServerServiceDefinition bindService() {
return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor())
.addMethod( .addMethod(
@ -267,6 +307,13 @@ public final class GraphInferenceServerGrpc {
org.nd4j.graph.FlatDropRequest, org.nd4j.graph.FlatDropRequest,
org.nd4j.graph.FlatResponse>( org.nd4j.graph.FlatResponse>(
this, METHODID_FORGET_GRAPH))) this, METHODID_FORGET_GRAPH)))
.addMethod(
getReplaceGraphMethod(),
asyncUnaryCall(
new MethodHandlers<
org.nd4j.graph.FlatGraph,
org.nd4j.graph.FlatResponse>(
this, METHODID_REPLACE_GRAPH)))
.addMethod( .addMethod(
getInferenceRequestMethod(), getInferenceRequestMethod(),
asyncUnaryCall( asyncUnaryCall(
@ -277,25 +324,25 @@ public final class GraphInferenceServerGrpc {
.build(); .build();
} }
} }
/** /**
*/ */
public static final class GraphInferenceServerStub extends io.grpc.stub.AbstractStub<GraphInferenceServerStub> { public static final class GraphInferenceServerStub extends io.grpc.stub.AbstractStub<GraphInferenceServerStub> {
private GraphInferenceServerStub(io.grpc.Channel channel) { private GraphInferenceServerStub(io.grpc.Channel channel) {
super(channel); super(channel);
} }
private GraphInferenceServerStub(io.grpc.Channel channel, private GraphInferenceServerStub(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) { io.grpc.CallOptions callOptions) {
super(channel, callOptions); super(channel, callOptions);
} }
@Override @Override
protected GraphInferenceServerStub build(io.grpc.Channel channel, protected GraphInferenceServerStub build(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) { io.grpc.CallOptions callOptions) {
return new GraphInferenceServerStub(channel, callOptions); return new GraphInferenceServerStub(channel, callOptions);
} }
/** /**
*/ */
public void registerGraph(org.nd4j.graph.FlatGraph request, public void registerGraph(org.nd4j.graph.FlatGraph request,
@ -303,7 +350,7 @@ public final class GraphInferenceServerGrpc {
asyncUnaryCall( asyncUnaryCall(
getChannel().newCall(getRegisterGraphMethod(), getCallOptions()), request, responseObserver); getChannel().newCall(getRegisterGraphMethod(), getCallOptions()), request, responseObserver);
} }
/** /**
*/ */
public void forgetGraph(org.nd4j.graph.FlatDropRequest request, public void forgetGraph(org.nd4j.graph.FlatDropRequest request,
@ -311,7 +358,15 @@ public final class GraphInferenceServerGrpc {
asyncUnaryCall( asyncUnaryCall(
getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request, responseObserver); getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request, responseObserver);
} }
/**
*/
public void replaceGraph(org.nd4j.graph.FlatGraph request,
io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse> responseObserver) {
asyncUnaryCall(
getChannel().newCall(getReplaceGraphMethod(), getCallOptions()), request, responseObserver);
}
/** /**
*/ */
public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request, public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request,
@ -320,39 +375,46 @@ public final class GraphInferenceServerGrpc {
getChannel().newCall(getInferenceRequestMethod(), getCallOptions()), request, responseObserver); getChannel().newCall(getInferenceRequestMethod(), getCallOptions()), request, responseObserver);
} }
} }
/** /**
*/ */
public static final class GraphInferenceServerBlockingStub extends io.grpc.stub.AbstractStub<GraphInferenceServerBlockingStub> { public static final class GraphInferenceServerBlockingStub extends io.grpc.stub.AbstractStub<GraphInferenceServerBlockingStub> {
private GraphInferenceServerBlockingStub(io.grpc.Channel channel) { private GraphInferenceServerBlockingStub(io.grpc.Channel channel) {
super(channel); super(channel);
} }
private GraphInferenceServerBlockingStub(io.grpc.Channel channel, private GraphInferenceServerBlockingStub(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) { io.grpc.CallOptions callOptions) {
super(channel, callOptions); super(channel, callOptions);
} }
@Override @Override
protected GraphInferenceServerBlockingStub build(io.grpc.Channel channel, protected GraphInferenceServerBlockingStub build(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) { io.grpc.CallOptions callOptions) {
return new GraphInferenceServerBlockingStub(channel, callOptions); return new GraphInferenceServerBlockingStub(channel, callOptions);
} }
/** /**
*/ */
public org.nd4j.graph.FlatResponse registerGraph(org.nd4j.graph.FlatGraph request) { public org.nd4j.graph.FlatResponse registerGraph(org.nd4j.graph.FlatGraph request) {
return blockingUnaryCall( return blockingUnaryCall(
getChannel(), getRegisterGraphMethod(), getCallOptions(), request); getChannel(), getRegisterGraphMethod(), getCallOptions(), request);
} }
/** /**
*/ */
public org.nd4j.graph.FlatResponse forgetGraph(org.nd4j.graph.FlatDropRequest request) { public org.nd4j.graph.FlatResponse forgetGraph(org.nd4j.graph.FlatDropRequest request) {
return blockingUnaryCall( return blockingUnaryCall(
getChannel(), getForgetGraphMethod(), getCallOptions(), request); getChannel(), getForgetGraphMethod(), getCallOptions(), request);
} }
/**
*/
public org.nd4j.graph.FlatResponse replaceGraph(org.nd4j.graph.FlatGraph request) {
return blockingUnaryCall(
getChannel(), getReplaceGraphMethod(), getCallOptions(), request);
}
/** /**
*/ */
public org.nd4j.graph.FlatResult inferenceRequest(org.nd4j.graph.FlatInferenceRequest request) { public org.nd4j.graph.FlatResult inferenceRequest(org.nd4j.graph.FlatInferenceRequest request) {
@ -360,25 +422,25 @@ public final class GraphInferenceServerGrpc {
getChannel(), getInferenceRequestMethod(), getCallOptions(), request); getChannel(), getInferenceRequestMethod(), getCallOptions(), request);
} }
} }
/** /**
*/ */
public static final class GraphInferenceServerFutureStub extends io.grpc.stub.AbstractStub<GraphInferenceServerFutureStub> { public static final class GraphInferenceServerFutureStub extends io.grpc.stub.AbstractStub<GraphInferenceServerFutureStub> {
private GraphInferenceServerFutureStub(io.grpc.Channel channel) { private GraphInferenceServerFutureStub(io.grpc.Channel channel) {
super(channel); super(channel);
} }
private GraphInferenceServerFutureStub(io.grpc.Channel channel, private GraphInferenceServerFutureStub(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) { io.grpc.CallOptions callOptions) {
super(channel, callOptions); super(channel, callOptions);
} }
@Override @Override
protected GraphInferenceServerFutureStub build(io.grpc.Channel channel, protected GraphInferenceServerFutureStub build(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) { io.grpc.CallOptions callOptions) {
return new GraphInferenceServerFutureStub(channel, callOptions); return new GraphInferenceServerFutureStub(channel, callOptions);
} }
/** /**
*/ */
public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResponse> registerGraph( public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResponse> registerGraph(
@ -386,7 +448,7 @@ public final class GraphInferenceServerGrpc {
return futureUnaryCall( return futureUnaryCall(
getChannel().newCall(getRegisterGraphMethod(), getCallOptions()), request); getChannel().newCall(getRegisterGraphMethod(), getCallOptions()), request);
} }
/** /**
*/ */
public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResponse> forgetGraph( public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResponse> forgetGraph(
@ -394,7 +456,15 @@ public final class GraphInferenceServerGrpc {
return futureUnaryCall( return futureUnaryCall(
getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request); getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request);
} }
/**
*/
public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResponse> replaceGraph(
org.nd4j.graph.FlatGraph request) {
return futureUnaryCall(
getChannel().newCall(getReplaceGraphMethod(), getCallOptions()), request);
}
/** /**
*/ */
public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResult> inferenceRequest( public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResult> inferenceRequest(
@ -403,11 +473,12 @@ public final class GraphInferenceServerGrpc {
getChannel().newCall(getInferenceRequestMethod(), getCallOptions()), request); getChannel().newCall(getInferenceRequestMethod(), getCallOptions()), request);
} }
} }
private static final int METHODID_REGISTER_GRAPH = 0; private static final int METHODID_REGISTER_GRAPH = 0;
private static final int METHODID_FORGET_GRAPH = 1; private static final int METHODID_FORGET_GRAPH = 1;
private static final int METHODID_INFERENCE_REQUEST = 2; private static final int METHODID_REPLACE_GRAPH = 2;
private static final int METHODID_INFERENCE_REQUEST = 3;
private static final class MethodHandlers<Req, Resp> implements private static final class MethodHandlers<Req, Resp> implements
io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>, io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,
io.grpc.stub.ServerCalls.ServerStreamingMethod<Req, Resp>, io.grpc.stub.ServerCalls.ServerStreamingMethod<Req, Resp>,
@ -415,12 +486,12 @@ public final class GraphInferenceServerGrpc {
io.grpc.stub.ServerCalls.BidiStreamingMethod<Req, Resp> { io.grpc.stub.ServerCalls.BidiStreamingMethod<Req, Resp> {
private final GraphInferenceServerImplBase serviceImpl; private final GraphInferenceServerImplBase serviceImpl;
private final int methodId; private final int methodId;
MethodHandlers(GraphInferenceServerImplBase serviceImpl, int methodId) { MethodHandlers(GraphInferenceServerImplBase serviceImpl, int methodId) {
this.serviceImpl = serviceImpl; this.serviceImpl = serviceImpl;
this.methodId = methodId; this.methodId = methodId;
} }
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void invoke(Req request, io.grpc.stub.StreamObserver<Resp> responseObserver) { public void invoke(Req request, io.grpc.stub.StreamObserver<Resp> responseObserver) {
@ -433,6 +504,10 @@ public final class GraphInferenceServerGrpc {
serviceImpl.forgetGraph((org.nd4j.graph.FlatDropRequest) request, serviceImpl.forgetGraph((org.nd4j.graph.FlatDropRequest) request,
(io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse>) responseObserver); (io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse>) responseObserver);
break; break;
case METHODID_REPLACE_GRAPH:
serviceImpl.replaceGraph((org.nd4j.graph.FlatGraph) request,
(io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse>) responseObserver);
break;
case METHODID_INFERENCE_REQUEST: case METHODID_INFERENCE_REQUEST:
serviceImpl.inferenceRequest((org.nd4j.graph.FlatInferenceRequest) request, serviceImpl.inferenceRequest((org.nd4j.graph.FlatInferenceRequest) request,
(io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResult>) responseObserver); (io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResult>) responseObserver);
@ -465,6 +540,7 @@ public final class GraphInferenceServerGrpc {
.setSchemaDescriptor(null) .setSchemaDescriptor(null)
.addMethod(getRegisterGraphMethod()) .addMethod(getRegisterGraphMethod())
.addMethod(getForgetGraphMethod()) .addMethod(getForgetGraphMethod())
.addMethod(getReplaceGraphMethod())
.addMethod(getInferenceRequestMethod()) .addMethod(getInferenceRequestMethod())
.build(); .build();
} }