DataTypes & FlatBuffers (#197)
* flatbuffers version upgrade Signed-off-by: raver119 <raver119@gmail.com> * flatbuffers version upgrade java side Signed-off-by: raver119 <raver119@gmail.com> * flatbuffers dependency version upgrade java side Signed-off-by: raver119 <raver119@gmail.com> * MKLDNN version upgrade Signed-off-by: raver119 <raver119@gmail.com> * DArgs first pass Signed-off-by: raver119 <raver119@gmail.com> * signatures first pass Signed-off-by: raver119 <raver119@gmail.com> * signatures second pass Signed-off-by: raver119 <raver119@gmail.com> * signatures third pass Signed-off-by: raver119 <raver119@gmail.com> * signatures third pass Signed-off-by: raver119 <raver119@gmail.com> * signatures fourth pass Signed-off-by: raver119 <raver119@gmail.com> * signatures fifth pass Signed-off-by: raver119 <raver119@gmail.com> * flatbuffers UI version upgrade java side Signed-off-by: raver119 <raver119@gmail.com> * flatbuffers ui update Signed-off-by: raver119 <raver119@gmail.com> * flatbuffers downgrade Signed-off-by: raver119 <raver119@gmail.com> * flatbuffers downgrade java side Signed-off-by: raver119 <raver119@gmail.com>master
parent
5039fb22b7
commit
ba961c7601
|
@ -5,7 +5,7 @@ project(flatbuffers-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(flatbuffers
|
ExternalProject_Add(flatbuffers
|
||||||
GIT_REPOSITORY https://github.com/google/flatbuffers.git
|
GIT_REPOSITORY https://github.com/google/flatbuffers.git
|
||||||
GIT_TAG v1.10.0
|
GIT_TAG v1.11.0
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build"
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
|
|
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(mkldnn
|
ExternalProject_Add(mkldnn
|
||||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||||
GIT_TAG v1.1.2
|
GIT_TAG v1.1.3
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
|
|
|
@ -1607,6 +1607,7 @@ ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *
|
||||||
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
||||||
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
||||||
|
ND4J_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments);
|
||||||
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
|
||||||
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
|
||||||
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
|
||||||
|
|
|
@ -2130,7 +2130,7 @@ Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4
|
||||||
biArgs[e] = bArgs[e];
|
biArgs[e] = bArgs[e];
|
||||||
|
|
||||||
// hypothetically at this point we have everything filled
|
// hypothetically at this point we have everything filled
|
||||||
auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, isInplace);
|
auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, std::vector<nd4j::DataType>(), isInplace);
|
||||||
//auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
//auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
||||||
|
|
||||||
|
|
||||||
|
@ -2788,6 +2788,15 @@ void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, i
|
||||||
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
||||||
ptr->setBArguments(arguments, numberOfArguments);
|
ptr->setBArguments(arguments, numberOfArguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
|
||||||
|
std::vector<nd4j::DataType> dtypes(numberOfArguments);
|
||||||
|
for (int e = 0; e < numberOfArguments; e++)
|
||||||
|
dtypes[e] = (nd4j::DataType) arguments[e];
|
||||||
|
|
||||||
|
ptr->setDArguments(dtypes);
|
||||||
|
}
|
||||||
|
|
||||||
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||||
delete ptr;
|
delete ptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2831,7 +2831,7 @@ static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer*
|
||||||
|
|
||||||
|
|
||||||
// hypothetically at this point we have everything filled
|
// hypothetically at this point we have everything filled
|
||||||
auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, isInplace);
|
auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, std::vector<nd4j::DataType>(), isInplace);
|
||||||
//auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
//auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
||||||
|
|
||||||
|
|
||||||
|
@ -3596,6 +3596,14 @@ void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int n
|
||||||
ptr->setBArguments(arguments, numberOfArguments);
|
ptr->setBArguments(arguments, numberOfArguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
|
||||||
|
std::vector<nd4j::DataType> dtypes(numberOfArguments);
|
||||||
|
for (int e = 0; e < numberOfArguments; e++)
|
||||||
|
dtypes[e] = (nd4j::DataType) arguments[e];
|
||||||
|
|
||||||
|
ptr->setDArguments(dtypes);
|
||||||
|
}
|
||||||
|
|
||||||
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||||
delete ptr;
|
delete ptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,6 +95,10 @@ namespace nd4j {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
// struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; };
|
// struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; };
|
||||||
struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; };
|
struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; };
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct scalarTypesForExecution { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<Nd4jLong, T>::value || std::is_same<int, T>::value || std::is_same<bool, T>::value; };
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,7 @@ namespace nd4j {
|
||||||
|
|
||||||
iargs.push_back(_axis);
|
iargs.push_back(_axis);
|
||||||
|
|
||||||
auto result = op.execute(inputs, {}, {}, {});
|
auto result = op.evaluate(inputs);
|
||||||
|
|
||||||
auto array = new NDArray(result->at(0)->dup());
|
auto array = new NDArray(result->at(0)->dup());
|
||||||
|
|
||||||
|
|
|
@ -197,10 +197,12 @@ namespace nd4j {
|
||||||
void setTArguments(double *arguments, int numberOfArguments);
|
void setTArguments(double *arguments, int numberOfArguments);
|
||||||
void setIArguments(Nd4jLong *arguments, int numberOfArguments);
|
void setIArguments(Nd4jLong *arguments, int numberOfArguments);
|
||||||
void setBArguments(bool *arguments, int numberOfArguments);
|
void setBArguments(bool *arguments, int numberOfArguments);
|
||||||
|
void setDArguments(nd4j::DataType *arguments, int numberOfArguments);
|
||||||
|
|
||||||
void setTArguments(const std::vector<double> &tArgs);
|
void setTArguments(const std::vector<double> &tArgs);
|
||||||
void setIArguments(const std::vector<Nd4jLong> &tArgs);
|
void setIArguments(const std::vector<Nd4jLong> &tArgs);
|
||||||
void setBArguments(const std::vector<bool> &tArgs);
|
void setBArguments(const std::vector<bool> &tArgs);
|
||||||
|
void setDArguments(const std::vector<nd4j::DataType> &dArgs);
|
||||||
|
|
||||||
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
|
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,9 @@ namespace nd4j {
|
||||||
std::vector<int> _iArgs;
|
std::vector<int> _iArgs;
|
||||||
std::vector<bool> _bArgs;
|
std::vector<bool> _bArgs;
|
||||||
std::vector<int> _axis;
|
std::vector<int> _axis;
|
||||||
|
std::vector<nd4j::DataType> _dArgs;
|
||||||
|
|
||||||
|
// TODO: remove this field
|
||||||
nd4j::DataType _dataType = nd4j::DataType::FLOAT32;
|
nd4j::DataType _dataType = nd4j::DataType::FLOAT32;
|
||||||
bool _isInplace;
|
bool _isInplace;
|
||||||
|
|
||||||
|
@ -93,6 +96,7 @@ namespace nd4j {
|
||||||
std::vector<double>* getTArguments();
|
std::vector<double>* getTArguments();
|
||||||
std::vector<int>* getIArguments();
|
std::vector<int>* getIArguments();
|
||||||
std::vector<bool>* getBArguments();
|
std::vector<bool>* getBArguments();
|
||||||
|
std::vector<nd4j::DataType>* getDArguments();
|
||||||
std::vector<int>* getAxis();
|
std::vector<int>* getAxis();
|
||||||
|
|
||||||
samediff::Engine engine();
|
samediff::Engine engine();
|
||||||
|
@ -100,6 +104,7 @@ namespace nd4j {
|
||||||
size_t numT();
|
size_t numT();
|
||||||
size_t numI();
|
size_t numI();
|
||||||
size_t numB();
|
size_t numB();
|
||||||
|
size_t numD();
|
||||||
|
|
||||||
std::pair<int, int>* input(int idx);
|
std::pair<int, int>* input(int idx);
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,9 @@ namespace nd4j {
|
||||||
|
|
||||||
class ND4J_EXPORT Node {
|
class ND4J_EXPORT Node {
|
||||||
protected:
|
protected:
|
||||||
|
// TODO: this field must be removed
|
||||||
nd4j::DataType _dataType;
|
nd4j::DataType _dataType;
|
||||||
|
|
||||||
OpType _opType;
|
OpType _opType;
|
||||||
ContextPrototype* _protoContext = nullptr;
|
ContextPrototype* _protoContext = nullptr;
|
||||||
Nd4jLong _opNum;
|
Nd4jLong _opNum;
|
||||||
|
@ -61,6 +63,7 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
// optional scalar. used in scalar ops and in summary stats
|
// optional scalar. used in scalar ops and in summary stats
|
||||||
|
// TODO: this field must be removed
|
||||||
NDArray _scalar;
|
NDArray _scalar;
|
||||||
|
|
||||||
bool _hasExternalOutputs;
|
bool _hasExternalOutputs;
|
||||||
|
@ -87,15 +90,15 @@ namespace nd4j {
|
||||||
int _scope_id = 0;
|
int _scope_id = 0;
|
||||||
std::string _scope_name;
|
std::string _scope_name;
|
||||||
|
|
||||||
|
// TODO: these 3 fields should be removed
|
||||||
int _rewindNode = -1;
|
int _rewindNode = -1;
|
||||||
std::pair<int, int> _rewindLayer = {-1, -1};
|
std::pair<int, int> _rewindLayer = {-1, -1};
|
||||||
|
|
||||||
Nd4jLong _frameId = -1;
|
Nd4jLong _frameId = -1;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
|
explicit Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
|
||||||
Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
|
explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
|
||||||
Node(const nd4j::graph::FlatNode *node);
|
explicit Node(const nd4j::graph::FlatNode *node);
|
||||||
~Node();
|
~Node();
|
||||||
|
|
||||||
bool equals(Node *other);
|
bool equals(Node *other);
|
||||||
|
|
|
@ -60,11 +60,13 @@ enum DType {
|
||||||
DType_QINT16 = 16,
|
DType_QINT16 = 16,
|
||||||
DType_BFLOAT16 = 17,
|
DType_BFLOAT16 = 17,
|
||||||
DType_UTF8 = 50,
|
DType_UTF8 = 50,
|
||||||
|
DType_UTF16 = 51,
|
||||||
|
DType_UTF32 = 52,
|
||||||
DType_MIN = DType_INHERIT,
|
DType_MIN = DType_INHERIT,
|
||||||
DType_MAX = DType_UTF8
|
DType_MAX = DType_UTF32
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const DType (&EnumValuesDType())[19] {
|
inline const DType (&EnumValuesDType())[21] {
|
||||||
static const DType values[] = {
|
static const DType values[] = {
|
||||||
DType_INHERIT,
|
DType_INHERIT,
|
||||||
DType_BOOL,
|
DType_BOOL,
|
||||||
|
@ -84,7 +86,9 @@ inline const DType (&EnumValuesDType())[19] {
|
||||||
DType_QINT8,
|
DType_QINT8,
|
||||||
DType_QINT16,
|
DType_QINT16,
|
||||||
DType_BFLOAT16,
|
DType_BFLOAT16,
|
||||||
DType_UTF8
|
DType_UTF8,
|
||||||
|
DType_UTF16,
|
||||||
|
DType_UTF32
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
@ -142,6 +146,8 @@ inline const char * const *EnumNamesDType() {
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"UTF8",
|
"UTF8",
|
||||||
|
"UTF16",
|
||||||
|
"UTF32",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
|
|
|
@ -42,7 +42,9 @@ nd4j.graph.DType = {
|
||||||
QINT8: 15,
|
QINT8: 15,
|
||||||
QINT16: 16,
|
QINT16: 16,
|
||||||
BFLOAT16: 17,
|
BFLOAT16: 17,
|
||||||
UTF8: 50
|
UTF8: 50,
|
||||||
|
UTF16: 51,
|
||||||
|
UTF32: 52
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -26,6 +26,8 @@ public enum DType : sbyte
|
||||||
QINT16 = 16,
|
QINT16 = 16,
|
||||||
BFLOAT16 = 17,
|
BFLOAT16 = 17,
|
||||||
UTF8 = 50,
|
UTF8 = 50,
|
||||||
|
UTF16 = 51,
|
||||||
|
UTF32 = 52,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,10 @@ public final class DType {
|
||||||
public static final byte QINT16 = 16;
|
public static final byte QINT16 = 16;
|
||||||
public static final byte BFLOAT16 = 17;
|
public static final byte BFLOAT16 = 17;
|
||||||
public static final byte UTF8 = 50;
|
public static final byte UTF8 = 50;
|
||||||
|
public static final byte UTF16 = 51;
|
||||||
|
public static final byte UTF32 = 52;
|
||||||
|
|
||||||
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", };
|
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", "UTF16", "UTF32", };
|
||||||
|
|
||||||
public static String name(int e) { return names[e]; }
|
public static String name(int e) { return names[e]; }
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,4 +22,6 @@ class DType(object):
|
||||||
QINT16 = 16
|
QINT16 = 16
|
||||||
BFLOAT16 = 17
|
BFLOAT16 = 17
|
||||||
UTF8 = 50
|
UTF8 = 50
|
||||||
|
UTF16 = 51
|
||||||
|
UTF32 = 52
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ public struct UIVariable : IFlatbufferObject
|
||||||
#endif
|
#endif
|
||||||
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
||||||
public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
|
public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
|
||||||
public DataType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
public DType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||||
public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
||||||
public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
#if ENABLE_SPAN_T
|
#if ENABLE_SPAN_T
|
||||||
|
@ -70,7 +70,7 @@ public struct UIVariable : IFlatbufferObject
|
||||||
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
||||||
StringOffset nameOffset = default(StringOffset),
|
StringOffset nameOffset = default(StringOffset),
|
||||||
VarType type = VarType.VARIABLE,
|
VarType type = VarType.VARIABLE,
|
||||||
DataType datatype = DataType.INHERIT,
|
DType datatype = DType.INHERIT,
|
||||||
VectorOffset shapeOffset = default(VectorOffset),
|
VectorOffset shapeOffset = default(VectorOffset),
|
||||||
VectorOffset controlDepsOffset = default(VectorOffset),
|
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||||
StringOffset outputOfOpOffset = default(StringOffset),
|
StringOffset outputOfOpOffset = default(StringOffset),
|
||||||
|
@ -101,7 +101,7 @@ public struct UIVariable : IFlatbufferObject
|
||||||
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
||||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||||
public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); }
|
public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); }
|
||||||
public static void AddDatatype(FlatBufferBuilder builder, DataType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); }
|
public static void AddDatatype(FlatBufferBuilder builder, DType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); }
|
||||||
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); }
|
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); }
|
||||||
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
||||||
|
|
|
@ -266,8 +266,8 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VarType type() const {
|
VarType type() const {
|
||||||
return static_cast<VarType>(GetField<int8_t>(VT_TYPE, 0));
|
return static_cast<VarType>(GetField<int8_t>(VT_TYPE, 0));
|
||||||
}
|
}
|
||||||
DataType datatype() const {
|
DType datatype() const {
|
||||||
return static_cast<DataType>(GetField<int8_t>(VT_DATATYPE, 0));
|
return static_cast<DType>(GetField<int8_t>(VT_DATATYPE, 0));
|
||||||
}
|
}
|
||||||
const flatbuffers::Vector<int64_t> *shape() const {
|
const flatbuffers::Vector<int64_t> *shape() const {
|
||||||
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
||||||
|
@ -342,7 +342,7 @@ struct UIVariableBuilder {
|
||||||
void add_type(VarType type) {
|
void add_type(VarType type) {
|
||||||
fbb_.AddElement<int8_t>(UIVariable::VT_TYPE, static_cast<int8_t>(type), 0);
|
fbb_.AddElement<int8_t>(UIVariable::VT_TYPE, static_cast<int8_t>(type), 0);
|
||||||
}
|
}
|
||||||
void add_datatype(DataType datatype) {
|
void add_datatype(DType datatype) {
|
||||||
fbb_.AddElement<int8_t>(UIVariable::VT_DATATYPE, static_cast<int8_t>(datatype), 0);
|
fbb_.AddElement<int8_t>(UIVariable::VT_DATATYPE, static_cast<int8_t>(datatype), 0);
|
||||||
}
|
}
|
||||||
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
||||||
|
@ -389,7 +389,7 @@ inline flatbuffers::Offset<UIVariable> CreateUIVariable(
|
||||||
flatbuffers::Offset<IntPair> id = 0,
|
flatbuffers::Offset<IntPair> id = 0,
|
||||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||||
VarType type = VarType_VARIABLE,
|
VarType type = VarType_VARIABLE,
|
||||||
DataType datatype = DataType_INHERIT,
|
DType datatype = DType_INHERIT,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||||
flatbuffers::Offset<flatbuffers::String> outputOfOp = 0,
|
flatbuffers::Offset<flatbuffers::String> outputOfOp = 0,
|
||||||
|
@ -421,7 +421,7 @@ inline flatbuffers::Offset<UIVariable> CreateUIVariableDirect(
|
||||||
flatbuffers::Offset<IntPair> id = 0,
|
flatbuffers::Offset<IntPair> id = 0,
|
||||||
const char *name = nullptr,
|
const char *name = nullptr,
|
||||||
VarType type = VarType_VARIABLE,
|
VarType type = VarType_VARIABLE,
|
||||||
DataType datatype = DataType_INHERIT,
|
DType datatype = DType_INHERIT,
|
||||||
const std::vector<int64_t> *shape = nullptr,
|
const std::vector<int64_t> *shape = nullptr,
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||||
const char *outputOfOp = nullptr,
|
const char *outputOfOp = nullptr,
|
||||||
|
|
|
@ -503,11 +503,11 @@ nd4j.graph.UIVariable.prototype.type = function() {
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {nd4j.graph.DataType}
|
* @returns {nd4j.graph.DType}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.UIVariable.prototype.datatype = function() {
|
nd4j.graph.UIVariable.prototype.datatype = function() {
|
||||||
var offset = this.bb.__offset(this.bb_pos, 10);
|
var offset = this.bb.__offset(this.bb_pos, 10);
|
||||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -668,10 +668,10 @@ nd4j.graph.UIVariable.addType = function(builder, type) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @param {nd4j.graph.DataType} datatype
|
* @param {nd4j.graph.DType} datatype
|
||||||
*/
|
*/
|
||||||
nd4j.graph.UIVariable.addDatatype = function(builder, datatype) {
|
nd4j.graph.UIVariable.addDatatype = function(builder, datatype) {
|
||||||
builder.addFieldInt8(3, datatype, nd4j.graph.DataType.INHERIT);
|
builder.addFieldInt8(3, datatype, nd4j.graph.DType.INHERIT);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -551,6 +551,18 @@ namespace nd4j {
|
||||||
bool Context::isInference() {
|
bool Context::isInference() {
|
||||||
return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
|
return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Context::setDArguments(nd4j::DataType *arguments, int numberOfArguments) {
|
||||||
|
_dArgs.clear();
|
||||||
|
for (int e = 0; e < numberOfArguments; e++)
|
||||||
|
_dArgs.emplace_back(arguments[e]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Context::setDArguments(const std::vector<nd4j::DataType> &dArgs) {
|
||||||
|
_dArgs.clear();
|
||||||
|
for (auto d:dArgs)
|
||||||
|
_dArgs.emplace_back(d);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -173,5 +173,13 @@ namespace nd4j {
|
||||||
|
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<nd4j::DataType> *ContextPrototype::getDArguments() {
|
||||||
|
return &_dArgs;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ContextPrototype::numD() {
|
||||||
|
return _dArgs.size();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -587,6 +587,12 @@ namespace nd4j {
|
||||||
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
||||||
|
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
||||||
|
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
this->setContextPrototype(block);
|
this->setContextPrototype(block);
|
||||||
this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
||||||
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
||||||
|
@ -618,6 +624,12 @@ namespace nd4j {
|
||||||
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
||||||
|
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
||||||
|
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
this->setContextPrototype(block);
|
this->setContextPrototype(block);
|
||||||
|
|
||||||
this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
||||||
|
@ -652,6 +664,12 @@ namespace nd4j {
|
||||||
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) {
|
||||||
|
for (int e = 0; e < (int) node->outputTypes()->size(); e++) {
|
||||||
|
block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (auto v: _dimensions)
|
for (auto v: _dimensions)
|
||||||
block->getAxis()->emplace_back(v);
|
block->getAxis()->emplace_back(v);
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace nd4j {
|
||||||
|
|
||||||
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
|
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
|
||||||
nd4j::ops::matmul mmul;
|
nd4j::ops::matmul mmul;
|
||||||
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
|
mmul.execute({&projectionPrep, &inputPrep}, {&projected});
|
||||||
|
|
||||||
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
||||||
projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
|
projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
|
||||||
|
@ -66,7 +66,7 @@ namespace nd4j {
|
||||||
nd4j::ops::matmul_bp mmulBp;
|
nd4j::ops::matmul_bp mmulBp;
|
||||||
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
|
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
|
||||||
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
|
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
|
||||||
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
|
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, std::vector<NDArray*>{&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
|
||||||
|
|
||||||
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
||||||
dLdProjectionMatrix->assign(dLdProjectionPrep);
|
dLdProjectionMatrix->assign(dLdProjectionPrep);
|
||||||
|
|
|
@ -1516,7 +1516,9 @@
|
||||||
|
|
||||||
#define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
#define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
||||||
|
|
||||||
|
#define D_ARG(INDEX) block.getDArguments()->at(INDEX)
|
||||||
#define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
|
#define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
|
||||||
|
#define I_ARG(INDEX) INT_ARG(INDEX)
|
||||||
#define T_ARG(INDEX) block.getTArguments()->at(INDEX)
|
#define T_ARG(INDEX) block.getTArguments()->at(INDEX)
|
||||||
#define B_ARG(INDEX) block.getBArguments()->at(INDEX)
|
#define B_ARG(INDEX) block.getBArguments()->at(INDEX)
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,8 @@ namespace nd4j {
|
||||||
public:
|
public:
|
||||||
BooleanOp(const char *name, int numInputs, bool scalar);
|
BooleanOp(const char *name, int numInputs, bool scalar);
|
||||||
|
|
||||||
bool evaluate(std::initializer_list<nd4j::NDArray*> args);
|
bool verify(const std::vector<nd4j::NDArray*>& args);
|
||||||
bool evaluate(std::vector<nd4j::NDArray*>& args);
|
bool verify(nd4j::graph::Context& block);
|
||||||
bool evaluate(nd4j::graph::Context& block);
|
|
||||||
|
|
||||||
Nd4jStatus execute(Context* block) override;
|
Nd4jStatus execute(Context* block) override;
|
||||||
|
|
||||||
|
|
|
@ -169,13 +169,22 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
virtual Nd4jStatus execute(Context* block);
|
virtual Nd4jStatus execute(Context* block);
|
||||||
|
|
||||||
nd4j::ResultSet* execute(std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs);
|
||||||
Nd4jStatus execute(std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
||||||
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
nd4j::ResultSet* execute(const std::vector<NDArray*>& inputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs = std::vector<bool>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
template <class T>
|
||||||
Nd4jStatus execute(std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs , std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, std::initializer_list<T> tArgs);
|
||||||
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
||||||
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
||||||
|
|
||||||
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
||||||
|
|
||||||
|
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false);
|
nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false);
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ namespace nd4j {
|
||||||
|
|
||||||
// at first step we build fwd activation
|
// at first step we build fwd activation
|
||||||
nd4j::ops::crelu op;
|
nd4j::ops::crelu op;
|
||||||
auto tmpResult = op.execute({input}, {}, {}, {});
|
auto tmpResult = op.evaluate({input});
|
||||||
if (tmpResult->status() != ND4J_STATUS_OK)
|
if (tmpResult->status() != ND4J_STATUS_OK)
|
||||||
return tmpResult->status();
|
return tmpResult->status();
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ namespace nd4j {
|
||||||
helpers::reluDerivative(block.launchContext(), actv, epsilonNext);
|
helpers::reluDerivative(block.launchContext(), actv, epsilonNext);
|
||||||
// now we split updated array into 2 chunks along last dimension
|
// now we split updated array into 2 chunks along last dimension
|
||||||
nd4j::ops::concat_bp opc;
|
nd4j::ops::concat_bp opc;
|
||||||
auto dec = opc.execute({input, input, actv}, {}, {-1}, {});
|
auto dec = opc.evaluate({input, input, actv}, {-1});
|
||||||
if (dec->status() != ND4J_STATUS_OK)
|
if (dec->status() != ND4J_STATUS_OK)
|
||||||
return dec->status();
|
return dec->status();
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,7 @@ namespace nd4j {
|
||||||
// if (output->isEmpty())
|
// if (output->isEmpty())
|
||||||
Nd4jLong width = condition->rankOf();
|
Nd4jLong width = condition->rankOf();
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
std::unique_ptr<ResultSet> res(op.execute({condition}, {}, {}, {}));
|
std::unique_ptr<ResultSet> res(op.evaluate({condition}));
|
||||||
REQUIRE_OK(res->status());
|
REQUIRE_OK(res->status());
|
||||||
NDArray* whereTrue = res->at(0);
|
NDArray* whereTrue = res->at(0);
|
||||||
if (whereTrue->isEmpty())
|
if (whereTrue->isEmpty())
|
||||||
|
|
|
@ -66,7 +66,7 @@ namespace nd4j {
|
||||||
auto gradY = OUTPUT_VARIABLE(1);
|
auto gradY = OUTPUT_VARIABLE(1);
|
||||||
gradX->assign(epsNext);
|
gradX->assign(epsNext);
|
||||||
nd4j::ops::floormod op;
|
nd4j::ops::floormod op;
|
||||||
std::unique_ptr<ResultSet> tmpResult(op.execute({x, y}, {}, {}, {}));
|
std::unique_ptr<ResultSet> tmpResult(op.evaluate({x, y}));
|
||||||
|
|
||||||
if (gradY->rankOf() == gradX->rankOf())
|
if (gradY->rankOf() == gradX->rankOf())
|
||||||
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
|
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
|
||||||
|
|
|
@ -91,7 +91,7 @@ namespace ops {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::softmax softmax;
|
nd4j::ops::softmax softmax;
|
||||||
softmax.execute({weights}, {weights}, {}, {-2}, {}, true);
|
softmax.execute({weights}, std::vector<NDArray*>{weights}, {}, {-2}, {}, {}, true);
|
||||||
|
|
||||||
mmul.execute({values, weights}, {output}, {}, {}, {});
|
mmul.execute({values, weights}, {output}, {}, {}, {});
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ namespace ops {
|
||||||
|
|
||||||
nd4j::ops::matmul_bp mmul_bp;
|
nd4j::ops::matmul_bp mmul_bp;
|
||||||
NDArray dLdw(weights.getShapeInfo(), block.workspace());
|
NDArray dLdw(weights.getShapeInfo(), block.workspace());
|
||||||
mmul_bp.execute({values, &weights, eps}, {dLdv, &dLdw}, {}, {}, {});
|
mmul_bp.execute({values, &weights, eps}, std::vector<NDArray*>{dLdv, &dLdw}, {}, {}, {});
|
||||||
|
|
||||||
NDArray dLds(preSoftmax.shapeInfo(), block.workspace());
|
NDArray dLds(preSoftmax.shapeInfo(), block.workspace());
|
||||||
nd4j::ops::softmax_bp softmax_bp;
|
nd4j::ops::softmax_bp softmax_bp;
|
||||||
|
@ -198,7 +198,7 @@ namespace ops {
|
||||||
if(normalization)
|
if(normalization)
|
||||||
dLds /= factor;
|
dLds /= factor;
|
||||||
|
|
||||||
mmul_bp.execute({keys, queries, &dLds}, {dLdk, dLdq}, {}, {1}, {});
|
mmul_bp.execute({keys, queries, &dLds}, std::vector<NDArray*>{dLdk, dLdq}, {}, {1}, {});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -239,7 +239,7 @@ namespace ops {
|
||||||
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
|
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
|
||||||
nd4j::ops::matmul_bp matmulBp;
|
nd4j::ops::matmul_bp matmulBp;
|
||||||
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
|
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
|
||||||
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
|
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, std::vector<NDArray*>{&dLdPreWo, dLdWo}, {}, {}, {});
|
||||||
|
|
||||||
// dLdAttn
|
// dLdAttn
|
||||||
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});
|
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace nd4j {
|
||||||
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
|
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
|
||||||
|
|
||||||
nd4j::ops::xw_plus_b op;
|
nd4j::ops::xw_plus_b op;
|
||||||
std::unique_ptr<ResultSet> result(op.execute({x, w, b}, {}, {}, {}));
|
std::unique_ptr<ResultSet> result(op.evaluate({x, w, b}));
|
||||||
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
|
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
|
||||||
|
|
||||||
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
|
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
||||||
bitcast res;
|
bitcast res;
|
||||||
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false);
|
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false);
|
||||||
if (tZ != &z0) {
|
if (tZ != &z0) {
|
||||||
delete tZ;
|
delete tZ;
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ namespace ops {
|
||||||
NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType());
|
NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType());
|
||||||
originalIndices.linspace(0);
|
originalIndices.linspace(0);
|
||||||
ops::dynamic_partition op;
|
ops::dynamic_partition op;
|
||||||
auto res = op.execute({&originalIndices, indices}, {}, {numPartition});
|
auto res = op.evaluate({&originalIndices, indices}, {numPartition});
|
||||||
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||||
ops::dynamic_stitch stichOp;
|
ops::dynamic_stitch stichOp;
|
||||||
std::vector<NDArray*> partitions(numPartition * 2);
|
std::vector<NDArray*> partitions(numPartition * 2);
|
||||||
|
@ -121,7 +121,7 @@ namespace ops {
|
||||||
partitions[i + numPartition] = gradOutList[i];
|
partitions[i + numPartition] = gradOutList[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false);
|
auto result = stichOp.evaluate(partitions, {numPartition});
|
||||||
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||||
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
|
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
|
||||||
outputList[1]->assign(indices);
|
outputList[1]->assign(indices);
|
||||||
|
|
|
@ -66,7 +66,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
nd4j::ops::gather op;
|
nd4j::ops::gather op;
|
||||||
|
|
||||||
std::unique_ptr<ResultSet> result(op.execute({input, indeces}, {}, {0}, {}));
|
std::unique_ptr<ResultSet> result(op.evaluate({input, indeces}, {0}));
|
||||||
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
|
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
|
||||||
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
|
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
|
||||||
output->assign(result->at(0));
|
output->assign(result->at(0));
|
||||||
|
@ -94,7 +94,7 @@ DECLARE_SHAPE_FN(embedding_lookup) {
|
||||||
for (int e = 1; e < outRank; e++)
|
for (int e = 1; e < outRank; e++)
|
||||||
shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
|
shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
|
||||||
|
|
||||||
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(inShapeInfo), shapeInfo);
|
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo);
|
||||||
return SHAPELIST(outShapeInfo);
|
return SHAPELIST(outShapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,8 @@ namespace nd4j {
|
||||||
DECLARE_SHAPE_FN(onehot) {
|
DECLARE_SHAPE_FN(onehot) {
|
||||||
auto inShape = inputShape->at(0);
|
auto inShape = inputShape->at(0);
|
||||||
|
|
||||||
|
nd4j::DataType dtype = block.numD() > 0 ? D_ARG(0) : nd4j::DataType::FLOAT32;
|
||||||
|
|
||||||
int depth = -1;
|
int depth = -1;
|
||||||
Nd4jLong axis = -1;
|
Nd4jLong axis = -1;
|
||||||
|
|
||||||
|
@ -99,7 +101,7 @@ namespace nd4j {
|
||||||
shape.push_back(shape::shapeOf(inShape)[e]);
|
shape.push_back(shape::shapeOf(inShape)[e]);
|
||||||
|
|
||||||
shape.insert(shape.begin() + axis, depth);
|
shape.insert(shape.begin() + axis, depth);
|
||||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', rank + 1, shape.data());
|
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', rank + 1, shape.data());
|
||||||
|
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
||||||
|
|
||||||
// forward steps
|
// forward steps
|
||||||
nd4j::ops::dynamic_rnn dynamicRnn;
|
nd4j::ops::dynamic_rnn dynamicRnn;
|
||||||
auto resultsFW = dynamicRnn.execute({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {}, {timeMajor}, {}, false, x->dataType());
|
auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
|
||||||
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
|
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
|
||||||
hFWFinal->assign(resultsFW->at(1));
|
hFWFinal->assign(resultsFW->at(1));
|
||||||
|
|
||||||
|
@ -97,17 +97,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
||||||
|
|
||||||
// reverse x
|
// reverse x
|
||||||
nd4j::ops::reverse_sequence reverse;
|
nd4j::ops::reverse_sequence reverse;
|
||||||
auto resultsIn = timeMajor ? reverse.execute({x, seqLen}, {}, {0, 1}, {}, false, x->dataType()) : reverse.execute({x, seqLen}, {}, {1, 0}, {}, false, x->dataType());
|
auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
|
||||||
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
||||||
auto revInput = resultsIn->at(0);
|
auto revInput = resultsIn->at(0);
|
||||||
|
|
||||||
// backward steps
|
// backward steps
|
||||||
auto resultsBW = dynamicRnn.execute({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {}, {timeMajor}, {});
|
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
|
||||||
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
||||||
hBWFinal->assign(resultsBW->at(1));
|
hBWFinal->assign(resultsBW->at(1));
|
||||||
|
|
||||||
// reverse hBWtemp
|
// reverse hBWtemp
|
||||||
auto resultsOut = timeMajor ? reverse.execute({hBWtemp, seqLen}, {}, {0, 1}, {}) : reverse.execute({hBWtemp, seqLen}, {}, {1, 0}, {});
|
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
|
||||||
hBW->assign(resultsOut->at(0));
|
hBW->assign(resultsOut->at(0));
|
||||||
|
|
||||||
delete resultsOut;
|
delete resultsOut;
|
||||||
|
|
|
@ -48,7 +48,7 @@ namespace ops {
|
||||||
|
|
||||||
auto conv = ArrayUtils::toLongVector(*block.getIArguments());
|
auto conv = ArrayUtils::toLongVector(*block.getIArguments());
|
||||||
|
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(in), conv);
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), shape::order(in), conv);
|
||||||
|
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,7 @@ namespace helpers {
|
||||||
throw std::runtime_error("multiUnique: cannot execute concat op properly.");
|
throw std::runtime_error("multiUnique: cannot execute concat op properly.");
|
||||||
|
|
||||||
nd4j::ops::unique opUnique;
|
nd4j::ops::unique opUnique;
|
||||||
auto uResult = opUnique.execute({&arrayFull}, {}, {}, {});
|
auto uResult = opUnique.evaluate({&arrayFull});
|
||||||
if (Status::OK() != uResult->status())
|
if (Status::OK() != uResult->status())
|
||||||
throw std::runtime_error("multiUnique: cannot execute unique op properly.");
|
throw std::runtime_error("multiUnique: cannot execute unique op properly.");
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ namespace nd4j {
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BooleanOp::evaluate(nd4j::graph::Context &block) {
|
bool BooleanOp::verify(nd4j::graph::Context &block) {
|
||||||
// check if scalar or not
|
// check if scalar or not
|
||||||
|
|
||||||
// validation?
|
// validation?
|
||||||
|
@ -58,11 +58,6 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BooleanOp::evaluate(std::initializer_list<nd4j::NDArray *> args) {
|
|
||||||
std::vector<nd4j::NDArray *> vec(args);
|
|
||||||
return this->evaluate(vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool BooleanOp::prepareOutputs(Context& ctx) {
|
bool BooleanOp::prepareOutputs(Context& ctx) {
|
||||||
|
|
||||||
auto variableSpace = ctx.getVariableSpace();
|
auto variableSpace = ctx.getVariableSpace();
|
||||||
|
@ -120,7 +115,7 @@ namespace nd4j {
|
||||||
return ND4J_STATUS_KERNEL_FAILURE;
|
return ND4J_STATUS_KERNEL_FAILURE;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BooleanOp::evaluate(std::vector<nd4j::NDArray *> &args) {
|
bool BooleanOp::verify(const std::vector<nd4j::NDArray *> &args) {
|
||||||
VariableSpace variableSpace;
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
int cnt = -1;
|
int cnt = -1;
|
||||||
|
@ -135,7 +130,7 @@ namespace nd4j {
|
||||||
Context block(1, &variableSpace, false);
|
Context block(1, &variableSpace, false);
|
||||||
block.fillInputs(in);
|
block.fillInputs(in);
|
||||||
|
|
||||||
return this->evaluate(block);
|
return this->verify(block);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 07.10.2017.
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/DeclarableOp.h>
|
#include <ops/declarable/DeclarableOp.h>
|
||||||
|
@ -27,6 +27,7 @@
|
||||||
#include <ops/declarable/OpRegistrator.h>
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
#include <exceptions/datatype_exception.h>
|
#include <exceptions/datatype_exception.h>
|
||||||
#include <helpers/StringUtils.h>
|
#include <helpers/StringUtils.h>
|
||||||
|
#include <cstdarg>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -164,6 +165,9 @@ namespace nd4j {
|
||||||
// we build list of input shapes
|
// we build list of input shapes
|
||||||
if (ctx.isFastPath()) {
|
if (ctx.isFastPath()) {
|
||||||
for (const auto p:ctx.fastpath_in()) {
|
for (const auto p:ctx.fastpath_in()) {
|
||||||
|
if (p == nullptr)
|
||||||
|
continue;
|
||||||
|
|
||||||
inSha.push_back(p->getShapeInfo());
|
inSha.push_back(p->getShapeInfo());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -357,6 +361,9 @@ namespace nd4j {
|
||||||
std::vector<nd4j::DataType> inputTypes(block.width());
|
std::vector<nd4j::DataType> inputTypes(block.width());
|
||||||
if (block.isFastPath()) {
|
if (block.isFastPath()) {
|
||||||
for (auto array: block.fastpath_in()) {
|
for (auto array: block.fastpath_in()) {
|
||||||
|
if (array == nullptr)
|
||||||
|
continue;
|
||||||
|
|
||||||
inputTypes[inT++] = array->dataType();
|
inputTypes[inT++] = array->dataType();
|
||||||
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
|
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
|
||||||
auto ctype = DataTypeUtils::asString(array->dataType());
|
auto ctype = DataTypeUtils::asString(array->dataType());
|
||||||
|
@ -394,6 +401,9 @@ namespace nd4j {
|
||||||
if (block.isFastPath()) {
|
if (block.isFastPath()) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (auto array: block.fastpath_out()) {
|
for (auto array: block.fastpath_out()) {
|
||||||
|
if (array == nullptr)
|
||||||
|
continue;
|
||||||
|
|
||||||
auto cType = array->dataType();
|
auto cType = array->dataType();
|
||||||
|
|
||||||
if (_descriptor->isSameMode()) {
|
if (_descriptor->isSameMode()) {
|
||||||
|
@ -762,39 +772,7 @@ namespace nd4j {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) {
|
Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<nd4j::DataType>& dArgs, bool isInplace, nd4j::DataType type) {
|
||||||
std::vector<NDArray*> ins(inputs);
|
|
||||||
std::vector<double> tas(tArgs);
|
|
||||||
std::vector<Nd4jLong> ias(iArgs);
|
|
||||||
std::vector<bool> bas(bArgs);
|
|
||||||
return this->execute(ins, tas, ias, bas, isInplace, type);
|
|
||||||
}
|
|
||||||
|
|
||||||
Nd4jStatus nd4j::ops::DeclarableOp::execute(std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) {
|
|
||||||
std::vector<NDArray*> ins(inputs);
|
|
||||||
std::vector<NDArray*> ous(outputs);
|
|
||||||
std::vector<double> tas(tArgs);
|
|
||||||
std::vector<Nd4jLong> ias(iArgs);
|
|
||||||
std::vector<bool> bas(bArgs);
|
|
||||||
return this->execute(ins, ous, tas, ias, bas, isInplace, type);
|
|
||||||
}
|
|
||||||
|
|
||||||
Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) {
|
|
||||||
std::vector<NDArray*> ins(inputs);
|
|
||||||
std::vector<NDArray*> ous(outputs);
|
|
||||||
std::vector<double> tas(tArgs);
|
|
||||||
std::vector<Nd4jLong> ias(iArgs);
|
|
||||||
std::vector<bool> bas(bArgs);
|
|
||||||
return this->execute(rng, ins, ous, tas, ias, bas, isInplace, type);
|
|
||||||
}
|
|
||||||
|
|
||||||
Nd4jStatus nd4j::ops::DeclarableOp::execute(std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) {
|
|
||||||
// TODO: nullptr here might be replaced
|
|
||||||
nd4j::graph::RandomGenerator rng(0, 0);
|
|
||||||
return execute(rng, inputs, outputs, tArgs, iArgs, bArgs, isInplace, type);
|
|
||||||
}
|
|
||||||
|
|
||||||
Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) {
|
|
||||||
VariableSpace variableSpace;
|
VariableSpace variableSpace;
|
||||||
FlowPath fp;
|
FlowPath fp;
|
||||||
variableSpace.setFlowPath(&fp);
|
variableSpace.setFlowPath(&fp);
|
||||||
|
@ -838,12 +816,124 @@ namespace nd4j {
|
||||||
for (int e = 0; e < bArgs.size(); e++)
|
for (int e = 0; e < bArgs.size(); e++)
|
||||||
block.getBArguments()->push_back(static_cast<int>(bArgs.at(e)));
|
block.getBArguments()->push_back(static_cast<int>(bArgs.at(e)));
|
||||||
|
|
||||||
|
for (int e = 0; e < dArgs.size(); e++)
|
||||||
|
block.getDArguments()->push_back(dArgs.at(e));
|
||||||
|
|
||||||
Nd4jStatus result = this->execute(&block);
|
Nd4jStatus result = this->execute(&block);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const std::vector<NDArray*>& inputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) {
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs) {
|
||||||
|
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<double> tArgs) {
|
||||||
|
std::vector<double> realArgs(tArgs);
|
||||||
|
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<float> tArgs) {
|
||||||
|
std::vector<double> realArgs;
|
||||||
|
for (auto v:tArgs)
|
||||||
|
realArgs.emplace_back(v);
|
||||||
|
|
||||||
|
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||||
|
std::vector<Nd4jLong> realArgs(iArgs);
|
||||||
|
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<int> iArgs) {
|
||||||
|
std::vector<Nd4jLong> realArgs;
|
||||||
|
for (auto v:iArgs)
|
||||||
|
realArgs.emplace_back(v);
|
||||||
|
|
||||||
|
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<bool> bArgs) {
|
||||||
|
std::vector<bool> realArgs(bArgs);
|
||||||
|
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
||||||
|
Context ctx(1);
|
||||||
|
|
||||||
|
for (int e = 0; e < inputs.size(); e++) {
|
||||||
|
if (inputs[e] == nullptr)
|
||||||
|
break;
|
||||||
|
|
||||||
|
ctx.setInputArray(e, inputs[e]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int e = 0; e < outputs.size(); e++) {
|
||||||
|
if (outputs[e] == nullptr)
|
||||||
|
break;
|
||||||
|
|
||||||
|
ctx.setOutputArray(e, outputs[e]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if (isInplace)
|
||||||
|
ctx.markInplace(isInplace);
|
||||||
|
|
||||||
|
ctx.setIArguments(iArgs);
|
||||||
|
ctx.setTArguments(tArgs);
|
||||||
|
ctx.setBArguments(bArgs);
|
||||||
|
ctx.setDArguments(dArgs);
|
||||||
|
|
||||||
|
return execute(&ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
|
||||||
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
|
||||||
|
std::vector<Nd4jLong> realArgs;
|
||||||
|
for (auto v:iArgs)
|
||||||
|
realArgs.emplace_back(v);
|
||||||
|
|
||||||
|
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||||
|
std::vector<Nd4jLong> realArgs(iArgs);
|
||||||
|
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
|
||||||
|
std::vector<double> realArgs;
|
||||||
|
for (auto v:tArgs)
|
||||||
|
realArgs.emplace_back(v);
|
||||||
|
|
||||||
|
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
||||||
|
std::vector<double> realArgs(tArgs);
|
||||||
|
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
||||||
|
std::vector<bool> realArgs(bArgs);
|
||||||
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
|
||||||
|
}
|
||||||
|
|
||||||
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
||||||
VariableSpace variableSpace;
|
VariableSpace variableSpace;
|
||||||
//ResultSet arrayList;
|
//ResultSet arrayList;
|
||||||
FlowPath fp;
|
FlowPath fp;
|
||||||
|
@ -862,7 +952,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
Context block(1, &variableSpace, false);
|
Context block(1, &variableSpace, false);
|
||||||
block.setDataType(0, type);
|
block.setDataType(0, nd4j::DataType::FLOAT32);
|
||||||
block.fillInputs(in);
|
block.fillInputs(in);
|
||||||
block.markInplace(isInplace);
|
block.markInplace(isInplace);
|
||||||
// block.setRNG(ProviderRNG::getInstance().getRNG());
|
// block.setRNG(ProviderRNG::getInstance().getRNG());
|
||||||
|
@ -870,13 +960,15 @@ namespace nd4j {
|
||||||
for (int e = 0; e < tArgs.size(); e++)
|
for (int e = 0; e < tArgs.size(); e++)
|
||||||
block.getTArguments()->emplace_back(tArgs.at(e));
|
block.getTArguments()->emplace_back(tArgs.at(e));
|
||||||
|
|
||||||
|
|
||||||
for (int e = 0; e < iArgs.size(); e++)
|
for (int e = 0; e < iArgs.size(); e++)
|
||||||
block.getIArguments()->emplace_back(iArgs.at(e));
|
block.getIArguments()->emplace_back(iArgs.at(e));
|
||||||
|
|
||||||
for (int e = 0; e < bArgs.size(); e++)
|
for (int e = 0; e < bArgs.size(); e++)
|
||||||
block.getBArguments()->push_back(bArgs.at(e));
|
block.getBArguments()->push_back(bArgs.at(e));
|
||||||
|
|
||||||
|
for (int e = 0; e < dArgs.size(); e++)
|
||||||
|
block.getDArguments()->push_back(dArgs.at(e));
|
||||||
|
|
||||||
Nd4jStatus status = this->execute(&block);
|
Nd4jStatus status = this->execute(&block);
|
||||||
auto arrayList = new ResultSet();
|
auto arrayList = new ResultSet();
|
||||||
if (isInplace)
|
if (isInplace)
|
||||||
|
@ -907,7 +999,8 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const nd4j::OpArgsHolder& holder, bool isInplace) {
|
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const nd4j::OpArgsHolder& holder, bool isInplace) {
|
||||||
return execute(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), isInplace, nd4j::DataType::DOUBLE);
|
// FIXME: add DArgs to OpArgsHolder
|
||||||
|
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<nd4j::DataType>(), isInplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) {
|
Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) {
|
||||||
|
|
|
@ -43,7 +43,7 @@ TEST_F(AttentionTests, basic_dot_product_attention) {
|
||||||
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
|
||||||
nd4j::ops::dot_product_attention op;
|
nd4j::ops::dot_product_attention op;
|
||||||
auto result = op.execute({&queries, &keys, &values}, {}, {1, 0}, {});
|
auto result = op.evaluate({&queries, &keys, &values}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -71,7 +71,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) {
|
||||||
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
|
||||||
nd4j::ops::dot_product_attention op;
|
nd4j::ops::dot_product_attention op;
|
||||||
auto result = op.execute({&queries, &keys, &values}, {}, {1, 1}, {});
|
auto result = op.evaluate({&queries, &keys, &values}, {1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -85,7 +85,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
|
||||||
mask.assign(1.);
|
mask.assign(1.);
|
||||||
|
|
||||||
nd4j::ops::dot_product_attention op;
|
nd4j::ops::dot_product_attention op;
|
||||||
auto result = op.execute({&queries, &keys, &values, &mask}, {}, {1, 0}, {});
|
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -117,7 +117,7 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) {
|
||||||
mask.assign(1.);
|
mask.assign(1.);
|
||||||
|
|
||||||
nd4j::ops::dot_product_attention op;
|
nd4j::ops::dot_product_attention op;
|
||||||
auto result = op.execute({&queries, &keys, &values, &mask}, {}, {1, 0}, {});
|
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -153,7 +153,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention) {
|
||||||
auto Wo = NDArrayFactory::create<float>('c', {2* 3, 4});
|
auto Wo = NDArrayFactory::create<float>('c', {2* 3, 4});
|
||||||
|
|
||||||
nd4j::ops::multi_head_dot_product_attention op;
|
nd4j::ops::multi_head_dot_product_attention op;
|
||||||
auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {}, {1, 0}, {});
|
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -197,7 +197,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::multi_head_dot_product_attention op;
|
nd4j::ops::multi_head_dot_product_attention op;
|
||||||
auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {}, {1, 0}, {});
|
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
|
|
@ -37,7 +37,7 @@ TEST_F(BackpropTests, Test_Add_1) {
|
||||||
NDArray e('c', {2, 3, 4}, nd4j::DataType::FLOAT32);
|
NDArray e('c', {2, 3, 4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::add_bp op;
|
nd4j::ops::add_bp op;
|
||||||
auto result = op.execute({&x, &y, &e}, {}, {}, {});
|
auto result = op.evaluate({&x, &y, &e});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ TEST_F(BooleanOpsTests, LtTest_1) {
|
||||||
nd4j::ops::lt_scalar op;
|
nd4j::ops::lt_scalar op;
|
||||||
|
|
||||||
|
|
||||||
ASSERT_TRUE(op.evaluate({x, y}));
|
ASSERT_TRUE(op.verify({x, y}));
|
||||||
|
|
||||||
delete x;
|
delete x;
|
||||||
delete y;
|
delete y;
|
||||||
|
@ -51,7 +51,7 @@ TEST_F(BooleanOpsTests, LtTest_2) {
|
||||||
nd4j::ops::lt_scalar op;
|
nd4j::ops::lt_scalar op;
|
||||||
|
|
||||||
|
|
||||||
ASSERT_FALSE(op.evaluate({x, y}));
|
ASSERT_FALSE(op.verify({x, y}));
|
||||||
|
|
||||||
delete x;
|
delete x;
|
||||||
delete y;
|
delete y;
|
||||||
|
@ -62,7 +62,7 @@ TEST_F(BooleanOpsTests, Is_non_decreasing_1) {
|
||||||
|
|
||||||
nd4j::ops::is_non_decreasing op;
|
nd4j::ops::is_non_decreasing op;
|
||||||
|
|
||||||
ASSERT_TRUE(op.evaluate({&x}));
|
ASSERT_TRUE(op.verify({&x}));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ TEST_F(BooleanOpsTests, Is_non_decreasing_2) {
|
||||||
|
|
||||||
nd4j::ops::is_non_decreasing op;
|
nd4j::ops::is_non_decreasing op;
|
||||||
|
|
||||||
ASSERT_FALSE(op.evaluate({&x}));
|
ASSERT_FALSE(op.verify({&x}));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_1) {
|
||||||
|
|
||||||
nd4j::ops::is_strictly_increasing op;
|
nd4j::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
ASSERT_TRUE(op.evaluate({&x}));
|
ASSERT_TRUE(op.verify({&x}));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_2) {
|
||||||
|
|
||||||
nd4j::ops::is_strictly_increasing op;
|
nd4j::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
ASSERT_FALSE(op.evaluate({&x}));
|
ASSERT_FALSE(op.verify({&x}));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_3) {
|
||||||
|
|
||||||
nd4j::ops::is_strictly_increasing op;
|
nd4j::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
ASSERT_FALSE(op.evaluate({&x}));
|
ASSERT_FALSE(op.verify({&x}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BooleanOpsTests, Is_strictly_increasing_5) {
|
TEST_F(BooleanOpsTests, Is_strictly_increasing_5) {
|
||||||
|
@ -107,7 +107,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_5) {
|
||||||
|
|
||||||
nd4j::ops::is_strictly_increasing op;
|
nd4j::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
ASSERT_TRUE(op.evaluate({&x}));
|
ASSERT_TRUE(op.verify({&x}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BooleanOpsTests, Is_strictly_increasing_6) {
|
TEST_F(BooleanOpsTests, Is_strictly_increasing_6) {
|
||||||
|
@ -118,7 +118,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_6) {
|
||||||
|
|
||||||
nd4j::ops::is_strictly_increasing op;
|
nd4j::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
ASSERT_FALSE(op.evaluate({&x}));
|
ASSERT_FALSE(op.verify({&x}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
|
TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
|
||||||
|
@ -126,7 +126,7 @@ TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
|
||||||
|
|
||||||
nd4j::ops::is_numeric_tensor op;
|
nd4j::ops::is_numeric_tensor op;
|
||||||
|
|
||||||
ASSERT_TRUE(op.evaluate({&x}));
|
ASSERT_TRUE(op.verify({&x}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BooleanOpsTests, test_where_1) {
|
TEST_F(BooleanOpsTests, test_where_1) {
|
||||||
|
@ -136,7 +136,7 @@ TEST_F(BooleanOpsTests, test_where_1) {
|
||||||
|
|
||||||
nd4j::ops::choose op;
|
nd4j::ops::choose op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &y}, {}, {3});
|
auto result = op.evaluate({&x, &y}, {3});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -46,7 +46,7 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
|
||||||
exp.applyBroadcast(broadcast::Add, {1}, y, exp);
|
exp.applyBroadcast(broadcast::Add, {1}, y, exp);
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) {
|
||||||
exp.applyBroadcast(broadcast::Multiply, {1}, y, exp);
|
exp.applyBroadcast(broadcast::Multiply, {1}, y, exp);
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::squaredsubtract op;
|
nd4j::ops::squaredsubtract op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 0, -1});
|
auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 0, -1});
|
||||||
|
|
||||||
nd4j::ops::subtract op;
|
nd4j::ops::subtract op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ TEST_F(BroadcastableOpsTests, Test_Maximum_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {2, 2, 2, 2, 3, 2});
|
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {2, 2, 2, 2, 3, 2});
|
||||||
|
|
||||||
nd4j::ops::maximum op;
|
nd4j::ops::maximum op;
|
||||||
auto result = op.execute({&x, &row}, {}, {}, {});
|
auto result = op.evaluate({&x, &row});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -173,7 +173,7 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 1, 1, 1, 1});
|
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 1, 1, 1, 1});
|
||||||
|
|
||||||
nd4j::ops::minimum op;
|
nd4j::ops::minimum op;
|
||||||
auto result = op.execute({&x, &col}, {}, {}, {});
|
auto result = op.evaluate({&x, &col});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -281,7 +281,7 @@ TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {3, 4, 5, 6});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {3, 4, 5, 6});
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -331,7 +331,7 @@ TEST_F(BroadcastableOpsTests, Test_Subtract_2) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {2}, {1.0f, 0.0f});
|
auto e = NDArrayFactory::create<float>('c', {2}, {1.0f, 0.0f});
|
||||||
|
|
||||||
nd4j::ops::subtract op;
|
nd4j::ops::subtract op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
@ -509,7 +509,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_7) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {1}, {8.f});
|
auto e = NDArrayFactory::create<float>('c', {1}, {8.f});
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -525,7 +525,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 1}, {8.f});
|
auto e = NDArrayFactory::create<float>('c', {1, 1}, {8.f});
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -544,7 +544,7 @@ TEST_F(BroadcastableOpsTests, broadcast_add_1) {
|
||||||
NDArray exp('c', {1,4}, {2,3,4,5}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {1,4}, {2,3,4,5}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
auto status = op.execute({&x, &y}, {&z});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(z.equalsTo(exp));
|
ASSERT_TRUE(z.equalsTo(exp));
|
||||||
|
@ -559,7 +559,7 @@ TEST_F(BroadcastableOpsTests, broadcast_equals_1) {
|
||||||
NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, nd4j::DataType::BOOL);
|
NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
nd4j::ops::equals op;
|
nd4j::ops::equals op;
|
||||||
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
auto status = op.execute({&x, &y}, {&z});
|
||||||
// z.printIndexedBuffer();
|
// z.printIndexedBuffer();
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
@ -603,7 +603,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
|
||||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
nd4j::ops::maximum op;
|
nd4j::ops::maximum op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -622,7 +622,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
||||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
nd4j::ops::maximum op;
|
nd4j::ops::maximum op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -641,7 +641,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
||||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
nd4j::ops::realdiv op;
|
nd4j::ops::realdiv op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -660,7 +660,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
||||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
nd4j::ops::realdiv op;
|
nd4j::ops::realdiv op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -679,7 +679,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
||||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});;
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});;
|
||||||
|
|
||||||
nd4j::ops::realdiv op;
|
nd4j::ops::realdiv op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -715,7 +715,7 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::greater op;
|
nd4j::ops::greater op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -741,7 +741,7 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
|
||||||
|
|
||||||
nd4j::ops::greater op;
|
nd4j::ops::greater op;
|
||||||
|
|
||||||
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
auto status = op.execute({&x, &y}, {&z});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_2) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
auto result = op.evaluate({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -172,7 +172,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -201,7 +201,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_4) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -231,7 +231,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_5) {
|
||||||
weights.permutei({2,3,1,0});
|
weights.permutei({2,3,1,0});
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
// output->printIndexedBuffer();
|
// output->printIndexedBuffer();
|
||||||
|
@ -250,7 +250,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_6) {
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {1, 2, 12, 2});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {1, 2, 12, 2});
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1});
|
auto result = op.evaluate({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -271,7 +271,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_7) {
|
||||||
weights = 3.;
|
weights = 3.;
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -305,7 +305,7 @@ TEST_F(ConvolutionTests1, conv2d_8) {
|
||||||
1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412});
|
1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412});
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
// output->printBuffer();
|
// output->printBuffer();
|
||||||
|
@ -419,7 +419,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_2) {
|
||||||
|
|
||||||
nd4j::ops::sconv2d op;
|
nd4j::ops::sconv2d op;
|
||||||
|
|
||||||
auto resultFF = op.execute({&input, &weightsD, &weightsP}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}, {});
|
auto resultFF = op.evaluate({&input, &weightsD, &weightsP}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
|
|
||||||
auto z = resultFF->at(0);
|
auto z = resultFF->at(0);
|
||||||
//z->printShapeInfo("FF shape");
|
//z->printShapeInfo("FF shape");
|
||||||
|
@ -452,8 +452,8 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_3) {
|
||||||
auto expOutput = NDArrayFactory::create<TypeParam>('c', {3, 2, 8, 8});
|
auto expOutput = NDArrayFactory::create<TypeParam>('c', {3, 2, 8, 8});
|
||||||
|
|
||||||
nd4j::ops::sconv2d op;
|
nd4j::ops::sconv2d op;
|
||||||
Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0}, {});
|
Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {1, 1, 1, 1, 0, 0, 1, 1, 0});
|
||||||
auto result = op.execute({&input, &weightsD, &weightsP, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0}, {});
|
auto result = op.evaluate({&input, &weightsD, &weightsP, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 0});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -493,7 +493,7 @@ TEST_F(ConvolutionTests1, sconv2d_4) {
|
||||||
0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515});
|
0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515});
|
||||||
|
|
||||||
nd4j::ops::sconv2d op;
|
nd4j::ops::sconv2d op;
|
||||||
auto results = op.execute({&input, &weightsD, &weightsP, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weightsD, &weightsP, &biases}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -531,7 +531,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) {
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp op;
|
nd4j::ops::conv2d_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});
|
auto results = op.evaluate({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});
|
||||||
|
|
||||||
ASSERT_TRUE(results->size() == 3);
|
ASSERT_TRUE(results->size() == 3);
|
||||||
|
|
||||||
|
@ -581,7 +581,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) {
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp op;
|
nd4j::ops::conv2d_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});
|
auto results = op.evaluate({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});
|
||||||
|
|
||||||
ASSERT_TRUE(results->size() == 2);
|
ASSERT_TRUE(results->size() == 2);
|
||||||
|
|
||||||
|
@ -664,7 +664,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::sconv2d op;
|
nd4j::ops::sconv2d op;
|
||||||
auto resultFF = op.execute({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {});
|
auto resultFF = op.evaluate({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {});
|
||||||
|
|
||||||
auto z = resultFF->at(0);
|
auto z = resultFF->at(0);
|
||||||
|
|
||||||
|
@ -674,7 +674,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) {
|
||||||
|
|
||||||
nd4j::ops::conv2d op2d;
|
nd4j::ops::conv2d op2d;
|
||||||
// weightsP.printShapeInfo();
|
// weightsP.printShapeInfo();
|
||||||
auto result2D = op2d.execute({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});
|
auto result2D = op2d.evaluate({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});
|
||||||
|
|
||||||
auto z2d = result2D->at(0);
|
auto z2d = result2D->at(0);
|
||||||
// z2d->printBuffer();
|
// z2d->printBuffer();
|
||||||
|
@ -717,7 +717,7 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::deconv2d_bp op;
|
nd4j::ops::deconv2d_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -771,7 +771,7 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_2) {
|
||||||
|
|
||||||
nd4j::ops::deconv2d_bp<double> op;
|
nd4j::ops::deconv2d_bp<double> op;
|
||||||
|
|
||||||
auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0});
|
auto result = op.evaluate({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
|
||||||
|
@ -791,7 +791,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
|
||||||
bias.linspace(1);
|
bias.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0});
|
auto result_FF = op.evaluate({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result_FF->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result_FF->status());
|
||||||
|
|
||||||
|
@ -805,7 +805,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
|
||||||
auto epsilonNxt = new NDArray(z->dup());
|
auto epsilonNxt = new NDArray(z->dup());
|
||||||
epsilonNxt->linspace(1);
|
epsilonNxt->linspace(1);
|
||||||
|
|
||||||
auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0});
|
auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result_BP->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result_BP->status());
|
||||||
|
|
||||||
auto eps = result_BP->at(0);
|
auto eps = result_BP->at(0);
|
||||||
|
@ -833,7 +833,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1, 1,0});
|
auto result = op.evaluate({&input, &weights}, {}, {2, 1, 0, 1, 1,0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -860,7 +860,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_1) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -892,7 +892,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_2) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -923,7 +923,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_3) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -954,7 +954,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_4) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -985,7 +985,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_5) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1016,7 +1016,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1048,7 +1048,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_7) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1081,7 +1081,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_8) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1129,7 +1129,7 @@ TEST_F(ConvolutionTests1, Test_Dilation2D_1) {
|
||||||
weights.linspace(1);
|
weights.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::dilation2d op;
|
nd4j::ops::dilation2d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {1, 1,2,2,1, 1,2,2,1});
|
auto result = op.evaluate({&input, &weights}, {1, 1,2,2,1, 1,2,2,1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1149,7 +1149,7 @@ TEST_F(ConvolutionTests1, Test_Dilation2D_2) {
|
||||||
weights.linspace(1);
|
weights.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::dilation2d op;
|
nd4j::ops::dilation2d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {0, 1,2,2,1, 1,2,2,1});
|
auto result = op.evaluate({&input, &weights}, {0, 1,2,2,1, 1,2,2,1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1188,7 +1188,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp op;
|
nd4j::ops::conv2d_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
auto gradW = results->at(1);
|
auto gradW = results->at(1);
|
||||||
|
|
||||||
|
@ -1231,7 +1231,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp op;
|
nd4j::ops::conv2d_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
auto gradW = results->at(1);
|
auto gradW = results->at(1);
|
||||||
|
|
||||||
|
@ -1276,7 +1276,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) {
|
||||||
expGradW.permutei({2,3,1,0});
|
expGradW.permutei({2,3,1,0});
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp op;
|
nd4j::ops::conv2d_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
auto gradW = results->at(1);
|
auto gradW = results->at(1);
|
||||||
auto gradB = results->at(2);
|
auto gradB = results->at(2);
|
||||||
|
@ -1358,7 +1358,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::conv3dnew_bp op;
|
nd4j::ops::conv3dnew_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
auto gradW = results->at(1);
|
auto gradW = results->at(1);
|
||||||
|
|
||||||
|
@ -1406,7 +1406,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::conv3dnew_bp op;
|
nd4j::ops::conv3dnew_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
auto gradW = results->at(1);
|
auto gradW = results->at(1);
|
||||||
|
|
||||||
|
@ -1459,7 +1459,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) {
|
||||||
expGradW.permutei({2, 3, 4, 1, 0});
|
expGradW.permutei({2, 3, 4, 1, 0});
|
||||||
|
|
||||||
nd4j::ops::conv3dnew_bp op;
|
nd4j::ops::conv3dnew_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* gradI = results->at(0);
|
auto* gradI = results->at(0);
|
||||||
auto* gradW = results->at(1);
|
auto* gradW = results->at(1);
|
||||||
auto* gradB = results->at(2);
|
auto* gradB = results->at(2);
|
||||||
|
@ -1502,7 +1502,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d_bp op;
|
nd4j::ops::depthwise_conv2d_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto* gradI = results->at(0);
|
auto* gradI = results->at(0);
|
||||||
auto* gradW = results->at(1);
|
auto* gradW = results->at(1);
|
||||||
|
|
||||||
|
@ -1540,7 +1540,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d_bp op;
|
nd4j::ops::depthwise_conv2d_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto* gradI = results->at(0);
|
auto* gradI = results->at(0);
|
||||||
auto* gradW = results->at(1);
|
auto* gradW = results->at(1);
|
||||||
|
|
||||||
|
@ -1568,7 +1568,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test3) {
|
||||||
auto gradB = b.like();
|
auto gradB = b.like();
|
||||||
|
|
||||||
nd4j:ops::depthwise_conv2d_bp op;
|
nd4j:ops::depthwise_conv2d_bp op;
|
||||||
auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {});
|
auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1607,7 +1607,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test4) {
|
||||||
NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, nd4j::DataType::FLOAT32);
|
NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d_bp op;
|
nd4j::ops::depthwise_conv2d_bp op;
|
||||||
ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
ResultSet* results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
NDArray* gradI = results->at(0);
|
NDArray* gradI = results->at(0);
|
||||||
NDArray* gradW = results->at(1);
|
NDArray* gradW = results->at(1);
|
||||||
NDArray* gradB = results->at(2);
|
NDArray* gradB = results->at(2);
|
||||||
|
@ -1662,7 +1662,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test5) {
|
||||||
NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, nd4j::DataType::FLOAT32);
|
NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d_bp op;
|
nd4j::ops::depthwise_conv2d_bp op;
|
||||||
ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
ResultSet* results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
NDArray* gradI = results->at(0);
|
NDArray* gradI = results->at(0);
|
||||||
NDArray* gradW = results->at(1);
|
NDArray* gradW = results->at(1);
|
||||||
NDArray* gradB = results->at(2);
|
NDArray* gradB = results->at(2);
|
||||||
|
@ -1706,7 +1706,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test6) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d_bp op;
|
nd4j::ops::depthwise_conv2d_bp op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto* gradI = results->at(0);
|
auto* gradI = results->at(0);
|
||||||
auto* gradW = results->at(1);
|
auto* gradW = results->at(1);
|
||||||
|
|
||||||
|
@ -1742,7 +1742,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) {
|
||||||
weights = 1.;
|
weights = 1.;
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1774,7 +1774,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1801,7 +1801,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1827,7 +1827,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test4) {
|
||||||
expected = 48.;
|
expected = 48.;
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1855,7 +1855,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test5) {
|
||||||
bias = 1.;
|
bias = 1.;
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
// output->printIndexedBuffer();
|
// output->printIndexedBuffer();
|
||||||
|
@ -1884,7 +1884,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) {
|
||||||
weights = 0.5;
|
weights = 0.5;
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
// output->printIndexedBuffer();
|
// output->printIndexedBuffer();
|
||||||
|
@ -1915,7 +1915,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) {
|
||||||
weights.permutei({2, 3, 4, 1, 0});
|
weights.permutei({2, 3, 4, 1, 0});
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
// output->printIndexedBuffer();
|
// output->printIndexedBuffer();
|
||||||
|
@ -1944,7 +1944,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) {
|
||||||
weights.permutei({2, 3, 4, 1, 0});
|
weights.permutei({2, 3, 4, 1, 0});
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1961,7 +1961,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test9) {
|
||||||
auto e = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});
|
auto e = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto result = op.execute({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
|
auto result = op.evaluate({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1977,7 +1977,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test10) {
|
||||||
auto exp = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});
|
auto exp = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto result = op.execute({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
|
auto result = op.evaluate({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ShapeList shapeList({x.shapeInfo(), w.shapeInfo()});
|
ShapeList shapeList({x.shapeInfo(), w.shapeInfo()});
|
||||||
|
@ -2039,7 +2039,7 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) {
|
||||||
bias = 1.;
|
bias = 1.;
|
||||||
|
|
||||||
nd4j::ops::pointwise_conv2d op;
|
nd4j::ops::pointwise_conv2d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2063,7 +2063,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test11) {
|
||||||
weights = 1.;
|
weights = 1.;
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2087,7 +2087,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test12) {
|
||||||
weights = 1.;
|
weights = 1.;
|
||||||
|
|
||||||
nd4j::ops::conv3dnew op;
|
nd4j::ops::conv3dnew op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2205,7 +2205,7 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) {
|
||||||
31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f});
|
31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f});
|
||||||
|
|
||||||
nd4j::ops::upsampling2d op;
|
nd4j::ops::upsampling2d op;
|
||||||
auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW});
|
auto results = op.evaluate({&input}, {factorH, factorW, isNCHW});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2233,7 +2233,7 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) {
|
||||||
33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f});
|
33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f});
|
||||||
|
|
||||||
nd4j::ops::upsampling2d op;
|
nd4j::ops::upsampling2d op;
|
||||||
auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW});
|
auto results = op.evaluate({&input}, {factorH, factorW, isNCHW});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2271,7 +2271,7 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) {
|
||||||
67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f});
|
67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f});
|
||||||
|
|
||||||
nd4j::ops::upsampling3d op;
|
nd4j::ops::upsampling3d op;
|
||||||
auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW});
|
auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2305,7 +2305,7 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) {
|
||||||
65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f});
|
65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f});
|
||||||
|
|
||||||
nd4j::ops::upsampling3d op;
|
nd4j::ops::upsampling3d op;
|
||||||
auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW});
|
auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW});
|
||||||
auto* output = results->at(0);
|
auto* output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2332,7 +2332,7 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test1) {
|
||||||
expGradI = 8.;
|
expGradI = 8.;
|
||||||
|
|
||||||
nd4j::ops::upsampling3d_bp op;
|
nd4j::ops::upsampling3d_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
|
auto results = op.evaluate({&input, &gradO}, {isNCDHW});
|
||||||
auto* gradI = results->at(0);
|
auto* gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2359,7 +2359,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) {
|
||||||
|
|
||||||
nd4j::ops::conv2d_input_bp op;
|
nd4j::ops::conv2d_input_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1});
|
auto results = op.evaluate({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1});
|
||||||
|
|
||||||
ASSERT_TRUE(results->size() == 1);
|
ASSERT_TRUE(results->size() == 1);
|
||||||
|
|
||||||
|
@ -2424,7 +2424,7 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {
|
||||||
4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32);
|
4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::upsampling3d_bp op;
|
nd4j::ops::upsampling3d_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
|
auto results = op.evaluate({&input, &gradO}, {isNCDHW});
|
||||||
auto* gradI = results->at(0);
|
auto* gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2457,7 +2457,7 @@ TEST_F(ConvolutionTests1, deconv2d_test1) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
@ -2490,7 +2490,7 @@ TEST_F(ConvolutionTests1, deconv2d_test2) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -2522,7 +2522,7 @@ TEST_F(ConvolutionTests1, deconv2d_test3) {
|
||||||
bias = 0.2;
|
bias = 0.2;
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
@ -2557,7 +2557,7 @@ TEST_F(ConvolutionTests1, deconv2d_test4) {
|
||||||
weights.permutei({2,3,1,0});
|
weights.permutei({2,3,1,0});
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
|
auto result = op.evaluate({&input, &weights}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
// z->printShapeInfo();
|
// z->printShapeInfo();
|
||||||
|
@ -2584,7 +2584,7 @@ TEST_F(ConvolutionTests1, deconv2d_test5) {
|
||||||
weights.permutei({2,3,1,0});
|
weights.permutei({2,3,1,0});
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
auto result = op.execute({&input, &weights}, {&z}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0},{});
|
auto result = op.execute({&input, &weights}, {&z}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result);
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||||
|
|
||||||
|
@ -2615,7 +2615,7 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
@ -2640,7 +2640,7 @@ TEST_F(ConvolutionTests1, deconv2d_test7) {
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
|
|
||||||
auto result = op.execute({&input, &weights, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});
|
auto result = op.evaluate({&input, &weights, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2683,7 +2683,7 @@ TEST_F(ConvolutionTests1, deconv2d_test8) {
|
||||||
1.471922, 1.484062, 1.212039, 1.144419, 1.266123});
|
1.471922, 1.484062, 1.212039, 1.144419, 1.266123});
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
@ -2718,7 +2718,7 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) {
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::deconv2d_tf op;
|
nd4j::ops::deconv2d_tf op;
|
||||||
auto results = op.execute({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
auto results = op.evaluate({&outShape, &weights, &input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -46,7 +46,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});
|
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status());
|
ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status());
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});
|
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -161,7 +161,7 @@ TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,4});
|
auto exp = NDArrayFactory::create<double>('c', {3,4});
|
||||||
exp.linspace(0.9, 0.9);
|
exp.linspace(0.9, 0.9);
|
||||||
nd4j::ops::apply_sgd op;
|
nd4j::ops::apply_sgd op;
|
||||||
auto result = op.execute({&x, &y}, {1.}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &y}, {1.}, {});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) {
|
||||||
auto y = NDArrayFactory::create<double>('c', {1,4}, {0.1,0.2,0.3,0.4});
|
auto y = NDArrayFactory::create<double>('c', {1,4}, {0.1,0.2,0.3,0.4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4});
|
auto exp = NDArrayFactory::create<double>('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4});
|
||||||
nd4j::ops::assign op;
|
nd4j::ops::assign op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) {
|
||||||
auto exp1 = NDArrayFactory::create<double>('c', {3,4}); // zero
|
auto exp1 = NDArrayFactory::create<double>('c', {3,4}); // zero
|
||||||
auto exp2 = NDArrayFactory::create<double>('c', {1,4}, {3, 6, 9, 12});
|
auto exp2 = NDArrayFactory::create<double>('c', {1,4}, {3, 6, 9, 12});
|
||||||
nd4j::ops::assign_bp op;
|
nd4j::ops::assign_bp op;
|
||||||
auto result = op.execute({&x, &y, &eps}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &y, &eps});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto z1 = result->at(0);
|
auto z1 = result->at(0);
|
||||||
auto z2 = result->at(1);
|
auto z2 = result->at(1);
|
||||||
|
@ -208,7 +208,7 @@ TEST_F(DeclarableOpsTests1, AXpY_Test_1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,4});
|
auto exp = NDArrayFactory::create<double>('c', {3,4});
|
||||||
exp.linspace(3, 3);
|
exp.linspace(3, 3);
|
||||||
nd4j::ops::axpy op;
|
nd4j::ops::axpy op;
|
||||||
auto result = op.execute({&x, &y}, {2.}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &y}, {2.});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -249,7 +249,7 @@ TEST_F(DeclarableOpsTests1, TestTensorMmul1) {
|
||||||
NDArray exp('c', {2, 2}, {650.0, 1586.0, 1586.0, 4250.0}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {2, 2}, {650.0, 1586.0, 1586.0, 4250.0}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
nd4j::ops::tensormmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2});
|
auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -269,7 +269,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) {
|
||||||
NDArray exp('c', {2, 2}, {2300.0, 2444.0, 2444.0, 2600.0}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {2, 2}, {2300.0, 2444.0, 2444.0, 2600.0}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
nd4j::ops::tensormmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2});
|
auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -289,7 +289,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) {
|
||||||
NDArray exp('f', {2, 2}, {1090.0, 2818.0, 1168.0, 3040.0}, nd4j::DataType::FLOAT32);
|
NDArray exp('f', {2, 2}, {1090.0, 2818.0, 1168.0, 3040.0}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
nd4j::ops::tensormmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2});
|
auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -309,7 +309,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) {
|
||||||
NDArray exp('f', {2, 2}, {1090.0, 1168.0, 2818.0, 3040.0}, nd4j::DataType::FLOAT32);
|
NDArray exp('f', {2, 2}, {1090.0, 1168.0, 2818.0, 3040.0}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
nd4j::ops::tensormmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2});
|
auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -506,7 +506,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) {
|
||||||
|
|
||||||
nd4j::ops::subtract subOp;
|
nd4j::ops::subtract subOp;
|
||||||
|
|
||||||
auto res = subOp.execute({&x, &y}, {}, {});
|
auto res = subOp.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
|
|
||||||
|
@ -767,7 +767,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) {
|
||||||
|
|
||||||
nd4j::ops::reversesubtract subOp;
|
nd4j::ops::reversesubtract subOp;
|
||||||
|
|
||||||
auto res = subOp.execute({&x, &y}, {}, {});
|
auto res = subOp.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
||||||
|
@ -792,7 +792,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) {
|
||||||
|
|
||||||
nd4j::ops::reversesubtract subOp;
|
nd4j::ops::reversesubtract subOp;
|
||||||
|
|
||||||
auto res = subOp.execute({&x, &y}, {}, {});
|
auto res = subOp.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
||||||
|
@ -815,7 +815,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) {
|
||||||
ASSERT_TRUE(z.equalsTo(&exp));
|
ASSERT_TRUE(z.equalsTo(&exp));
|
||||||
nd4j::ops::reversesubtract subOp;
|
nd4j::ops::reversesubtract subOp;
|
||||||
|
|
||||||
auto res = subOp.execute({&x, &y}, {}, {});
|
auto res = subOp.evaluate({&x, &y});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
||||||
|
|
||||||
|
@ -841,7 +841,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) {
|
||||||
|
|
||||||
nd4j::ops::reversemod subOp;
|
nd4j::ops::reversemod subOp;
|
||||||
|
|
||||||
auto res = subOp.execute({&x, &y}, {}, {});
|
auto res = subOp.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
||||||
|
@ -868,7 +868,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) {
|
||||||
|
|
||||||
nd4j::ops::reversemod subOp;
|
nd4j::ops::reversemod subOp;
|
||||||
|
|
||||||
auto res = subOp.execute({&x, &y}, {}, {});
|
auto res = subOp.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res->at(0)->equalsTo(&exp));
|
||||||
|
@ -1157,7 +1157,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) {
|
||||||
|
|
||||||
nd4j::ops::divide div;
|
nd4j::ops::divide div;
|
||||||
|
|
||||||
auto res = div.execute({&x, &y}, {}, {});
|
auto res = div.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
||||||
|
@ -1176,7 +1176,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) {
|
||||||
exp.assign(3);
|
exp.assign(3);
|
||||||
|
|
||||||
nd4j::ops::divide_no_nan div;
|
nd4j::ops::divide_no_nan div;
|
||||||
auto res = div.execute({&x, &y}, {}, {});
|
auto res = div.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
||||||
|
@ -1192,7 +1192,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) {
|
||||||
auto exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2});
|
auto exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2});
|
||||||
|
|
||||||
nd4j::ops::divide_no_nan div;
|
nd4j::ops::divide_no_nan div;
|
||||||
auto res = div.execute({&x, &y}, {}, {});
|
auto res = div.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
||||||
|
@ -1212,7 +1212,7 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {
|
||||||
|
|
||||||
nd4j::ops::reversedivide div;
|
nd4j::ops::reversedivide div;
|
||||||
|
|
||||||
auto res = div.execute({&x, &y}, {}, {});
|
auto res = div.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
|
||||||
|
@ -1469,7 +1469,7 @@ TEST_F(DeclarableOpsTests1, Test_Cast_1) {
|
||||||
yExp.linspace(1);
|
yExp.linspace(1);
|
||||||
nd4j::ops::cast op;
|
nd4j::ops::cast op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {3});
|
auto result = op.evaluate({&x}, {}, {3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1673,7 +1673,7 @@ TEST_F(DeclarableOpsTests1, Reshape3) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
|
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {-99, 3, 4, 5});
|
auto result = op.evaluate({&x}, {}, {-99, 3, 4, 5});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1688,7 +1688,7 @@ TEST_F(DeclarableOpsTests1, Reshape4) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
|
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {3, 4, 5});
|
auto result = op.evaluate({&x}, {}, {3, 4, 5});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1703,7 +1703,7 @@ TEST_F(DeclarableOpsTests1, Reshape5) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
|
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {5, 4, 3});
|
auto result = op.evaluate({&x}, {}, {5, 4, 3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1715,7 +1715,7 @@ TEST_F(DeclarableOpsTests1, Reshape6){
|
||||||
auto exp = NDArrayFactory::create<float>('c', {4, 15});
|
auto exp = NDArrayFactory::create<float>('c', {4, 15});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {4, -1});
|
auto result = op.evaluate({&x}, {}, {4, -1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1732,7 +1732,7 @@ TEST_F(DeclarableOpsTests1, Reshape7){
|
||||||
auto exp = NDArrayFactory::create<float>('c', {60});
|
auto exp = NDArrayFactory::create<float>('c', {60});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {-1});
|
auto result = op.evaluate({&x}, {}, {-1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2217,7 +2217,7 @@ TEST_F(DeclarableOpsTests1, IsMax1) {
|
||||||
exp.p<bool>(2, 2, true);
|
exp.p<bool>(2, 2, true);
|
||||||
|
|
||||||
nd4j::ops::ismax ismaxOp;
|
nd4j::ops::ismax ismaxOp;
|
||||||
auto result = ismaxOp.execute({&x}, {}, {1});
|
auto result = ismaxOp.evaluate({&x}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2239,7 +2239,7 @@ TEST_F(DeclarableOpsTests1, IsMax2) {
|
||||||
exp.p<bool>(2, 2, true);
|
exp.p<bool>(2, 2, true);
|
||||||
|
|
||||||
nd4j::ops::ismax ismaxOp;
|
nd4j::ops::ismax ismaxOp;
|
||||||
auto result = ismaxOp.execute({&x}, {}, {0, 1});
|
auto result = ismaxOp.evaluate({&x}, {}, {0, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2261,7 +2261,7 @@ TEST_F(DeclarableOpsTests1, IsMax3) {
|
||||||
//exp.p<bool>(2, 2, true);
|
//exp.p<bool>(2, 2, true);
|
||||||
|
|
||||||
nd4j::ops::ismax ismaxOp;
|
nd4j::ops::ismax ismaxOp;
|
||||||
auto result = ismaxOp.execute({&x}, {}, {0});
|
auto result = ismaxOp.evaluate({&x}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2279,7 +2279,7 @@ TEST_F(DeclarableOpsTests1, IsMax4) {
|
||||||
auto e = NDArrayFactory::create<bool>('c', {6}, {false, false, false, true, false, false});
|
auto e = NDArrayFactory::create<bool>('c', {6}, {false, false, false, true, false, false});
|
||||||
|
|
||||||
nd4j::ops::ismax op;
|
nd4j::ops::ismax op;
|
||||||
auto result = op.execute({&x}, {&z}, {}, {}, {});
|
auto result = op.execute({&x}, {&z});
|
||||||
ASSERT_EQ(Status::OK(), result);
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
@ -2343,7 +2343,7 @@ TEST_F(DeclarableOpsTests1, sru_test1) {
|
||||||
mask.assign(1.);
|
mask.assign(1.);
|
||||||
|
|
||||||
nd4j::ops::sru op;
|
nd4j::ops::sru op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {});
|
auto results = op.evaluate({&input, &weights, &bias, &init, &mask});
|
||||||
ASSERT_TRUE(results->size() == 2);
|
ASSERT_TRUE(results->size() == 2);
|
||||||
|
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
@ -2390,7 +2390,7 @@ TEST_F(DeclarableOpsTests1, sru_bp) {
|
||||||
inGradH.assign(0.5);
|
inGradH.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sru_bp bp;
|
nd4j::ops::sru_bp bp;
|
||||||
auto resultsBP = bp.execute({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {});
|
auto resultsBP = bp.evaluate({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {});
|
||||||
ASSERT_TRUE(resultsBP->size() == 4);
|
ASSERT_TRUE(resultsBP->size() == 4);
|
||||||
|
|
||||||
auto gradX = resultsBP->at(0);
|
auto gradX = resultsBP->at(0);
|
||||||
|
@ -2429,7 +2429,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) {
|
||||||
mask.assign(1.);
|
mask.assign(1.);
|
||||||
|
|
||||||
nd4j::ops::sru_bi op;
|
nd4j::ops::sru_bi op;
|
||||||
auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {});
|
auto results = op.evaluate({&input, &weights, &bias, &init, &mask}, {}, {});
|
||||||
ASSERT_TRUE(results->size() == 2);
|
ASSERT_TRUE(results->size() == 2);
|
||||||
|
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
@ -2480,7 +2480,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
|
||||||
inGradH.assign(0.5);
|
inGradH.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sru_bi_bp bp;
|
nd4j::ops::sru_bi_bp bp;
|
||||||
auto resultsBP = bp.execute({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {});
|
auto resultsBP = bp.evaluate({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {});
|
||||||
ASSERT_TRUE(resultsBP->size() == 4);
|
ASSERT_TRUE(resultsBP->size() == 4);
|
||||||
|
|
||||||
auto gradX = resultsBP->at(0);
|
auto gradX = resultsBP->at(0);
|
||||||
|
@ -2504,7 +2504,7 @@ TEST_F(DeclarableOpsTests1, ArgMax1) {
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {1});
|
auto result = op.evaluate({&x}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2525,7 +2525,7 @@ TEST_F(DeclarableOpsTests1, ArgMax2) {
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2547,7 +2547,7 @@ TEST_F(DeclarableOpsTests1, ArgMax3) {
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &dim}, {}, {});
|
auto result = op.evaluate({&x, &dim}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2568,7 +2568,7 @@ TEST_F(DeclarableOpsTests1, ArgMax4) {
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &dim}, {}, {});
|
auto result = op.evaluate({&x, &dim}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2590,7 +2590,7 @@ TEST_F(DeclarableOpsTests1, ArgMax5) {
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &dim}, {}, {});
|
auto result = op.evaluate({&x, &dim}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2610,12 +2610,12 @@ TEST_F(DeclarableOpsTests1, ArgMax6) {
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
|
|
||||||
auto expected = op.execute({&x}, {}, {2});
|
auto expected = op.evaluate({&x}, {}, {2});
|
||||||
ASSERT_EQ(Status::OK(), expected->status());
|
ASSERT_EQ(Status::OK(), expected->status());
|
||||||
auto exp = expected->at(0);
|
auto exp = expected->at(0);
|
||||||
|
|
||||||
|
|
||||||
auto result = op.execute({&x, &dim}, {}, {});
|
auto result = op.evaluate({&x, &dim}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -2636,7 +2636,7 @@ TEST_F(DeclarableOpsTests1, ArgMin1) {
|
||||||
|
|
||||||
nd4j::ops::argmin op;
|
nd4j::ops::argmin op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {1});
|
auto result = op.evaluate({&x}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2659,7 +2659,7 @@ TEST_F(DeclarableOpsTests1, SquareTests1) {
|
||||||
|
|
||||||
nd4j::ops::square op;
|
nd4j::ops::square op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -2677,7 +2677,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_1) {
|
||||||
|
|
||||||
nd4j::ops::onehot op;
|
nd4j::ops::onehot op;
|
||||||
|
|
||||||
auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3});
|
auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -2695,7 +2695,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::onehot op;
|
nd4j::ops::onehot op;
|
||||||
auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3});
|
auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2715,7 +2715,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_3) {
|
||||||
|
|
||||||
nd4j::ops::onehot op;
|
nd4j::ops::onehot op;
|
||||||
|
|
||||||
auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3});
|
auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -2736,7 +2736,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_4) {
|
||||||
|
|
||||||
nd4j::ops::onehot op;
|
nd4j::ops::onehot op;
|
||||||
|
|
||||||
auto result = op.execute({&indices, &depth}, {1.0f, 0.0f}, {});
|
auto result = op.evaluate({&indices, &depth}, {1.0f, 0.0f}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -2757,7 +2757,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) {
|
||||||
|
|
||||||
nd4j::ops::onehot op;
|
nd4j::ops::onehot op;
|
||||||
|
|
||||||
auto result = op.execute({&indices, &depth, &on, &off}, {}, {});
|
auto result = op.evaluate({&indices, &depth, &on, &off}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -2769,11 +2769,24 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, OneHotTests_6) {
|
TEST_F(DeclarableOpsTests1, OneHotTests_6) {
|
||||||
auto indices = NDArrayFactory::create<float>('c', {3}, {0., 1., 2.});
|
auto indices = NDArrayFactory::create<float>('c', {3}, {0.f, 1.f, 2.f});
|
||||||
auto e = NDArrayFactory::create<float>('c', {3, 3}, {1., 0., 0., 0., 1., 0., 0., 0., 1.});
|
auto e = NDArrayFactory::create<float>('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f});
|
||||||
|
|
||||||
nd4j::ops::onehot op;
|
nd4j::ops::onehot op;
|
||||||
auto result = op.execute({&indices}, {1.0, 0.0}, {0, 3});
|
auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3});
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests1, OneHotTests_7) {
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
|
||||||
|
auto e = NDArrayFactory::create<float16>('c', {3, 3}, {1., 0., 0., 0., 1., 0., 0., 0., 1.});
|
||||||
|
|
||||||
|
nd4j::ops::onehot op;
|
||||||
|
auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3}, {}, {nd4j::DataType::HALF}, false);
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
@ -2788,7 +2801,7 @@ TEST_F(DeclarableOpsTests1, FillAs_1) {
|
||||||
float scalar = 119.f;
|
float scalar = 119.f;
|
||||||
|
|
||||||
nd4j::ops::fill_as op;
|
nd4j::ops::fill_as op;
|
||||||
auto result = op.execute({&x}, {scalar}, {});
|
auto result = op.evaluate({&x}, {scalar}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2824,7 +2837,7 @@ TEST_F(DeclarableOpsTests1, Stack_1) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input2}, {}, {0});
|
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -2852,7 +2865,7 @@ TEST_F(DeclarableOpsTests1, Stack_2) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input2}, {}, {1});
|
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -2880,7 +2893,7 @@ TEST_F(DeclarableOpsTests1, Stack_3) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input2}, {}, {0});
|
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -2907,7 +2920,7 @@ TEST_F(DeclarableOpsTests1, Stack_4) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input2}, {}, {1});
|
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -2934,7 +2947,7 @@ TEST_F(DeclarableOpsTests1, Stack_5) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input2}, {}, {0});
|
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -2961,7 +2974,7 @@ TEST_F(DeclarableOpsTests1, Stack_6) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input2}, {}, {1});
|
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -2985,7 +2998,7 @@ TEST_F(DeclarableOpsTests1, Stack_7) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input1, &input1}, {}, {0});
|
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -3008,7 +3021,7 @@ TEST_F(DeclarableOpsTests1, Stack_8) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input1, &input1}, {}, {0});
|
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -3031,7 +3044,7 @@ TEST_F(DeclarableOpsTests1, Stack_9) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input1, &input1}, {}, {1});
|
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -3054,7 +3067,7 @@ TEST_F(DeclarableOpsTests1, Stack_10) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input1, &input1}, {}, {1});
|
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
//expected.printShapeInfo("exp");
|
//expected.printShapeInfo("exp");
|
||||||
|
@ -3079,7 +3092,7 @@ TEST_F(DeclarableOpsTests1, Stack_11) {
|
||||||
NDArray expected(expBuff, expShape);
|
NDArray expected(expBuff, expShape);
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto results = op.execute({&input1, &input1, &input1}, {}, {});
|
auto results = op.evaluate({&input1, &input1, &input1}, {}, {});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
|
@ -3095,7 +3108,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) {
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
|
|
||||||
auto result = op.execute({}, {}, {1, 5, 1});
|
auto result = op.evaluate({}, {}, {1, 5, 1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
@ -3122,7 +3135,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) {
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
|
|
||||||
auto result = op.execute({&start, &stop, &step}, {}, {});
|
auto result = op.evaluate({&start, &stop, &step}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
@ -3142,7 +3155,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) {
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
|
|
||||||
auto result = op.execute({}, {1.f, 5.f, 1.f}, {});
|
auto result = op.evaluate({}, {1.f, 5.f, 1.f}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
@ -3161,7 +3174,7 @@ TEST_F(DeclarableOpsTests1, softmax_test1) {
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {3, 3}, {1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01});
|
auto expOutput = NDArrayFactory::create<double>('c', {3, 3}, {1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01});
|
||||||
|
|
||||||
nd4j::ops::softmax op;
|
nd4j::ops::softmax op;
|
||||||
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -3177,7 +3190,7 @@ TEST_F(DeclarableOpsTests1, softmax_test2) {
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01});
|
auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01});
|
||||||
|
|
||||||
nd4j::ops::softmax op;
|
nd4j::ops::softmax op;
|
||||||
auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {1}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -3193,7 +3206,7 @@ TEST_F(DeclarableOpsTests1, softmax_test3) {
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01});
|
auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01});
|
||||||
|
|
||||||
nd4j::ops::softmax op;
|
nd4j::ops::softmax op;
|
||||||
auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {0}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -3209,7 +3222,7 @@ TEST_F(DeclarableOpsTests1, softmax_test4) {
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {0.01198,0.08855,0.00441,0.24072,0.65434});
|
auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {0.01198,0.08855,0.00441,0.24072,0.65434});
|
||||||
|
|
||||||
nd4j::ops::softmax op;
|
nd4j::ops::softmax op;
|
||||||
auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {1}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -3225,7 +3238,7 @@ TEST_F(DeclarableOpsTests1, softmax_test5) {
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {1,1,1,1,1});
|
auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {1,1,1,1,1});
|
||||||
|
|
||||||
nd4j::ops::softmax op;
|
nd4j::ops::softmax op;
|
||||||
auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {0});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -3241,7 +3254,7 @@ TEST_F(DeclarableOpsTests1, softmax_test6) {
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {0.01198,0.08855,0.00441,0.24072,0.65434});
|
auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {0.01198,0.08855,0.00441,0.24072,0.65434});
|
||||||
|
|
||||||
nd4j::ops::softmax op;
|
nd4j::ops::softmax op;
|
||||||
auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {0}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -3257,7 +3270,7 @@ TEST_F(DeclarableOpsTests1, softmax_test7) {
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {1,1,1,1,1});
|
auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {1,1,1,1,1});
|
||||||
|
|
||||||
nd4j::ops::softmax op;
|
nd4j::ops::softmax op;
|
||||||
auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {1}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -3273,7 +3286,7 @@ TEST_F(DeclarableOpsTests1, softmax_test8) {
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {5}, {0.01198,0.08855,0.00441,0.24072,0.65434});
|
auto expOutput = NDArrayFactory::create<double>('c', {5}, {0.01198,0.08855,0.00441,0.24072,0.65434});
|
||||||
|
|
||||||
nd4j::ops::softmax op;
|
nd4j::ops::softmax op;
|
||||||
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -3294,7 +3307,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_1) {
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {0});
|
auto result = op.evaluate({&input}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -3316,7 +3329,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_2) {
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {0});
|
auto result = op.evaluate({&input}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -3338,7 +3351,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_3) {
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1});
|
auto result = op.evaluate({&input}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -3364,7 +3377,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {0,1,2});
|
auto results = op.evaluate({&input}, {}, {0,1,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3389,7 +3402,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {}, {}, true);
|
auto results = op.evaluate({&input}, {}, {}, {}, {}, true);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3414,7 +3427,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {1,2});
|
auto results = op.evaluate({&input}, {}, {1,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3440,7 +3453,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {0,2});
|
auto results = op.evaluate({&input}, {}, {0,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3466,7 +3479,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {0,1});
|
auto results = op.evaluate({&input}, {}, {0,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3491,7 +3504,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {2}, {}, true);
|
auto results = op.evaluate({&input}, {}, {2}, {}, {}, true);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3518,7 +3531,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {1});
|
auto results = op.evaluate({&input}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3547,7 +3560,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {2,1});
|
auto results = op.evaluate({&input}, {}, {2,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3573,7 +3586,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9 ) {
|
||||||
NDArray output(shapeInfo);
|
NDArray output(shapeInfo);
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {0});
|
auto results = op.evaluate({&input}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3591,7 +3604,7 @@ TEST_F(DeclarableOpsTests1, Reverse_10 ) {
|
||||||
auto e = NDArrayFactory::create<double>('c', {4, 3}, {0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872,0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, 0.13950661});
|
auto e = NDArrayFactory::create<double>('c', {4, 3}, {0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872,0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, 0.13950661});
|
||||||
|
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto result = op.execute({&x, &i}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &i}, {}, {}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -3612,7 +3625,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) {
|
||||||
|
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {0, 1, 2});
|
auto results = op.evaluate({&input}, {}, {0, 1, 2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3633,7 +3646,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12 ) {
|
||||||
|
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {0});
|
auto results = op.evaluate({&input}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3655,7 +3668,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13 ) {
|
||||||
|
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {-1});
|
auto results = op.evaluate({&input}, {}, {-1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3676,7 +3689,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14 ) {
|
||||||
|
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&input}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3694,7 +3707,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) {
|
||||||
|
|
||||||
nd4j::ops::expose op;
|
nd4j::ops::expose op;
|
||||||
|
|
||||||
auto result = op.execute({&input0, &input1}, {}, {});
|
auto result = op.evaluate({&input0, &input1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ TEST_F(DeclarableOpsTests10, Test_ArgMax_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
auto result = op.execute({&x}, {}, {}, {});
|
auto result = op.evaluate({&x});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ TEST_F(DeclarableOpsTests10, Test_ArgMax_2) {
|
||||||
x.linspace(1.0);
|
x.linspace(1.0);
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = *result->at(0);
|
auto z = *result->at(0);
|
||||||
|
@ -98,7 +98,7 @@ TEST_F(DeclarableOpsTests10, Test_And_1) {
|
||||||
auto e = NDArrayFactory::create<double>('c', {4}, {0, 0, 0, 1});
|
auto e = NDArrayFactory::create<double>('c', {4}, {0, 0, 0, 1});
|
||||||
|
|
||||||
nd4j::ops::boolean_and op;
|
nd4j::ops::boolean_and op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
@ -112,7 +112,7 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) {
|
||||||
auto e = NDArrayFactory::create<double>('c', {4}, {1, 1, 0, 1});
|
auto e = NDArrayFactory::create<double>('c', {4}, {1, 1, 0, 1});
|
||||||
|
|
||||||
nd4j::ops::boolean_or op;
|
nd4j::ops::boolean_or op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
@ -127,7 +127,7 @@ TEST_F(DeclarableOpsTests10, Test_Not_1) {
|
||||||
auto e = NDArrayFactory::create<bool>('c', {4}, {false, false, true, false});
|
auto e = NDArrayFactory::create<bool>('c', {4}, {false, false, true, false});
|
||||||
|
|
||||||
nd4j::ops::boolean_not op;
|
nd4j::ops::boolean_not op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ TEST_F(DeclarableOpsTests10, Test_Size_at_1) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>(20);
|
auto e = NDArrayFactory::create<Nd4jLong>(20);
|
||||||
|
|
||||||
nd4j::ops::size_at op;
|
nd4j::ops::size_at op;
|
||||||
auto result = op.execute({&x}, {}, {1});
|
auto result = op.evaluate({&x}, {1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
@ -161,7 +161,7 @@ TEST_F(DeclarableOpsTests10, MirrorPad_SGO_Test_1) {
|
||||||
|
|
||||||
nd4j::ops::mirror_pad op;
|
nd4j::ops::mirror_pad op;
|
||||||
|
|
||||||
auto res = op.execute({&in, &pad}, {10.0}, {0}, {}, false, nd4j::DataType::DOUBLE);
|
auto res = op.evaluate({&in, &pad}, {10.0}, {0});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -175,7 +175,7 @@ TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) {
|
||||||
auto exp = NDArrayFactory::create<double>({3., 4., 1., 0., 2.});
|
auto exp = NDArrayFactory::create<double>({3., 4., 1., 0., 2.});
|
||||||
|
|
||||||
nd4j::ops::unique op;
|
nd4j::ops::unique op;
|
||||||
auto res = op.execute({&input}, {}, {});
|
auto res = op.evaluate({&input}, {}, {});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto res1 = res->at(0);
|
auto res1 = res->at(0);
|
||||||
auto res2 = res->at(1);
|
auto res2 = res->at(1);
|
||||||
|
@ -192,7 +192,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {6, 2}, {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {6, 2}, {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto res = op.execute({&input}, {}, {});
|
auto res = op.evaluate({&input}, {}, {});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5, 3}, {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5, 3}, {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto res = op.execute({&input}, {}, {});
|
auto res = op.evaluate({&input}, {}, {});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
|
|
||||||
|
@ -227,7 +227,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) {
|
||||||
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 0, 0, 1});
|
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 0, 0, 1});
|
||||||
auto exp3 = NDArrayFactory::create<Nd4jLong>({0, 1, 0, 1, 0});
|
auto exp3 = NDArrayFactory::create<Nd4jLong>({0, 1, 0, 1, 0});
|
||||||
nd4j::ops::where_np op;
|
nd4j::ops::where_np op;
|
||||||
auto res = op.execute({&cond3d}, {}, {});
|
auto res = op.evaluate({&cond3d}, {}, {});
|
||||||
ASSERT_TRUE(res->size() == 3);
|
ASSERT_TRUE(res->size() == 3);
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto res1 = res->at(0);
|
auto res1 = res->at(0);
|
||||||
|
@ -251,7 +251,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
|
||||||
auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2});
|
auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2});
|
||||||
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4});
|
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4});
|
||||||
nd4j::ops::where_np op;
|
nd4j::ops::where_np op;
|
||||||
auto res = op.execute({&cond2d}, {}, {});
|
auto res = op.evaluate({&cond2d}, {}, {});
|
||||||
ASSERT_TRUE(res->size() == 2);
|
ASSERT_TRUE(res->size() == 2);
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(exp1.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp1.equalsTo(res->at(0)));
|
||||||
|
@ -267,7 +267,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4,1}, {0, 2, 3, 4});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4,1}, {0, 2, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto res = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto res = op.evaluate({&input});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
// resA->printIndexedBuffer("Result A");
|
// resA->printIndexedBuffer("Result A");
|
||||||
|
@ -285,7 +285,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto res = op.execute({&input}, {}, {});
|
auto res = op.evaluate({&input}, {}, {});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
//resA->printIndexedBuffer("Result A");
|
//resA->printIndexedBuffer("Result A");
|
||||||
|
@ -303,7 +303,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto res = op.execute({&input}, {}, {});
|
auto res = op.evaluate({&input}, {}, {});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
ASSERT_TRUE(resA->isEmpty());
|
ASSERT_TRUE(resA->isEmpty());
|
||||||
|
@ -322,7 +322,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {3, 1}, {0, 3, 4});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {3, 1}, {0, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto res = op.execute({&input}, {}, {});
|
auto res = op.evaluate({&input}, {}, {});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
//ASSERT_TRUE(resA->isEmpty());
|
//ASSERT_TRUE(resA->isEmpty());
|
||||||
|
@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0});
|
||||||
|
|
||||||
nd4j::ops::where_np op;
|
nd4j::ops::where_np op;
|
||||||
auto res = op.execute({&input}, {}, {});
|
auto res = op.evaluate({&input}, {}, {});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
ASSERT_TRUE(resA->isEmpty());
|
ASSERT_TRUE(resA->isEmpty());
|
||||||
|
@ -361,7 +361,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) {
|
||||||
auto exp = NDArrayFactory::create<double>(0.6);
|
auto exp = NDArrayFactory::create<double>(0.6);
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss op;
|
nd4j::ops::cosine_distance_loss op;
|
||||||
auto res = op.execute({&predictions, &weights, &labels}, {}, {3, 1});
|
auto res = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
|
|
||||||
|
@ -379,7 +379,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) {
|
||||||
auto exp = NDArrayFactory::create<double>(0.6);
|
auto exp = NDArrayFactory::create<double>(0.6);
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss op;
|
nd4j::ops::cosine_distance_loss op;
|
||||||
auto res = op.execute({&predictions, &weights, &labels}, {}, {2, 1});
|
auto res = op.evaluate({&predictions, &weights, &labels}, {}, {2, 1});
|
||||||
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res->status() == ND4J_STATUS_OK);
|
||||||
auto resA = res->at(0);
|
auto resA = res->at(0);
|
||||||
|
|
||||||
|
@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) {
|
||||||
exp.p(1, 2, 0, 0.);
|
exp.p(1, 2, 0, 0.);
|
||||||
|
|
||||||
nd4j::ops::matrix_band_part op;
|
nd4j::ops::matrix_band_part op;
|
||||||
auto results = op.execute({&x}, {}, {1, 1}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&x}, {}, {1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
//results->at(0)->printIndexedBuffer("MBP Test1");
|
//results->at(0)->printIndexedBuffer("MBP Test1");
|
||||||
|
@ -422,7 +422,7 @@ TEST_F(DeclarableOpsTests10, atan2_test1) {
|
||||||
0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266,});
|
0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266,});
|
||||||
|
|
||||||
nd4j::ops::tf_atan2 op;
|
nd4j::ops::tf_atan2 op;
|
||||||
auto result = op.execute({&y, &x}, {}, {});
|
auto result = op.evaluate({&y, &x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests10, atan2_test2) {
|
||||||
3.11208, 2.99987, 2.83399, 2.57869, 2.207 , 1.77611, 1.41664, 1.17298, 1.01458, 0.90829, 0.8336 , 0.77879});
|
3.11208, 2.99987, 2.83399, 2.57869, 2.207 , 1.77611, 1.41664, 1.17298, 1.01458, 0.90829, 0.8336 , 0.77879});
|
||||||
|
|
||||||
nd4j::ops::tf_atan2 op;
|
nd4j::ops::tf_atan2 op;
|
||||||
auto result = op.execute({&y, &x}, {}, {});
|
auto result = op.evaluate({&y, &x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
// z->printIndexedBuffer();
|
// z->printIndexedBuffer();
|
||||||
|
@ -465,7 +465,7 @@ TEST_F(DeclarableOpsTests10, atan2_test3) {
|
||||||
-1.54128, -1.42907, -1.2632 , -1.00789,-0.63621, -0.20531, 0.15416, 0.39782, 0.55622, 0.6625 , 0.7372 , 0.79201});
|
-1.54128, -1.42907, -1.2632 , -1.00789,-0.63621, -0.20531, 0.15416, 0.39782, 0.55622, 0.6625 , 0.7372 , 0.79201});
|
||||||
|
|
||||||
nd4j::ops::tf_atan2 op;
|
nd4j::ops::tf_atan2 op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -485,7 +485,7 @@ TEST_F(DeclarableOpsTests10, atan2_test4) {
|
||||||
3.05688, 3.03942, 3.01293, 2.9681 , 2.18167, 1.87635, 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372 });
|
3.05688, 3.03942, 3.01293, 2.9681 , 2.18167, 1.87635, 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372 });
|
||||||
|
|
||||||
nd4j::ops::tf_atan2 op;
|
nd4j::ops::tf_atan2 op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -505,7 +505,7 @@ TEST_F(DeclarableOpsTests10, atan2_test5) {
|
||||||
-1.48608, -1.46862, -1.44214, -1.3973 ,-0.61088, -0.30556, 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336 });
|
-1.48608, -1.46862, -1.44214, -1.3973 ,-0.61088, -0.30556, 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336 });
|
||||||
|
|
||||||
nd4j::ops::tf_atan2 op;
|
nd4j::ops::tf_atan2 op;
|
||||||
auto result = op.execute({&y, &x}, {}, {}, {});
|
auto result = op.evaluate({&y, &x}, {}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -524,7 +524,7 @@ TEST_F(DeclarableOpsTests10, atan2_test6) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {-2.25712, -1.68608, -1.44214, -0.54006,-2.77695, -2.16855, 0.34972, 0.24585, 2.71267, 1.74453, 1.45312, 0.8336 });
|
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {-2.25712, -1.68608, -1.44214, -0.54006,-2.77695, -2.16855, 0.34972, 0.24585, 2.71267, 1.74453, 1.45312, 0.8336 });
|
||||||
|
|
||||||
nd4j::ops::tf_atan2 op;
|
nd4j::ops::tf_atan2 op;
|
||||||
auto result = op.execute({&y, &x}, {}, {}, {});
|
auto result = op.evaluate({&y, &x}, {}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests10, IGamma_Test1) {
|
||||||
0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735});
|
0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735});
|
||||||
|
|
||||||
nd4j::ops::igamma op;
|
nd4j::ops::igamma op;
|
||||||
auto result = op.execute({&y, &x}, {}, {}, {});
|
auto result = op.evaluate({&y, &x}, {}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
// z->printBuffer("OUtput");
|
// z->printBuffer("OUtput");
|
||||||
|
@ -568,7 +568,7 @@ TEST_F(DeclarableOpsTests10, IGamma_Test2) {
|
||||||
0.999996, 0.999914, 0.999564, 0.998773});
|
0.999996, 0.999914, 0.999564, 0.998773});
|
||||||
|
|
||||||
nd4j::ops::igammac op;
|
nd4j::ops::igammac op;
|
||||||
auto result = op.execute({&y, &x}, {}, {}, {});
|
auto result = op.evaluate({&y, &x}, {}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
// z->printBuffer("OUtput");
|
// z->printBuffer("OUtput");
|
||||||
|
@ -591,7 +591,7 @@ TEST_F(DeclarableOpsTests10, LGamma_Test1) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::lgamma op;
|
nd4j::ops::lgamma op;
|
||||||
auto result = op.execute({&x}, {}, {}, {});
|
auto result = op.evaluate({&x}, {}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
// z->printBuffer("OUtput");
|
// z->printBuffer("OUtput");
|
||||||
|
@ -610,7 +610,7 @@ TEST_F(DeclarableOpsTests10, range_test10) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {5}, {0.,1.,2.,3.,4.});
|
auto exp = NDArrayFactory::create<double>('c', {5}, {0.,1.,2.,3.,4.});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({&limit}, {}, {}, {});
|
auto result = op.evaluate({&limit}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -632,7 +632,7 @@ TEST_F(DeclarableOpsTests10, range_test11) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {5}, {0.5,1.5,2.5,3.5,4.5});
|
auto exp = NDArrayFactory::create<double>('c', {5}, {0.5,1.5,2.5,3.5,4.5});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({&start, &limit}, {}, {}, {});
|
auto result = op.evaluate({&start, &limit}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -650,7 +650,7 @@ TEST_F(DeclarableOpsTests10, range_test12) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f});
|
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {0.5, 5, 0.5}, {}, {});
|
auto result = op.evaluate({}, {0.5, 5, 0.5}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -671,7 +671,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::top_k op;
|
nd4j::ops::top_k op;
|
||||||
auto result = op.execute({&x}, {}, {4}, {false});
|
auto result = op.evaluate({&x}, {}, {4}, {false});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -681,7 +681,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test1) {
|
||||||
ASSERT_TRUE(expUnsorted.isSameShape(z));
|
ASSERT_TRUE(expUnsorted.isSameShape(z));
|
||||||
ASSERT_TRUE(expUnsorted.equalsTo(z));
|
ASSERT_TRUE(expUnsorted.equalsTo(z));
|
||||||
|
|
||||||
auto result2 = op.execute({&x}, {}, {5}, {true});
|
auto result2 = op.evaluate({&x}, {}, {5}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result2->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result2->status());
|
||||||
|
|
||||||
|
@ -704,7 +704,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test2) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::top_k op;
|
nd4j::ops::top_k op;
|
||||||
auto result = op.execute({&x}, {}, {5}, {false});
|
auto result = op.evaluate({&x}, {}, {5}, {false});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -714,7 +714,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test2) {
|
||||||
ASSERT_TRUE(expUnsorted.isSameShape(z));
|
ASSERT_TRUE(expUnsorted.isSameShape(z));
|
||||||
ASSERT_TRUE(expUnsorted.equalsTo(z));
|
ASSERT_TRUE(expUnsorted.equalsTo(z));
|
||||||
|
|
||||||
auto result2 = op.execute({&x}, {}, {5}, {true});
|
auto result2 = op.evaluate({&x}, {}, {5}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result2->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result2->status());
|
||||||
|
|
||||||
|
@ -738,7 +738,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test1
|
||||||
logits.linspace(0.1, 0.1);
|
logits.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
|
||||||
auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&labels, &logits});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -760,7 +760,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2
|
||||||
logits.linspace(0.1, 0.1);
|
logits.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
|
||||||
auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&labels, &logits});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -782,7 +782,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3
|
||||||
logits.linspace(0.1, 0.1);
|
logits.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
|
||||||
auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&labels, &logits});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -804,7 +804,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4
|
||||||
logits.linspace(0.1, 0.1);
|
logits.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op;
|
||||||
auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&labels, &logits});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -825,7 +825,7 @@ TEST_F(DeclarableOpsTests10, split_test4) {
|
||||||
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
|
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
|
||||||
|
|
||||||
nd4j::ops::split op;
|
nd4j::ops::split op;
|
||||||
auto results = op.execute({&input, &axis}, {}, {2}, {});
|
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -849,7 +849,7 @@ TEST_F(DeclarableOpsTests10, split_test5) {
|
||||||
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
|
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
|
||||||
|
|
||||||
nd4j::ops::split op;
|
nd4j::ops::split op;
|
||||||
auto results = op.execute({&input}, {}, {2,-1},{});
|
auto results = op.evaluate({&input}, {}, {2,-1},{});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -872,7 +872,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {2, 1, 1, 0, 2});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {2, 1, 1, 0, 2});
|
||||||
|
|
||||||
nd4j::ops::histogram_fixed_width op;
|
nd4j::ops::histogram_fixed_width op;
|
||||||
auto results = op.execute({&input, &range}, {}, {5}, {});
|
auto results = op.evaluate({&input, &range}, {}, {5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -892,7 +892,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test2) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {5, 2, 5, 3, 9});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {5, 2, 5, 3, 9});
|
||||||
|
|
||||||
nd4j::ops::histogram_fixed_width op;
|
nd4j::ops::histogram_fixed_width op;
|
||||||
auto results = op.execute({&input, &range}, {}, {5}, {});
|
auto results = op.evaluate({&input, &range}, {}, {5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -912,7 +912,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test3) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {5, 2, 5, 4, 8});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {5, 2, 5, 4, 8});
|
||||||
|
|
||||||
nd4j::ops::histogram_fixed_width op;
|
nd4j::ops::histogram_fixed_width op;
|
||||||
auto results = op.execute({&input, &range}, {}, {5}, {});
|
auto results = op.evaluate({&input, &range}, {}, {5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -937,7 +937,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test4) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {22, 17, 24, 19, 18});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {22, 17, 24, 19, 18});
|
||||||
|
|
||||||
nd4j::ops::histogram_fixed_width op;
|
nd4j::ops::histogram_fixed_width op;
|
||||||
auto results = op.execute({&input, &range}, {}, {5}, {});
|
auto results = op.evaluate({&input, &range}, {}, {5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -963,7 +963,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {23, 15, 24, 17, 21});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {23, 15, 24, 17, 21});
|
||||||
|
|
||||||
nd4j::ops::histogram_fixed_width op;
|
nd4j::ops::histogram_fixed_width op;
|
||||||
auto results = op.execute({&input, &range}, {}, {5}, {});
|
auto results = op.evaluate({&input, &range}, {}, {5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -986,7 +986,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {3, 1, 2, 0, 1});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {3, 1, 2, 0, 1});
|
||||||
|
|
||||||
nd4j::ops::histogram_fixed_width op;
|
nd4j::ops::histogram_fixed_width op;
|
||||||
auto results = op.execute({&input, &range, &bins}, {}, {}, {});
|
auto results = op.evaluate({&input, &range, &bins}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1010,7 +1010,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_1) {
|
||||||
//input.linspace(1.f);
|
//input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {});
|
auto results = op.evaluate({&input, &n}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1032,7 +1032,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_2) {
|
||||||
// input.linspace(1.f);
|
// input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {});
|
auto results = op.evaluate({&input, &n}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_3) {
|
||||||
//input.linspace(1.f);
|
//input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {1}); // with reverse = true
|
auto results = op.evaluate({&input, &n}, {}, {1}); // with reverse = true
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1076,7 +1076,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_4) {
|
||||||
//input.linspace(1.f);
|
//input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {});
|
auto results = op.evaluate({&input, &n}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1097,7 +1097,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_04) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {});
|
auto results = op.evaluate({&input, &n}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1118,7 +1118,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_5) {
|
||||||
// input.linspace(1.f);
|
// input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {1});
|
auto results = op.evaluate({&input, &n}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1140,7 +1140,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_6) {
|
||||||
// input.linspace(1.f);
|
// input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {0});
|
auto results = op.evaluate({&input, &n}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1160,7 +1160,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) {
|
||||||
// input.linspace(1.f);
|
// input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {1});
|
auto results = op.evaluate({&input, &n}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1186,7 +1186,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) {
|
||||||
//input.linspace(1.f);
|
//input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {0});
|
auto results = op.evaluate({&input, &n}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1213,7 +1213,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) {
|
||||||
//input.linspace(1.f);
|
//input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::nth_element op;
|
nd4j::ops::nth_element op;
|
||||||
auto results = op.execute({&input, &n}, {}, {1});
|
auto results = op.evaluate({&input, &n}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1235,7 +1235,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test1) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1257,7 +1257,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test2) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1279,7 +1279,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test3) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1299,7 +1299,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test4) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {10.f, 10.f, 10.f,10.f, 10.f, 10.f, 10.f, 10.f, 10.f});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {10.f, 10.f, 10.f,10.f, 10.f, 10.f, 10.f, 10.f, 10.f});
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1319,7 +1319,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test5) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3}, {10.f, 10.f, 10.f});
|
auto exp = NDArrayFactory::create<double>('c', {3}, {10.f, 10.f, 10.f});
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1339,7 +1339,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1}, {10.f});
|
auto exp = NDArrayFactory::create<double>('c', {1}, {10.f});
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1359,7 +1359,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test7) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1}, {10.});
|
auto exp = NDArrayFactory::create<double>('c', {1}, {10.});
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1381,7 +1381,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test8) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1403,7 +1403,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test9) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1425,7 +1425,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.evaluate({&input, &shape}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1481,7 +1481,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input}, {}, {10, 10});
|
auto results = op.evaluate({&input}, {}, {10, 10});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1503,7 +1503,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) {
|
||||||
auto size = NDArrayFactory::create<int>({65,65});
|
auto size = NDArrayFactory::create<int>({65,65});
|
||||||
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
|
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input, &size}, {}, {}, {false});
|
auto results = op.evaluate({&input, &size}, {}, {}, {false});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1522,7 +1522,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) {
|
||||||
auto size = NDArrayFactory::create<int>({65,65});
|
auto size = NDArrayFactory::create<int>({65,65});
|
||||||
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
|
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input, &size}, {}, {}, {true});
|
auto results = op.evaluate({&input, &size}, {}, {}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1566,7 +1566,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input}, {}, {4, 5}, {false, true});
|
auto results = op.evaluate({&input}, {}, {4, 5}, {false, true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1613,7 +1613,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input}, {}, {4, 5}, {false, true});
|
auto results = op.evaluate({&input}, {}, {4, 5}, {false, true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1669,7 +1669,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input}, {}, {10, 10});
|
auto results = op.evaluate({&input}, {}, {10, 10});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1824,7 +1824,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input}, {}, {9, 9});
|
auto results = op.evaluate({&input}, {}, {9, 9});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1883,7 +1883,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2013,7 +2013,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input}, {}, {10, 10}, {true});
|
auto results = op.evaluate({&input}, {}, {10, 10}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2142,7 +2142,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input, &size}, {}, {}, {true});
|
auto results = op.evaluate({&input, &size}, {}, {}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2166,7 +2166,7 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
|
||||||
8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.});
|
8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.});
|
||||||
|
|
||||||
nd4j::ops::lin_space op;
|
nd4j::ops::lin_space op;
|
||||||
auto result = op.execute({&start, &finish, &num}, {}, {});
|
auto result = op.evaluate({&start, &finish, &num}, {}, {});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
|
|
||||||
|
@ -2208,7 +2208,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_nearest_neighbor op;
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
auto results = op.execute({&input}, {}, {4, 5}, {false, false});
|
auto results = op.evaluate({&input}, {}, {4, 5}, {false, false});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_nearest_neighbor op;
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
auto results = op.execute({&input}, {}, {4, 5});
|
auto results = op.evaluate({&input}, {}, {4, 5});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2304,7 +2304,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_nearest_neighbor op;
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
auto results = op.execute({&input}, {}, {4,5}, {false, true});
|
auto results = op.evaluate({&input}, {}, {4,5}, {false, true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2351,7 +2351,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_nearest_neighbor op;
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
auto results = op.execute({&input}, {}, {4, 5});
|
auto results = op.evaluate({&input}, {}, {4, 5});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2373,7 +2373,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) {
|
||||||
NDArray expected = NDArrayFactory::create<double>(2.5206409f);
|
NDArray expected = NDArrayFactory::create<double>(2.5206409f);
|
||||||
|
|
||||||
nd4j::ops::reduce_logsumexp op;
|
nd4j::ops::reduce_logsumexp op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2394,7 +2394,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) {
|
||||||
NDArray expected = NDArrayFactory::create<double>({1.0986123f, 1.8619947f, 1.0986123f});
|
NDArray expected = NDArrayFactory::create<double>({1.0986123f, 1.8619947f, 1.0986123f});
|
||||||
|
|
||||||
nd4j::ops::reduce_logsumexp op;
|
nd4j::ops::reduce_logsumexp op;
|
||||||
auto results = op.execute({&input}, {}, {0});
|
auto results = op.evaluate({&input}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2414,7 +2414,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) {
|
||||||
NDArray expected = NDArrayFactory::create<float>('c', {1,3}, {1.0986123f, 1.8619947f, 1.0986123f});
|
NDArray expected = NDArrayFactory::create<float>('c', {1,3}, {1.0986123f, 1.8619947f, 1.0986123f});
|
||||||
|
|
||||||
nd4j::ops::reduce_logsumexp op;
|
nd4j::ops::reduce_logsumexp op;
|
||||||
auto results = op.execute({&input}, {1.f}, {0});
|
auto results = op.evaluate({&input}, {1.f}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2435,7 +2435,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
|
||||||
boxes.linspace(1.f);
|
boxes.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::non_max_suppression op;
|
nd4j::ops::non_max_suppression op;
|
||||||
auto results = op.execute({&boxes, &scores}, {}, {3});
|
auto results = op.evaluate({&boxes, &scores}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2457,7 +2457,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
|
||||||
NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5});
|
NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5});
|
||||||
|
|
||||||
nd4j::ops::non_max_suppression op;
|
nd4j::ops::non_max_suppression op;
|
||||||
auto results = op.execute({&boxes, &scales}, {0.5}, {3});
|
auto results = op.evaluate({&boxes, &scales}, {0.5}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2479,7 +2479,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) {
|
||||||
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
|
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
nd4j::ops::non_max_suppression op;
|
nd4j::ops::non_max_suppression op;
|
||||||
auto results = op.execute({&boxes, &scales}, {0.5, 0.5}, {2});
|
auto results = op.evaluate({&boxes, &scales}, {0.5, 0.5}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
@ -2502,7 +2502,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) {
|
||||||
NDArray threshold = NDArrayFactory::create(0.5f);
|
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||||
NDArray scoreThreshold = NDArrayFactory::create(0.5);
|
NDArray scoreThreshold = NDArrayFactory::create(0.5);
|
||||||
nd4j::ops::non_max_suppression op;
|
nd4j::ops::non_max_suppression op;
|
||||||
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
@ -2524,7 +2524,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) {
|
||||||
NDArray threshold = NDArrayFactory::create(0.5f);
|
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||||
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
||||||
nd4j::ops::non_max_suppression op;
|
nd4j::ops::non_max_suppression op;
|
||||||
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
@ -2547,7 +2547,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) {
|
||||||
NDArray threshold = NDArrayFactory::create(0.5f);
|
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||||
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
||||||
nd4j::ops::non_max_suppression_v3 op;
|
nd4j::ops::non_max_suppression_v3 op;
|
||||||
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
@ -2571,7 +2571,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) {
|
||||||
NDArray threshold = NDArrayFactory::create(0.5f);
|
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||||
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
||||||
nd4j::ops::non_max_suppression_v3 op;
|
nd4j::ops::non_max_suppression_v3 op;
|
||||||
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
@ -2594,7 +2594,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) {
|
||||||
NDArray threshold = NDArrayFactory::create(0.5f);
|
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||||
NDArray scoreThreshold = NDArrayFactory::create(0.5f);
|
NDArray scoreThreshold = NDArrayFactory::create(0.5f);
|
||||||
nd4j::ops::non_max_suppression_v3 op;
|
nd4j::ops::non_max_suppression_v3 op;
|
||||||
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
@ -2619,7 +2619,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) {
|
||||||
NDArray expected = NDArrayFactory::create<int>('c', {1,}, {3});
|
NDArray expected = NDArrayFactory::create<int>('c', {1,}, {3});
|
||||||
|
|
||||||
nd4j::ops::non_max_suppression_overlaps op;
|
nd4j::ops::non_max_suppression_overlaps op;
|
||||||
auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {});
|
auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2644,7 +2644,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) {
|
||||||
NDArray expected = NDArrayFactory::create<int>('c', {3,}, {1,1,1});
|
NDArray expected = NDArrayFactory::create<int>('c', {3,}, {1,1,1});
|
||||||
|
|
||||||
nd4j::ops::non_max_suppression_overlaps op;
|
nd4j::ops::non_max_suppression_overlaps op;
|
||||||
auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {});
|
auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2669,7 +2669,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) {
|
||||||
NDArray expected = NDArrayFactory::create<int>('c', {5,}, {1,1,1,1,1});
|
NDArray expected = NDArrayFactory::create<int>('c', {5,}, {1,1,1,1,1});
|
||||||
|
|
||||||
nd4j::ops::non_max_suppression_overlaps op;
|
nd4j::ops::non_max_suppression_overlaps op;
|
||||||
auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {});
|
auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2693,7 +2693,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
|
||||||
NDArray expected = NDArrayFactory::create<double>('c', {1,1,1,1}, {2.5f});
|
NDArray expected = NDArrayFactory::create<double>('c', {1,1,1,1}, {2.5f});
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {});
|
auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2718,7 +2718,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) {
|
||||||
NDArray expected = NDArrayFactory::create<float>('c', {1,1,1,1}, {4.f});
|
NDArray expected = NDArrayFactory::create<float>('c', {1,1,1,1}, {4.f});
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2742,7 +2742,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
|
||||||
NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0});
|
auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2766,7 +2766,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
|
||||||
NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2790,7 +2790,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) {
|
||||||
NDArray expected('c', {1, 10, 10,3}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {1, 10, 10,3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2826,7 +2826,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) {
|
||||||
});
|
});
|
||||||
images.linspace(1.);
|
images.linspace(1.);
|
||||||
nd4j::ops::draw_bounding_boxes op;
|
nd4j::ops::draw_bounding_boxes op;
|
||||||
auto results = op.execute({&images, &boxes, &colors}, {}, {});
|
auto results = op.evaluate({&images, &boxes, &colors}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2859,7 +2859,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) {
|
||||||
73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f });
|
73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f });
|
||||||
images.linspace(1.1);
|
images.linspace(1.1);
|
||||||
nd4j::ops::draw_bounding_boxes op;
|
nd4j::ops::draw_bounding_boxes op;
|
||||||
auto results = op.execute({&images, &boxes, &colors}, {}, {});
|
auto results = op.evaluate({&images, &boxes, &colors}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2912,7 +2912,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) {
|
||||||
0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f,
|
0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f,
|
||||||
0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f});
|
0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f});
|
||||||
nd4j::ops::draw_bounding_boxes op;
|
nd4j::ops::draw_bounding_boxes op;
|
||||||
auto results = op.execute({&images, &boxes, &colors}, {}, {});
|
auto results = op.evaluate({&images, &boxes, &colors}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto result = results->at(0);
|
auto result = results->at(0);
|
||||||
|
@ -2937,7 +2937,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
||||||
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2958,7 +2958,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
|
||||||
NDArray max = NDArrayFactory::create<double>(0.1);
|
NDArray max = NDArrayFactory::create<double>(0.1);
|
||||||
|
|
||||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2979,7 +2979,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) {
|
||||||
NDArray max = NDArrayFactory::create<double>('c', {1}, {0.1});
|
NDArray max = NDArrayFactory::create<double>('c', {1}, {0.1});
|
||||||
|
|
||||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3003,7 +3003,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) {
|
||||||
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||||
|
|
||||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3026,7 +3026,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) {
|
||||||
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||||
|
|
||||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {8}, {true});
|
auto results = op.evaluate({&x, &min, &max}, {}, {8}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3050,7 +3050,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) {
|
||||||
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||||
|
|
||||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {6}, {true});
|
auto results = op.evaluate({&x, &min, &max}, {}, {6}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3073,7 +3073,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) {
|
||||||
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||||
|
|
||||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {6}, {false});
|
auto results = op.evaluate({&x, &min, &max}, {}, {6}, {false});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3108,7 +3108,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
|
||||||
NDArray max = NDArrayFactory::create<float>({65.f, 70.f, 90.f});
|
NDArray max = NDArrayFactory::create<float>({65.f, 70.f, 90.f});
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3161,7 +3161,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
|
||||||
NDArray max = NDArrayFactory::create<float>({20.f, 21.f, 22.f, 23.f});
|
NDArray max = NDArrayFactory::create<float>({20.f, 21.f, 22.f, 23.f});
|
||||||
x.linspace(-60.);
|
x.linspace(-60.);
|
||||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3195,7 +3195,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
|
||||||
NDArray max = NDArrayFactory::create<float>('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
NDArray max = NDArrayFactory::create<float>('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||||
// x.linspace(-60.);
|
// x.linspace(-60.);
|
||||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3241,7 +3241,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) {
|
||||||
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
|
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||||
x.linspace(0., 0.01);
|
x.linspace(0., 0.01);
|
||||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3266,7 +3266,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) {
|
||||||
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
|
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||||
x.linspace(0., 0.1);
|
x.linspace(0., 0.1);
|
||||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ TEST_F(DeclarableOpsTests11, test_listdiff_1) {
|
||||||
auto y = NDArrayFactory::create<int>('c',{2}, {3, 1});
|
auto y = NDArrayFactory::create<int>('c',{2}, {3, 1});
|
||||||
|
|
||||||
nd4j::ops::listdiff op;
|
nd4j::ops::listdiff op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test1) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}, {});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test2) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test4) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -195,7 +195,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test5) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -227,7 +227,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test6) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -288,7 +288,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test8) {
|
||||||
weights.p(3, 0.);
|
weights.p(3, 0.);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -325,7 +325,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test9) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -383,7 +383,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test11) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -420,7 +420,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test12) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -459,7 +459,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
|
||||||
weights.t<double>(2) = 0.;
|
weights.t<double>(2) = 0.;
|
||||||
|
|
||||||
nd4j::ops::log_loss_grad op;
|
nd4j::ops::log_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -642,7 +642,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
|
||||||
|
|
||||||
auto size = NDArrayFactory::create<int>({30, 30});
|
auto size = NDArrayFactory::create<int>({30, 30});
|
||||||
nd4j::ops::resize_bicubic op;
|
nd4j::ops::resize_bicubic op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
NDArray* result = results->at(0);
|
NDArray* result = results->at(0);
|
||||||
|
@ -716,7 +716,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({10, 8});
|
auto size = NDArrayFactory::create<int>({10, 8});
|
||||||
nd4j::ops::resize_bicubic op;
|
nd4j::ops::resize_bicubic op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -753,7 +753,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({6, 6});
|
auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_bicubic op;
|
nd4j::ops::resize_bicubic op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -790,7 +790,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({6, 8});
|
auto size = NDArrayFactory::create<int>({6, 8});
|
||||||
nd4j::ops::resize_bicubic op;
|
nd4j::ops::resize_bicubic op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -833,7 +833,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({8, 8});
|
auto size = NDArrayFactory::create<int>({8, 8});
|
||||||
nd4j::ops::resize_bicubic op;
|
nd4j::ops::resize_bicubic op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -963,7 +963,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
|
||||||
|
|
||||||
auto size = NDArrayFactory::create<int>({30, 30});
|
auto size = NDArrayFactory::create<int>({30, 30});
|
||||||
nd4j::ops::resize_bicubic op;
|
nd4j::ops::resize_bicubic op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
NDArray* result = results->at(0);
|
NDArray* result = results->at(0);
|
||||||
|
@ -1021,7 +1021,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) {
|
||||||
});
|
});
|
||||||
auto size = NDArrayFactory::create<int>({9, 9});
|
auto size = NDArrayFactory::create<int>({9, 9});
|
||||||
nd4j::ops::resize_bicubic op;
|
nd4j::ops::resize_bicubic op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1074,7 +1074,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
|
||||||
|
|
||||||
auto size = NDArrayFactory::create<int>({9, 9});
|
auto size = NDArrayFactory::create<int>({9, 9});
|
||||||
nd4j::ops::resize_bicubic op;
|
nd4j::ops::resize_bicubic op;
|
||||||
auto results = op.execute({&input, &size}, {}, {}, {true, false});
|
auto results = op.evaluate({&input, &size}, {}, {}, {true, false});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1135,7 +1135,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({6, 6});
|
auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1162,7 +1162,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({6, 6});
|
auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1190,7 +1190,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({6, 6});
|
auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1228,7 +1228,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({6, 6});
|
auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1266,7 +1266,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({6, 6});
|
auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1304,7 +1304,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({6, 6});
|
auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input, &size}, {}, {}, {true});
|
auto results = op.evaluate({&input, &size}, {}, {}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1342,7 +1342,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
// auto size = NDArrayFactory::create<int>({6, 6});
|
// auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input}, {}, {6, 6}, {true});
|
auto results = op.evaluate({&input}, {}, {6, 6}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1372,7 +1372,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
// auto size = NDArrayFactory::create<int>({6, 6});
|
// auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input}, {}, {6, 6}, {true});
|
auto results = op.evaluate({&input}, {}, {6, 6}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
auto size = NDArrayFactory::create<int>({10, 10});
|
auto size = NDArrayFactory::create<int>({10, 10});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input, &size}, {}, {});
|
auto results = op.evaluate({&input, &size}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1426,7 +1426,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
//auto size = NDArrayFactory::create<int>({10, 10});
|
//auto size = NDArrayFactory::create<int>({10, 10});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input}, {}, {10, 10});
|
auto results = op.evaluate({&input}, {}, {10, 10});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1453,7 +1453,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
//auto size = NDArrayFactory::create<int>({10, 10});
|
//auto size = NDArrayFactory::create<int>({10, 10});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input}, {}, {6, 9});
|
auto results = op.evaluate({&input}, {}, {6, 9});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1480,7 +1480,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
//auto size = NDArrayFactory::create<int>({10, 10});
|
//auto size = NDArrayFactory::create<int>({10, 10});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input}, {}, {10, 15});
|
auto results = op.evaluate({&input}, {}, {10, 15});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1507,7 +1507,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) {
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
//auto size = NDArrayFactory::create<int>({10, 10});
|
//auto size = NDArrayFactory::create<int>({10, 10});
|
||||||
nd4j::ops::resize_area op;
|
nd4j::ops::resize_area op;
|
||||||
auto results = op.execute({&input}, {}, {9, 9});
|
auto results = op.evaluate({&input}, {}, {9, 9});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1558,7 +1558,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1618,7 +1618,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1650,7 +1650,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1679,7 +1679,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1711,7 +1711,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1737,7 +1737,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1770,7 +1770,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) {
|
||||||
weights.p(3, 0.);
|
weights.p(3, 0.);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1805,7 +1805,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1863,7 +1863,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1896,7 +1896,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) {
|
||||||
weights.t<double>(3) = 0.;
|
weights.t<double>(3) = 0.;
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1933,7 +1933,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) {
|
||||||
weights.t<double>(2) = 0.;
|
weights.t<double>(2) = 0.;
|
||||||
|
|
||||||
nd4j::ops::mean_sqerr_loss_grad op;
|
nd4j::ops::mean_sqerr_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1956,7 +1956,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) {
|
||||||
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9});
|
auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9});
|
||||||
nd4j::ops::squaredsubtract op;
|
nd4j::ops::squaredsubtract op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
|
|
||||||
|
@ -1968,7 +1968,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) {
|
||||||
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9});
|
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9});
|
||||||
nd4j::ops::squaredsubtract op;
|
nd4j::ops::squaredsubtract op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -1980,7 +1980,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48});
|
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48});
|
||||||
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8});
|
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8});
|
||||||
nd4j::ops::squaredsubtract_bp op;
|
nd4j::ops::squaredsubtract_bp op;
|
||||||
auto result = op.execute({&x, &y, &eps}, {}, {});
|
auto result = op.evaluate({&x, &y, &eps}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -2003,7 +2003,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2035,7 +2035,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2063,7 +2063,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2095,7 +2095,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2124,7 +2124,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2156,7 +2156,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2182,7 +2182,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2215,7 +2215,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) {
|
||||||
weights.p(3, 0.);
|
weights.p(3, 0.);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2250,7 +2250,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2282,7 +2282,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2308,7 +2308,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2341,7 +2341,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) {
|
||||||
weights.t<double>(3) = 0.;
|
weights.t<double>(3) = 0.;
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2378,7 +2378,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) {
|
||||||
weights.t<double>(2) = 0.;
|
weights.t<double>(2) = 0.;
|
||||||
|
|
||||||
nd4j::ops::absolute_difference_loss_grad op;
|
nd4j::ops::absolute_difference_loss_grad op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2407,7 +2407,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_1) {
|
||||||
y.linspace(1);
|
y.linspace(1);
|
||||||
exp.linspace(2,2);
|
exp.linspace(2,2);
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2428,7 +2428,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_2) {
|
||||||
y.linspace(1);
|
y.linspace(1);
|
||||||
exp.linspace(2,2);
|
exp.linspace(2,2);
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2449,7 +2449,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_3) {
|
||||||
y.linspace(1);
|
y.linspace(1);
|
||||||
exp.linspace(2,2);
|
exp.linspace(2,2);
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2478,7 +2478,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2514,7 +2514,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {0});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2550,7 +2550,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2582,7 +2582,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2613,7 +2613,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2645,7 +2645,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2671,7 +2671,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2705,7 +2705,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) {
|
||||||
weights.p(3, 0.);
|
weights.p(3, 0.);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2741,7 +2741,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2773,7 +2773,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2799,7 +2799,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) {
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2834,7 +2834,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2872,7 +2872,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) {
|
||||||
weights.t<double>(2) = 0.;
|
weights.t<double>(2) = 0.;
|
||||||
|
|
||||||
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2901,7 +2901,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_4) {
|
||||||
y.linspace(1);
|
y.linspace(1);
|
||||||
exp.linspace(2,2);
|
exp.linspace(2,2);
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2922,7 +2922,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_5) {
|
||||||
y.linspace(1);
|
y.linspace(1);
|
||||||
exp.linspace(1);
|
exp.linspace(1);
|
||||||
nd4j::ops::subtract op;
|
nd4j::ops::subtract op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2943,7 +2943,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_6) {
|
||||||
y.linspace(1);
|
y.linspace(1);
|
||||||
exp.linspace(1);
|
exp.linspace(1);
|
||||||
nd4j::ops::subtract op;
|
nd4j::ops::subtract op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2968,7 +2968,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3030,7 +3030,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3061,7 +3061,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3092,7 +3092,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3123,7 +3123,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3154,7 +3154,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3196,7 +3196,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2});
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3237,7 +3237,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &labels}, {}, {});
|
auto results = op.evaluate({&logits, &labels}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3261,7 +3261,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &labels}, {}, {1});
|
auto results = op.evaluate({&logits, &labels}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3284,7 +3284,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &labels}, {}, {0});
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3306,7 +3306,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &labels}, {}, {1});
|
auto results = op.evaluate({&logits, &labels}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3328,7 +3328,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &labels}, {}, {0});
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3350,7 +3350,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &labels}, {}, {0});
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3372,7 +3372,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &labels}, {}, {0});
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3394,7 +3394,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
|
||||||
|
|
||||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&logits, &labels}, {}, {0});
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3421,7 +3421,7 @@ TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) {
|
||||||
dLdpExp.assign(1.0);
|
dLdpExp.assign(1.0);
|
||||||
nd4j::ops::multiply_bp op;
|
nd4j::ops::multiply_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&x, &y, &dLdp}, {}, {});
|
auto results = op.evaluate({&x, &y, &dLdp}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3444,7 +3444,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) {
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&labels, &logits}, {}, {});
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3468,7 +3468,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&labels, &logits}, {}, {});
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3490,7 +3490,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&labels, &logits}, {}, {});
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3514,7 +3514,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&labels, &logits}, {}, {});
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3536,7 +3536,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
|
||||||
|
|
||||||
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&labels, &logits}, {}, {});
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ TEST_F(DeclarableOpsTests12, test_any_validation_1) {
|
||||||
auto y = NDArrayFactory::create<int>('c', {2}, {1, 0});
|
auto y = NDArrayFactory::create<int>('c', {2}, {1, 0});
|
||||||
|
|
||||||
nd4j::ops::transpose op;
|
nd4j::ops::transpose op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -69,7 +69,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test1) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {0, -1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, -1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -169,7 +169,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {1, 1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -204,7 +204,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2, 0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -238,7 +238,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test6) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3, 1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -274,7 +274,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test7) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2, 0});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -310,7 +310,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test8) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {3, 1});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -346,7 +346,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) {
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss_grad op;
|
nd4j::ops::cosine_distance_loss_grad op;
|
||||||
|
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 2});
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -422,7 +422,7 @@ TEST_F(DeclarableOpsTests12, TestDivideBP_2) {
|
||||||
eps.linspace(1.);
|
eps.linspace(1.);
|
||||||
|
|
||||||
nd4j::ops::divide_bp op;
|
nd4j::ops::divide_bp op;
|
||||||
Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {});
|
Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector<NDArray*>{&output1, &output2}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(output1.equalsTo(exp1));
|
ASSERT_TRUE(output1.equalsTo(exp1));
|
||||||
|
@ -443,7 +443,7 @@ TEST_F(DeclarableOpsTests12, TestReverseDivideBP_1) {
|
||||||
eps.linspace(1.);
|
eps.linspace(1.);
|
||||||
|
|
||||||
nd4j::ops::reversedivide_bp op;
|
nd4j::ops::reversedivide_bp op;
|
||||||
Nd4jStatus status = op.execute({&y, &x, &eps}, {&output2, &output1}, {}, {}, {});
|
Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector<NDArray*>{&output2, &output1}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
//ASSERT_TRUE(output.e<double>(0) == 47.);
|
//ASSERT_TRUE(output.e<double>(0) == 47.);
|
||||||
|
@ -467,7 +467,7 @@ TEST_F(DeclarableOpsTests12, TestReverseDivideBP_2) {
|
||||||
exp1.assign(1.);
|
exp1.assign(1.);
|
||||||
exp2.assign(-2.);
|
exp2.assign(-2.);
|
||||||
nd4j::ops::reversedivide_bp op;
|
nd4j::ops::reversedivide_bp op;
|
||||||
Nd4jStatus status = op.execute({&y, &x, &eps}, {&output2, &output1}, {}, {}, {});
|
Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector<NDArray*>{&output2, &output1}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(output1.equalsTo(exp1));
|
ASSERT_TRUE(output1.equalsTo(exp1));
|
||||||
|
@ -539,7 +539,7 @@ TEST_F(DeclarableOpsTests12, TestMaximumBP_1) {
|
||||||
//exp1.assign(1.);
|
//exp1.assign(1.);
|
||||||
//exp2.assign(-2.);
|
//exp2.assign(-2.);
|
||||||
nd4j::ops::maximum_bp op;
|
nd4j::ops::maximum_bp op;
|
||||||
Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {});
|
Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector<NDArray*>{&output1, &output2}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(output1.equalsTo(exp1));
|
ASSERT_TRUE(output1.equalsTo(exp1));
|
||||||
|
@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) {
|
||||||
//exp1.assign(1.);
|
//exp1.assign(1.);
|
||||||
//exp2.assign(-2.);
|
//exp2.assign(-2.);
|
||||||
nd4j::ops::minimum_bp op;
|
nd4j::ops::minimum_bp op;
|
||||||
Nd4jStatus status = op.execute({&x, &y, &eps}, {&output2, &output1}, {}, {}, {});
|
Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector<NDArray*>{&output2, &output1}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(output1.equalsTo(exp1));
|
ASSERT_TRUE(output1.equalsTo(exp1));
|
||||||
|
@ -716,7 +716,7 @@ TEST_F(DeclarableOpsTests12, tensormmul_6) {
|
||||||
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
nd4j::ops::tensormmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1,0, 1,1});
|
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -743,7 +743,7 @@ TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {
|
||||||
exp = 0.333333;
|
exp = 0.333333;
|
||||||
|
|
||||||
nd4j::ops::reduce_mean_bp op;
|
nd4j::ops::reduce_mean_bp op;
|
||||||
auto result = op.execute({&x, &gradO}, {}, {0});
|
auto result = op.evaluate({&x, &gradO}, {}, {0});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
// output->printShapeInfo();
|
// output->printShapeInfo();
|
||||||
|
@ -765,7 +765,7 @@ TEST_F(DeclarableOpsTests12, reduceMeanBp_5) {
|
||||||
exp = 0.2;
|
exp = 0.2;
|
||||||
|
|
||||||
nd4j::ops::reduce_mean_bp op;
|
nd4j::ops::reduce_mean_bp op;
|
||||||
auto result = op.execute({&x, &gradO}, {}, {1});
|
auto result = op.evaluate({&x, &gradO}, {}, {1});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
// output->printShapeInfo();
|
// output->printShapeInfo();
|
||||||
|
@ -783,7 +783,7 @@ TEST_F(DeclarableOpsTests12, reduceSqnormBp_1) {
|
||||||
NDArray gradO('c', {8,6,1}, nd4j::DataType::DOUBLE);
|
NDArray gradO('c', {8,6,1}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::ops::reduce_sqnorm_bp op;
|
nd4j::ops::reduce_sqnorm_bp op;
|
||||||
auto result = op.execute({&x, &gradO}, {1}, {2});
|
auto result = op.evaluate({&x, &gradO}, {1}, {2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -937,7 +937,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_1) {
|
||||||
|
|
||||||
nd4j::ops::lrn_bp op;
|
nd4j::ops::lrn_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &gradO}, {1., 1., 1}, {5});
|
auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(*gradI, exp);
|
ASSERT_EQ(*gradI, exp);
|
||||||
|
@ -968,7 +968,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_2) {
|
||||||
|
|
||||||
nd4j::ops::lrn_bp op;
|
nd4j::ops::lrn_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &gradO}, {1., 1., 1}, {2});
|
auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(*gradI, exp);
|
ASSERT_EQ(*gradI, exp);
|
||||||
|
@ -999,7 +999,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_3) {
|
||||||
|
|
||||||
nd4j::ops::lrn_bp op;
|
nd4j::ops::lrn_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &gradO}, {1., 1., 1}, {7});
|
auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(*gradI, exp);
|
ASSERT_EQ(*gradI, exp);
|
||||||
|
@ -1030,7 +1030,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_4) {
|
||||||
|
|
||||||
nd4j::ops::lrn_bp op;
|
nd4j::ops::lrn_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &gradO}, {1., 1., 1}, {12});
|
auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(*gradI, exp);
|
ASSERT_EQ(*gradI, exp);
|
||||||
|
@ -1053,7 +1053,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_5) {
|
||||||
|
|
||||||
nd4j::ops::lrn_bp op;
|
nd4j::ops::lrn_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &gradO}, {1., 1., 0.5}, {2});
|
auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(*gradI, exp);
|
ASSERT_EQ(*gradI, exp);
|
||||||
|
@ -1072,7 +1072,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_6) {
|
||||||
|
|
||||||
nd4j::ops::lrn_bp op;
|
nd4j::ops::lrn_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {10});
|
auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(*gradI, exp);
|
ASSERT_EQ(*gradI, exp);
|
||||||
|
@ -1126,7 +1126,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_9) {
|
||||||
|
|
||||||
nd4j::ops::lrn_bp op;
|
nd4j::ops::lrn_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {3});
|
auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {3});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
// for (int i = 0; i < exp.lengthOf(); ++i)
|
// for (int i = 0; i < exp.lengthOf(); ++i)
|
||||||
|
@ -1146,7 +1146,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_10) {
|
||||||
|
|
||||||
nd4j::ops::lrn_bp op;
|
nd4j::ops::lrn_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {1});
|
auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(*gradI, exp);
|
ASSERT_EQ(*gradI, exp);
|
||||||
|
@ -1167,7 +1167,7 @@ TEST_F(DeclarableOpsTests12, lrn_1) {
|
||||||
|
|
||||||
nd4j::ops::lrn op;
|
nd4j::ops::lrn op;
|
||||||
|
|
||||||
auto results = op.execute({&input}, {1., 2., 0.5}, {2});
|
auto results = op.evaluate({&input}, {1., 2., 0.5}, {2});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(*output, exp);
|
ASSERT_EQ(*output, exp);
|
||||||
|
@ -1183,7 +1183,7 @@ TEST_F(DeclarableOpsTests12, lrn_2) {
|
||||||
|
|
||||||
nd4j::ops::lrn op;
|
nd4j::ops::lrn op;
|
||||||
|
|
||||||
auto results = op.execute({&input}, {0.1, 2., 0.5}, {5});
|
auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
ASSERT_EQ(*output, exp);
|
ASSERT_EQ(*output, exp);
|
||||||
|
|
||||||
|
@ -1198,7 +1198,7 @@ TEST_F(DeclarableOpsTests12, lrn_3) {
|
||||||
|
|
||||||
nd4j::ops::lrn op;
|
nd4j::ops::lrn op;
|
||||||
|
|
||||||
auto results = op.execute({&input}, {0.1, 2., 0.5}, {5});
|
auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
ASSERT_EQ(*output, exp);
|
ASSERT_EQ(*output, exp);
|
||||||
|
|
||||||
|
@ -1213,7 +1213,7 @@ TEST_F(DeclarableOpsTests12, lrn_4) {
|
||||||
|
|
||||||
nd4j::ops::lrn op;
|
nd4j::ops::lrn op;
|
||||||
|
|
||||||
auto results = op.execute({&input}, {0.1, 2., 0.5}, {0});
|
auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
ASSERT_EQ(*output, exp);
|
ASSERT_EQ(*output, exp);
|
||||||
|
|
||||||
|
@ -1228,7 +1228,7 @@ TEST_F(DeclarableOpsTests12, lrn_5) {
|
||||||
|
|
||||||
nd4j::ops::lrn op;
|
nd4j::ops::lrn op;
|
||||||
|
|
||||||
auto results = op.execute({&input}, {0.1, 2., 0.5}, {0});
|
auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
ASSERT_EQ(*output, exp);
|
ASSERT_EQ(*output, exp);
|
||||||
|
|
||||||
|
@ -1268,7 +1268,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) {
|
||||||
|
|
||||||
nd4j::ops::in_top_k op;
|
nd4j::ops::in_top_k op;
|
||||||
|
|
||||||
auto res = op.execute({&input, &idx}, {}, {1}, {}, false, nd4j::DataType::BOOL);
|
auto res = op.evaluate({&input, &idx}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
//res->at(0)->printIndexedBuffer("IN_TOP_K output");
|
//res->at(0)->printIndexedBuffer("IN_TOP_K output");
|
||||||
|
@ -1283,7 +1283,7 @@ TEST_F(DeclarableOpsTests12, inTopK_3) {
|
||||||
auto expV = NDArrayFactory::create<bool>('c', {2}, {true, false});
|
auto expV = NDArrayFactory::create<bool>('c', {2}, {true, false});
|
||||||
|
|
||||||
nd4j::ops::in_top_k op;
|
nd4j::ops::in_top_k op;
|
||||||
auto result = op.execute({&x, &y}, {}, {2});
|
auto result = op.evaluate({&x, &y}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
@ -1303,7 +1303,7 @@ TEST_F(DeclarableOpsTests12, inTopK_4) {
|
||||||
auto expV = NDArrayFactory::create<bool>('c', {6}, {true, false, true, false, false, true});
|
auto expV = NDArrayFactory::create<bool>('c', {6}, {true, false, true, false, false, true});
|
||||||
|
|
||||||
nd4j::ops::in_top_k op;
|
nd4j::ops::in_top_k op;
|
||||||
auto result = op.execute({&x, &y}, {}, {2});
|
auto result = op.evaluate({&x, &y}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
@ -1324,7 +1324,7 @@ TEST_F(DeclarableOpsTests12, inTopK_5) {
|
||||||
auto expV = NDArrayFactory::create<bool>('f', {6}, {true, false, false, false, false, false });
|
auto expV = NDArrayFactory::create<bool>('f', {6}, {true, false, false, false, false, false });
|
||||||
|
|
||||||
nd4j::ops::in_top_k op;
|
nd4j::ops::in_top_k op;
|
||||||
auto result = op.execute({&x, &y}, {}, {2});
|
auto result = op.evaluate({&x, &y}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
@ -1345,7 +1345,7 @@ TEST_F(DeclarableOpsTests12, cube_1) {
|
||||||
|
|
||||||
nd4j::ops::cube op;
|
nd4j::ops::cube op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1368,7 +1368,7 @@ TEST_F(DeclarableOpsTests12, cube_bp_1) {
|
||||||
|
|
||||||
nd4j::ops::cube_bp op;
|
nd4j::ops::cube_bp op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &gradO}, {}, {});
|
auto result = op.evaluate({&x, &gradO});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1391,7 +1391,7 @@ TEST_F(DeclarableOpsTests12, pad_tests1) {
|
||||||
NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1418,7 +1418,7 @@ TEST_F(DeclarableOpsTests12, pad_tests2) {
|
||||||
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
|
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1445,7 +1445,7 @@ TEST_F(DeclarableOpsTests12, pad_tests3) {
|
||||||
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
|
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1476,7 +1476,7 @@ TEST_F(DeclarableOpsTests12, pad_tests4) {
|
||||||
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7});
|
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1510,7 +1510,7 @@ TEST_F(DeclarableOpsTests12, pad_tests5) {
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1537,7 +1537,7 @@ TEST_F(DeclarableOpsTests12, pad_tests6) {
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1563,7 +1563,7 @@ TEST_F(DeclarableOpsTests12, pad_tests7)
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1589,7 +1589,7 @@ TEST_F(DeclarableOpsTests12, pad_tests8)
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1615,7 +1615,7 @@ TEST_F(DeclarableOpsTests12, pad_tests9)
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1638,7 +1638,7 @@ TEST_F(DeclarableOpsTests12, pad_tests10) {
|
||||||
input = 1.f;
|
input = 1.f;
|
||||||
//input.assign(1.);
|
//input.assign(1.);
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1660,7 +1660,7 @@ TEST_F(DeclarableOpsTests12, pad_tests11) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1688,7 +1688,7 @@ TEST_F(DeclarableOpsTests12, pad_tests12) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1710,7 +1710,7 @@ TEST_F(DeclarableOpsTests12, pad_tests13) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1732,7 +1732,7 @@ TEST_F(DeclarableOpsTests12, pad_tests14) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1753,7 +1753,7 @@ TEST_F(DeclarableOpsTests12, pad_tests15) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1774,7 +1774,7 @@ TEST_F(DeclarableOpsTests12, pad_tests16) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1795,7 +1795,7 @@ TEST_F(DeclarableOpsTests12, pad_tests17) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1816,7 +1816,7 @@ TEST_F(DeclarableOpsTests12, pad_tests18) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests12, pad_tests19) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1858,7 +1858,7 @@ TEST_F(DeclarableOpsTests12, pad_tests20) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1880,7 +1880,7 @@ TEST_F(DeclarableOpsTests12, pad_tests21) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1903,7 +1903,7 @@ TEST_F(DeclarableOpsTests12, pad_tests22) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1926,7 +1926,7 @@ TEST_F(DeclarableOpsTests12, pad_tests23) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1950,7 +1950,7 @@ TEST_F(DeclarableOpsTests12, pad_tests24) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1972,7 +1972,7 @@ TEST_F(DeclarableOpsTests12, pad_tests25) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1994,7 +1994,7 @@ TEST_F(DeclarableOpsTests12, pad_tests26) {
|
||||||
input.linspace(1.f);
|
input.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2054,7 +2054,7 @@ TEST_F(DeclarableOpsTests12, pad_tests29) {
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
|
|
||||||
auto res = op.execute({&in, &pad}, {10.0}, {0});
|
auto res = op.evaluate({&in, &pad}, {10.0}, {0});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
delete res;
|
delete res;
|
||||||
|
@ -2071,7 +2071,7 @@ TEST_F(DeclarableOpsTests12, pad_tests30) {
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
|
|
||||||
auto res = op.execute({&in, &pad}, {10.0}, {2});
|
auto res = op.evaluate({&in, &pad}, {10.0}, {2});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
delete res;
|
delete res;
|
||||||
|
@ -2089,7 +2089,7 @@ TEST_F(DeclarableOpsTests12, pad_tests31) {
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
|
|
||||||
auto res = op.execute({&in, &pad}, {10.0}, {1});
|
auto res = op.evaluate({&in, &pad}, {10.0}, {1});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
delete res;
|
delete res;
|
||||||
|
@ -2105,7 +2105,7 @@ TEST_F(DeclarableOpsTests12, pad_tests32) {
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
|
|
||||||
auto res = op.execute({&in, &pad}, {10.0}, {2});
|
auto res = op.evaluate({&in, &pad}, {10.0}, {2});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
delete res;
|
delete res;
|
||||||
|
@ -2128,7 +2128,7 @@ TEST_F(DeclarableOpsTests12, pad_tests33) {
|
||||||
11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.});
|
11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.});
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
|
|
||||||
auto res = op.execute({&in, &pad}, {10.0}, {2});
|
auto res = op.evaluate({&in, &pad}, {10.0}, {2});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
delete res;
|
delete res;
|
||||||
|
@ -2163,7 +2163,7 @@ TEST_F(DeclarableOpsTests12, Pad_1) {
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2190,7 +2190,7 @@ TEST_F(DeclarableOpsTests12, Pad_2) {
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2217,7 +2217,7 @@ TEST_F(DeclarableOpsTests12, Pad_3) {
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2244,7 +2244,7 @@ TEST_F(DeclarableOpsTests12, Pad_4) {
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2271,7 +2271,7 @@ TEST_F(DeclarableOpsTests12, Pad_5) {
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2298,7 +2298,7 @@ TEST_F(DeclarableOpsTests12, Pad_6) {
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2324,7 +2324,7 @@ TEST_F(DeclarableOpsTests12, Pad_7)
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {0});
|
auto results = op.evaluate({&input, &paddings}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2350,7 +2350,7 @@ TEST_F(DeclarableOpsTests12, Pad_8)
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {1});
|
auto results = op.evaluate({&input, &paddings}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2376,7 +2376,7 @@ TEST_F(DeclarableOpsTests12, Pad_9)
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
auto results = op.execute({&input, &paddings}, {}, {2});
|
auto results = op.evaluate({&input, &paddings}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2395,7 +2395,7 @@ TEST_F(DeclarableOpsTests12, Test_Expose_1) {
|
||||||
|
|
||||||
nd4j::ops::expose op;
|
nd4j::ops::expose op;
|
||||||
|
|
||||||
auto result = op.execute({&input0, &input1}, {}, {});
|
auto result = op.evaluate({&input0, &input1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -2420,7 +2420,7 @@ TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) {
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
|
|
||||||
auto res = op.execute({&in, &pad}, {10.0}, {0});
|
auto res = op.evaluate({&in, &pad}, {10.0}, {0});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
// res->at(0)->printIndexedBuffer("PAD_SGO");
|
// res->at(0)->printIndexedBuffer("PAD_SGO");
|
||||||
// exp.printIndexedBuffer("PAD_EXP");
|
// exp.printIndexedBuffer("PAD_EXP");
|
||||||
|
@ -2436,7 +2436,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_1) {
|
||||||
auto pExp = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
|
auto pExp = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2457,7 +2457,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_2) {
|
||||||
auto expP = NDArrayFactory::create<int>({2, 0, 1});
|
auto expP = NDArrayFactory::create<int>({2, 0, 1});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2480,7 +2480,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3) {
|
||||||
auto expP = NDArrayFactory::create<int>({2, 1, 0});
|
auto expP = NDArrayFactory::create<int>({2, 1, 0});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2522,7 +2522,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4) {
|
||||||
auto expP = NDArrayFactory::create<int>({1, 2, 7, 3, 6, 8, 5, 4, 0, 9});
|
auto expP = NDArrayFactory::create<int>({1, 2, 7, 3, 6, 8, 5, 4, 0, 9});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2592,7 +2592,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_5) {
|
||||||
});
|
});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2613,7 +2613,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_1_2) {
|
||||||
|
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2641,7 +2641,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_2) {
|
||||||
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 2, 1, 0});
|
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 2, 1, 0});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2669,7 +2669,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_3) {
|
||||||
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 0, 2, 1});
|
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 0, 2, 1});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2697,7 +2697,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_1) {
|
||||||
auto expP = NDArrayFactory::create<int>('c', {2,2}, {0, 1, 0, 1});
|
auto expP = NDArrayFactory::create<int>('c', {2,2}, {0, 1, 0, 1});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {});
|
auto res = op.evaluate({&in});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2725,7 +2725,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_2) {
|
||||||
auto expP = NDArrayFactory::create<Nd4jLong>('c', {2,2}, {0, 1, 0, 1});
|
auto expP = NDArrayFactory::create<Nd4jLong>('c', {2,2}, {0, 1, 0, 1});
|
||||||
nd4j::ops::lu op;
|
nd4j::ops::lu op;
|
||||||
|
|
||||||
auto res = op.execute({&in}, {}, {nd4j::DataType::INT64});
|
auto res = op.evaluate({&in}, {}, {nd4j::DataType::INT64});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
auto p = res->at(1);
|
auto p = res->at(1);
|
||||||
|
@ -2750,7 +2750,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1) {
|
||||||
auto expR = NDArrayFactory::create<double>('c', {5,3}, {
|
auto expR = NDArrayFactory::create<double>('c', {5,3}, {
|
||||||
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. });
|
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. });
|
||||||
nd4j::ops::qr op;
|
nd4j::ops::qr op;
|
||||||
auto res = op.execute({&in}, {}, {}, {true});
|
auto res = op.evaluate({&in}, {}, {}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto q = res->at(0);
|
auto q = res->at(0);
|
||||||
|
@ -2762,7 +2762,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1) {
|
||||||
// q->printShapeInfo("Q shape");
|
// q->printShapeInfo("Q shape");
|
||||||
// r->printShapeInfo("R shape");
|
// r->printShapeInfo("R shape");
|
||||||
nd4j::ops::matmul opMul;
|
nd4j::ops::matmul opMul;
|
||||||
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false);
|
auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false);
|
||||||
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
|
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
|
||||||
ASSERT_TRUE(exp->isSameShape(in));
|
ASSERT_TRUE(exp->isSameShape(in));
|
||||||
// ASSERT_TRUE(q->isSameShape(expQ));
|
// ASSERT_TRUE(q->isSameShape(expQ));
|
||||||
|
@ -2797,7 +2797,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) {
|
||||||
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0.
|
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0.
|
||||||
});
|
});
|
||||||
nd4j::ops::qr op;
|
nd4j::ops::qr op;
|
||||||
auto res = op.execute({&in}, {}, {}, {true});
|
auto res = op.evaluate({&in}, {}, {}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto q = res->at(0);
|
auto q = res->at(0);
|
||||||
|
@ -2809,7 +2809,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) {
|
||||||
// q->printShapeInfo("Q shape");
|
// q->printShapeInfo("Q shape");
|
||||||
// r->printShapeInfo("R shape");
|
// r->printShapeInfo("R shape");
|
||||||
nd4j::ops::matmul opMul;
|
nd4j::ops::matmul opMul;
|
||||||
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false);
|
auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false);
|
||||||
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
|
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
|
||||||
ASSERT_TRUE(exp->isSameShape(in));
|
ASSERT_TRUE(exp->isSameShape(in));
|
||||||
// ASSERT_TRUE(q->isSameShape(expQ));
|
// ASSERT_TRUE(q->isSameShape(expQ));
|
||||||
|
@ -2836,7 +2836,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::qr op;
|
nd4j::ops::qr op;
|
||||||
auto res = op.execute({&in}, {}, {}, {false});
|
auto res = op.evaluate({&in}, {}, {}, {false});
|
||||||
|
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto q = res->at(0);
|
auto q = res->at(0);
|
||||||
|
@ -2847,7 +2847,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) {
|
||||||
// r->printIndexedBuffer("Upper triangular 5x3");
|
// r->printIndexedBuffer("Upper triangular 5x3");
|
||||||
|
|
||||||
nd4j::ops::matmul opMul;
|
nd4j::ops::matmul opMul;
|
||||||
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false);
|
auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false);
|
||||||
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
|
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
|
||||||
ASSERT_TRUE(exp->isSameShape(in));
|
ASSERT_TRUE(exp->isSameShape(in));
|
||||||
ASSERT_TRUE(exp->equalsTo(in));
|
ASSERT_TRUE(exp->equalsTo(in));
|
||||||
|
@ -2874,7 +2874,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) {
|
||||||
|
|
||||||
nd4j::ops::triangular_solve op;
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
auto res = op.execute({&a, &b}, {}, {});
|
auto res = op.evaluate({&a, &b});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
@ -2903,7 +2903,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_2) {
|
||||||
|
|
||||||
nd4j::ops::triangular_solve op;
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
auto res = op.execute({&a, &b}, {}, {});
|
auto res = op.evaluate({&a, &b});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
@ -2940,7 +2940,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_3) {
|
||||||
|
|
||||||
nd4j::ops::triangular_solve op;
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
auto res = op.execute({&a, &b}, {}, {});
|
auto res = op.evaluate({&a, &b});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
@ -2969,7 +2969,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) {
|
||||||
|
|
||||||
nd4j::ops::triangular_solve op;
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
auto res = op.execute({&a, &b}, {}, {}, {false});
|
auto res = op.evaluate({&a, &b}, {}, {}, {false});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
|
||||||
|
|
||||||
nd4j::ops::triangular_solve op;
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
auto res = op.execute({&a, &b}, {}, {}, {false, true});
|
auto res = op.evaluate({&a, &b}, {}, {}, {false, true});
|
||||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
auto z = res->at(0);
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ TEST_F(DeclarableOpsTests13, test_pow_1) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {2, 2}, {8.f, 8.f, 8.f, 8.f});
|
auto e = NDArrayFactory::create<float>('c', {2, 2}, {8.f, 8.f, 8.f, 8.f});
|
||||||
|
|
||||||
nd4j::ops::Pow op;
|
nd4j::ops::Pow op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -73,7 +73,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) {
|
||||||
auto limit = NDArrayFactory::create<int>(0);
|
auto limit = NDArrayFactory::create<int>(0);
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({&start, &limit}, {}, {});
|
auto result = op.evaluate({&start, &limit});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -85,7 +85,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) {
|
||||||
TEST_F(DeclarableOpsTests13, test_empty_range_2) {
|
TEST_F(DeclarableOpsTests13, test_empty_range_2) {
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {1.0, 1.0}, {});
|
auto result = op.evaluate({}, {1.0, 1.0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -97,7 +97,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_2) {
|
||||||
TEST_F(DeclarableOpsTests13, test_empty_range_3) {
|
TEST_F(DeclarableOpsTests13, test_empty_range_3) {
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {}, {1, 1});
|
auto result = op.evaluate({}, {1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -143,7 +143,7 @@ TEST_F(DeclarableOpsTests13, test_listdiff_1) {
|
||||||
auto oi = NDArrayFactory::create<int>('c', {2});
|
auto oi = NDArrayFactory::create<int>('c', {2});
|
||||||
|
|
||||||
nd4j::ops::listdiff op;
|
nd4j::ops::listdiff op;
|
||||||
auto result = op.execute({&x, &y}, {&od, &oi}, {}, {}, {});
|
auto result = op.execute({&x, &y}, std::vector<NDArray*>{&od, &oi}, {}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result);
|
ASSERT_EQ(Status::OK(), result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ TEST_F(DeclarableOpsTests13, test_greater_1) {
|
||||||
auto y = NDArrayFactory::create<float>('c', {1, 4});
|
auto y = NDArrayFactory::create<float>('c', {1, 4});
|
||||||
|
|
||||||
nd4j::ops::greater op;
|
nd4j::ops::greater op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -165,7 +165,7 @@ TEST_F(DeclarableOpsTests13, test_eval_reduction_shape_1) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2}, {1, 2});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2}, {1, 2});
|
||||||
|
|
||||||
nd4j::ops::evaluate_reduction_shape op;
|
nd4j::ops::evaluate_reduction_shape op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {true});
|
auto result = op.evaluate({&x, &y}, {true});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -218,7 +218,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_1) {
|
||||||
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,3}, {1.2,2.2,3.2,4.2,5.2,6.2});
|
auto exp = NDArrayFactory::create<double>('c', {2,3}, {1.2,2.2,3.2,4.2,5.2,6.2});
|
||||||
nd4j::ops::barnes_gains op;
|
nd4j::ops::barnes_gains op;
|
||||||
auto result = op.execute({&x, &y, &eps}, {}, {});
|
auto result = op.evaluate({&x, &y, &eps});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
//result->at(0)->printBuffer("Gains out");
|
//result->at(0)->printBuffer("Gains out");
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
|
@ -232,7 +232,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_2) {
|
||||||
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,3}, {1.2, 0.01, 3.2, 0.01, 5.2, 0.01});
|
auto exp = NDArrayFactory::create<double>('c', {2,3}, {1.2, 0.01, 3.2, 0.01, 5.2, 0.01});
|
||||||
nd4j::ops::barnes_gains op;
|
nd4j::ops::barnes_gains op;
|
||||||
auto result = op.execute({&x, &y, &eps}, {}, {});
|
auto result = op.evaluate({&x, &y, &eps}, {}, {});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
//result->at(0)->printBuffer("Gains out");
|
//result->at(0)->printBuffer("Gains out");
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
|
@ -247,7 +247,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_3) {
|
||||||
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,3}, {0.01, 2.2, 0.01, 4.2, 0.01, 6.2});
|
auto exp = NDArrayFactory::create<double>('c', {2,3}, {0.01, 2.2, 0.01, 4.2, 0.01, 6.2});
|
||||||
nd4j::ops::barnes_gains op;
|
nd4j::ops::barnes_gains op;
|
||||||
auto result = op.execute({&x, &y, &eps}, {}, {});
|
auto result = op.evaluate({&x, &y, &eps}, {}, {});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
//result->at(0)->printBuffer("Gains out");
|
//result->at(0)->printBuffer("Gains out");
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
|
@ -269,7 +269,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) {
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::barnes_edge_forces op;
|
nd4j::ops::barnes_edge_forces op;
|
||||||
auto result = op.execute({&rows, &cols, &vals, &data}, {}, {1});
|
auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {1});
|
||||||
|
|
||||||
|
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
|
@ -293,7 +293,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) {
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::barnes_edge_forces op;
|
nd4j::ops::barnes_edge_forces op;
|
||||||
auto result = op.execute({&rows, &cols, &vals, &data}, {}, {2});
|
auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {2});
|
||||||
|
|
||||||
|
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
|
@ -317,7 +317,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::barnes_edge_forces op;
|
nd4j::ops::barnes_edge_forces op;
|
||||||
auto result = op.execute({&rows, &cols, &vals, &data}, {}, {11});
|
auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {11});
|
||||||
|
|
||||||
//nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf());
|
//nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf());
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
|
@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) {
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::barnes_symmetrized op;
|
nd4j::ops::barnes_symmetrized op;
|
||||||
auto result = op.execute({&rows, &cols, &vals}, {}, {1});
|
auto result = op.evaluate({&rows, &cols, &vals}, {}, {1});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
//result->at(2)->printBuffer("Symmetrized1");
|
//result->at(2)->printBuffer("Symmetrized1");
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(2)));
|
ASSERT_TRUE(exp.equalsTo(result->at(2)));
|
||||||
|
@ -359,7 +359,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) {
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::barnes_symmetrized op;
|
nd4j::ops::barnes_symmetrized op;
|
||||||
auto result = op.execute({&rows, &cols, &vals}, {}, {3});
|
auto result = op.evaluate({&rows, &cols, &vals}, {}, {3});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
//result->at(2)->printBuffer("Symmetrized2");
|
//result->at(2)->printBuffer("Symmetrized2");
|
||||||
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
||||||
|
@ -378,7 +378,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) {
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::barnes_symmetrized op;
|
nd4j::ops::barnes_symmetrized op;
|
||||||
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
|
auto result = op.evaluate({&rows, &cols, &vals}, {}, {11});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
//result->at(2)->printBuffer("Symmetrized3");
|
//result->at(2)->printBuffer("Symmetrized3");
|
||||||
//exp.printBuffer("EXPect symm3");
|
//exp.printBuffer("EXPect symm3");
|
||||||
|
@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::barnes_symmetrized op;
|
nd4j::ops::barnes_symmetrized op;
|
||||||
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
|
auto result = op.evaluate({&rows, &cols, &vals}, {}, {11});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
auto res = result->at(2);
|
auto res = result->at(2);
|
||||||
// res->printBuffer("Symmetrized4");
|
// res->printBuffer("Symmetrized4");
|
||||||
|
@ -428,7 +428,7 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) {
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::cell_contains op;
|
nd4j::ops::cell_contains op;
|
||||||
auto result = op.execute({&corners, &width, &point}, {}, {5});
|
auto result = op.evaluate({&corners, &width, &point}, {}, {5});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
ASSERT_TRUE(result->at(0)->e<bool>(0));
|
ASSERT_TRUE(result->at(0)->e<bool>(0));
|
||||||
//result->at(2)->printBuffer("Symmetrized3");
|
//result->at(2)->printBuffer("Symmetrized3");
|
||||||
|
@ -446,7 +446,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_1) {
|
||||||
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_hue op;
|
nd4j::ops::adjust_hue op;
|
||||||
std::unique_ptr<nd4j::ResultSet> results (op.execute({&input, &factor}, {}, {2}));
|
std::unique_ptr<nd4j::ResultSet> results (op.evaluate({&input, &factor}, {}, {2}));
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -467,7 +467,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_2) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::adjust_hue op;
|
nd4j::ops::adjust_hue op;
|
||||||
std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.9}, {2}));
|
std::unique_ptr<nd4j::ResultSet> results(op.evaluate({&input}, {0.9}, {2}));
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_3) {
|
||||||
NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_hue op;
|
nd4j::ops::adjust_hue op;
|
||||||
std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {-0.9}, {2}));
|
std::unique_ptr<nd4j::ResultSet> results(op.evaluate({&input}, {-0.9}, {2}));
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -506,7 +506,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_4) {
|
||||||
NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_hue op;
|
nd4j::ops::adjust_hue op;
|
||||||
std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.5}, {1}));
|
std::unique_ptr<nd4j::ResultSet> results(op.evaluate({&input}, {0.5}, {1}));
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -525,7 +525,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) {
|
||||||
NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_hue op;
|
nd4j::ops::adjust_hue op;
|
||||||
std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.5}, {0}));
|
std::unique_ptr<nd4j::ResultSet> results(op.evaluate({&input}, {0.5}, {0}));
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -545,7 +545,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
||||||
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_saturation op;
|
nd4j::ops::adjust_saturation op;
|
||||||
auto results = op.execute({&input, &factor}, {}, {2});
|
auto results = op.evaluate({&input, &factor}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
|
||||||
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE);
|
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::ops::adjust_saturation op;
|
nd4j::ops::adjust_saturation op;
|
||||||
auto results = op.execute({&input}, {10}, {2});
|
auto results = op.evaluate({&input}, {10}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -585,7 +585,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_3) {
|
||||||
NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_saturation op;
|
nd4j::ops::adjust_saturation op;
|
||||||
auto results = op.execute({&input}, {-10}, {2});
|
auto results = op.evaluate({&input}, {-10}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -605,7 +605,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_4) {
|
||||||
NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_saturation op;
|
nd4j::ops::adjust_saturation op;
|
||||||
auto results = op.execute({&input}, {0.5}, {1});
|
auto results = op.evaluate({&input}, {0.5}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -625,7 +625,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
|
||||||
NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_saturation op;
|
nd4j::ops::adjust_saturation op;
|
||||||
auto results = op.execute({&input}, {0.5}, {0});
|
auto results = op.evaluate({&input}, {0.5}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -646,7 +646,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) {
|
||||||
e.assign(512);
|
e.assign(512);
|
||||||
|
|
||||||
nd4j::ops::shift_bits op;
|
nd4j::ops::shift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -664,7 +664,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_1) {
|
||||||
e.assign(32);
|
e.assign(32);
|
||||||
|
|
||||||
nd4j::ops::rshift_bits op;
|
nd4j::ops::rshift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -682,7 +682,7 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
|
||||||
e.assign(512);
|
e.assign(512);
|
||||||
|
|
||||||
nd4j::ops::cyclic_shift_bits op;
|
nd4j::ops::cyclic_shift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -700,7 +700,7 @@ TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) {
|
||||||
e.assign(32);
|
e.assign(32);
|
||||||
|
|
||||||
nd4j::ops::cyclic_rshift_bits op;
|
nd4j::ops::cyclic_rshift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -719,7 +719,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_2) {
|
||||||
e.assign(512);
|
e.assign(512);
|
||||||
|
|
||||||
nd4j::ops::shift_bits op;
|
nd4j::ops::shift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -738,7 +738,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_2) {
|
||||||
e.assign(32);
|
e.assign(32);
|
||||||
|
|
||||||
nd4j::ops::rshift_bits op;
|
nd4j::ops::rshift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -757,7 +757,7 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) {
|
||||||
e.assign(512);
|
e.assign(512);
|
||||||
|
|
||||||
nd4j::ops::cyclic_shift_bits op;
|
nd4j::ops::cyclic_shift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -776,7 +776,7 @@ TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) {
|
||||||
e.assign(32);
|
e.assign(32);
|
||||||
|
|
||||||
nd4j::ops::cyclic_rshift_bits op;
|
nd4j::ops::cyclic_rshift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -794,7 +794,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_3) {
|
||||||
e.assign(512);
|
e.assign(512);
|
||||||
|
|
||||||
nd4j::ops::shift_bits op;
|
nd4j::ops::shift_bits op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -817,7 +817,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) {
|
||||||
exp.linspace(1);
|
exp.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::space_to_batch_nd op;
|
nd4j::ops::space_to_batch_nd op;
|
||||||
auto result = op.execute({&x, &blockShape, &paddings}, {}, {});
|
auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -844,7 +844,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_2) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::space_to_batch_nd op;
|
nd4j::ops::space_to_batch_nd op;
|
||||||
auto result = op.execute({&x, &blockShape, &paddings}, {}, {});
|
auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -875,7 +875,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_3) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::space_to_batch_nd op;
|
nd4j::ops::space_to_batch_nd op;
|
||||||
auto result = op.execute({&x, &blockShape, &paddings}, {}, {});
|
auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -901,7 +901,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) {
|
||||||
exp.linspace(1);
|
exp.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::batch_to_space_nd op;
|
nd4j::ops::batch_to_space_nd op;
|
||||||
auto result = op.execute({&x, &blockShape, &crop}, {}, {});
|
auto result = op.evaluate({&x, &blockShape, &crop}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -924,7 +924,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_2) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::batch_to_space_nd op;
|
nd4j::ops::batch_to_space_nd op;
|
||||||
auto result = op.execute({&x, &blockShape, &crop}, {}, {});
|
auto result = op.evaluate({&x, &blockShape, &crop}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -948,7 +948,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::batch_to_space_nd op;
|
nd4j::ops::batch_to_space_nd op;
|
||||||
auto result = op.execute({&x, &blockShape, &crop}, {}, {});
|
auto result = op.evaluate({&x, &blockShape, &crop}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -974,7 +974,7 @@ TEST_F(DeclarableOpsTests13, mergemax_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::mergemax op;
|
nd4j::ops::mergemax op;
|
||||||
auto result = op.execute({&x1, &x2, &x3}, {}, {});
|
auto result = op.evaluate({&x1, &x2, &x3}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1040,9 +1040,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) {
|
||||||
hI = 1.;
|
hI = 1.;
|
||||||
cI = 2.;
|
cI = 2.;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::vector<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
auto expH = NDArrayFactory::create<float>('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f,
|
auto expH = NDArrayFactory::create<float>('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f,
|
||||||
0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f,
|
0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f,
|
||||||
|
@ -1053,7 +1053,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) {
|
||||||
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f});
|
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f});
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1110,9 +1110,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) {
|
||||||
hI = 1.;
|
hI = 1.;
|
||||||
cI = 2.;
|
cI = 2.;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::vector<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
auto expH = NDArrayFactory::create<float>('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f,
|
auto expH = NDArrayFactory::create<float>('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f,
|
||||||
0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f,
|
0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f,
|
||||||
|
@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) {
|
||||||
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f});
|
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f});
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1178,9 +1178,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) {
|
||||||
hI = 1.;
|
hI = 1.;
|
||||||
cI = 2.;
|
cI = 2.;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::vector<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f,
|
NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f,
|
||||||
0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f,
|
0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f,
|
||||||
|
@ -1190,7 +1190,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) {
|
||||||
NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1256,9 +1256,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
cI({1,2, 0,0, 0,0}) = -2;
|
cI({1,2, 0,0, 0,0}) = -2;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::vector<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, 2 * nOut}, {
|
NDArray expH('c', {sL, bS, 2 * nOut}, {
|
||||||
0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f,
|
0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f,
|
||||||
|
@ -1275,7 +1275,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
|
||||||
-0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32);
|
-0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1340,9 +1340,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) {
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
cI({1,2, 0,0, 0,0}) = -2;
|
cI({1,2, 0,0, 0,0}) = -2;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::vector<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {bS, sL, 2*nOut}, {
|
NDArray expH('c', {bS, sL, 2*nOut}, {
|
||||||
0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f,
|
0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f,
|
||||||
|
@ -1357,7 +1357,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) {
|
||||||
-0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32);
|
-0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1426,9 +1426,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
cI({1,2, 0,0, 0,0}) = -2;
|
cI({1,2, 0,0, 0,0}) = -2;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::vector<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f,
|
0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f,
|
||||||
|
@ -1443,7 +1443,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
|
||||||
nd4j::DataType::FLOAT32);
|
nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1519,7 +1519,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_7) {
|
||||||
NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1597,7 +1597,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) {
|
||||||
NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1684,7 +1684,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) {
|
||||||
-0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32);
|
-0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1769,7 +1769,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) {
|
||||||
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1849,7 +1849,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) {
|
||||||
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1940,7 +1940,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
|
||||||
0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32);
|
0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1977,7 +1977,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test1) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2010,7 +2010,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test2) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2039,7 +2039,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test3) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2067,7 +2067,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test4) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2095,7 +2095,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test5) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2124,7 +2124,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test6) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2193,7 +2193,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test8) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2236,7 +2236,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2271,7 +2271,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2314,7 +2314,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2356,7 +2356,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2395,7 +2395,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2439,7 +2439,7 @@ return;
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2484,7 +2484,7 @@ return;
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2532,7 +2532,7 @@ return;
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2581,7 +2581,7 @@ return;
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2635,7 +2635,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2687,7 +2687,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2751,7 +2751,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) {
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3});
|
auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) {
|
||||||
exp.assign(4.0f);
|
exp.assign(4.0f);
|
||||||
|
|
||||||
nd4j::ops::fill op;
|
nd4j::ops::fill op;
|
||||||
auto result = op.execute({&x}, {4.0f},{}, {});
|
auto result = op.evaluate({&x}, {4.0f});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -62,7 +62,7 @@ TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) {
|
||||||
r.streamline('f');
|
r.streamline('f');
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {3, 2}, {});
|
auto result = op.evaluate({&x}, {3, 2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -96,7 +96,7 @@ TEST_F(DeclarableOpsTests14, Multiply_test) {
|
||||||
e.assign(1.0);
|
e.assign(1.0);
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
auto f = result->at(0);
|
auto f = result->at(0);
|
||||||
NDArray r = *f;
|
NDArray r = *f;
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {2}, {5, 4});
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {2}, {5, 4});
|
||||||
|
|
||||||
nd4j::ops::evaluate_reduction_shape op;
|
nd4j::ops::evaluate_reduction_shape op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {false, false});
|
auto result = op.evaluate({&x, &y}, {}, {}, {false, false});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -128,7 +128,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {3}, {5, 1, 4});
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {3}, {5, 1, 4});
|
||||||
|
|
||||||
nd4j::ops::evaluate_reduction_shape op;
|
nd4j::ops::evaluate_reduction_shape op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {true, false});
|
auto result = op.evaluate({&x, &y}, {}, {}, {true, false});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
@ -200,7 +200,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::subtract op;
|
nd4j::ops::subtract op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
|
||||||
auto y = NDArrayFactory::create<int>(1);
|
auto y = NDArrayFactory::create<int>(1);
|
||||||
|
|
||||||
nd4j::ops::fill op;
|
nd4j::ops::fill op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -250,13 +250,13 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
nd4j::ops::reduce_min sumOp;
|
nd4j::ops::reduce_min sumOp;
|
||||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
|
||||||
ASSERT_EQ(res2->status(), Status::OK());
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
auto out = res2->at(0);
|
auto out = res2->at(0);
|
||||||
|
|
||||||
|
@ -270,7 +270,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_2) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {0});
|
auto e = NDArrayFactory::create<float>('c', {0});
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -284,7 +284,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_3) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto result = op.execute({&x, &x}, {}, {0});
|
auto result = op.evaluate({&x, &x}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -298,7 +298,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_4) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto result = op.execute({&x, &x}, {}, {0});
|
auto result = op.evaluate({&x, &x}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -311,7 +311,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
|
||||||
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
nd4j::ops::reduce_min sumOp;
|
nd4j::ops::reduce_min sumOp;
|
||||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
|
||||||
ASSERT_EQ(res2->status(), Status::OK());
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
auto out = res2->at(0);
|
auto out = res2->at(0);
|
||||||
|
|
||||||
|
@ -323,7 +323,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
|
||||||
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
nd4j::ops::reduce_max sumOp;
|
nd4j::ops::reduce_max sumOp;
|
||||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
|
||||||
ASSERT_EQ(res2->status(), Status::OK());
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
auto out = res2->at(0);
|
auto out = res2->at(0);
|
||||||
|
|
||||||
|
@ -335,7 +335,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
|
||||||
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
nd4j::ops::reduce_sum sumOp;
|
nd4j::ops::reduce_sum sumOp;
|
||||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
|
||||||
ASSERT_EQ(res2->status(), Status::OK());
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
auto out = res2->at(0);
|
auto out = res2->at(0);
|
||||||
ASSERT_EQ(out->e<float>(0), 0.f);
|
ASSERT_EQ(out->e<float>(0), 0.f);
|
||||||
|
@ -346,7 +346,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
|
||||||
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
nd4j::ops::reduce_mean sumOp;
|
nd4j::ops::reduce_mean sumOp;
|
||||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
|
||||||
ASSERT_EQ(res2->status(), Status::OK());
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
auto out = res2->at(0);
|
auto out = res2->at(0);
|
||||||
// out->printShapeInfo("ReduceMean empty shape with keep dims");
|
// out->printShapeInfo("ReduceMean empty shape with keep dims");
|
||||||
|
@ -366,7 +366,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) {
|
||||||
matrix.linspace(1);
|
matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -387,7 +387,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) {
|
||||||
matrix.linspace(1);
|
matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -405,7 +405,7 @@ TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
//nd4j::ops::reduce_max op;
|
//nd4j::ops::reduce_max op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -432,7 +432,7 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {32, 0});
|
auto x = NDArrayFactory::create<float>('c', {32, 0});
|
||||||
|
|
||||||
nd4j::ops::tanh op;
|
nd4j::ops::tanh op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -450,7 +450,7 @@ TEST_F(DeclarableOpsTests14, repeat_1) {
|
||||||
NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
|
NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
|
||||||
|
|
||||||
nd4j::ops::repeat op;
|
nd4j::ops::repeat op;
|
||||||
auto result = op.execute({&x}, {}, {2, 0});
|
auto result = op.evaluate({&x}, {}, {2, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -468,7 +468,7 @@ TEST_F(DeclarableOpsTests14, repeat_2) {
|
||||||
NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6});
|
NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6});
|
||||||
|
|
||||||
nd4j::ops::repeat op;
|
nd4j::ops::repeat op;
|
||||||
auto result = op.execute({&x}, {}, {2, 1});
|
auto result = op.evaluate({&x}, {}, {2, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -486,7 +486,7 @@ TEST_F(DeclarableOpsTests14, repeat_3) {
|
||||||
NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6});
|
NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6});
|
||||||
|
|
||||||
nd4j::ops::repeat op;
|
nd4j::ops::repeat op;
|
||||||
auto result = op.execute({&x}, {}, {1,2,3, 1});
|
auto result = op.evaluate({&x}, {}, {1,2,3, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -504,7 +504,7 @@ TEST_F(DeclarableOpsTests14, repeat_4) {
|
||||||
NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6});
|
NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6});
|
||||||
|
|
||||||
nd4j::ops::repeat op;
|
nd4j::ops::repeat op;
|
||||||
auto result = op.execute({&x}, {}, {3,4, 0});
|
auto result = op.evaluate({&x}, {}, {3,4, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -522,7 +522,7 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
|
||||||
NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24});
|
NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||||
|
|
||||||
nd4j::ops::repeat op;
|
nd4j::ops::repeat op;
|
||||||
auto result = op.execute({&x}, {}, {1,2,1, 1});
|
auto result = op.evaluate({&x}, {}, {1,2,1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -49,7 +49,7 @@ TEST_F(DeclarableOpsTests15, Test_NormalizeMoments_1) {
|
||||||
auto z1 = NDArrayFactory::create<double>('c', {10});
|
auto z1 = NDArrayFactory::create<double>('c', {10});
|
||||||
|
|
||||||
nd4j::ops::normalize_moments op;
|
nd4j::ops::normalize_moments op;
|
||||||
auto result = op.execute({&w, &x, &y}, {&z0, &z1}, {1e-4}, {}, {});
|
auto result = op.execute({&w, &x, &y}, std::vector<NDArray*>{&z0, &z1}, {1e-4}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result);
|
ASSERT_EQ(Status::OK(), result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
||||||
auto eps = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
auto eps = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::standardize_bp op;
|
nd4j::ops::standardize_bp op;
|
||||||
auto result = op.execute({&x, &eps}, {}, {0}, {});
|
auto result = op.evaluate({&x, &eps}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
||||||
|
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::adjust_contrast op;
|
nd4j::ops::adjust_contrast op;
|
||||||
auto result = op.execute({&x, &factor}, {}, {}, {});
|
auto result = op.evaluate({&x, &factor}, {}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
|
||||||
});
|
});
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::adjust_contrast op;
|
nd4j::ops::adjust_contrast op;
|
||||||
auto result = op.execute({&x}, {2.}, {}, {});
|
auto result = op.evaluate({&x}, {2.});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
// out->printIndexedBuffer("Adjusted Constrast");
|
||||||
|
@ -139,7 +139,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
|
||||||
});
|
});
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::adjust_contrast_v2 op;
|
nd4j::ops::adjust_contrast_v2 op;
|
||||||
auto result = op.execute({&x}, {2.}, {}, {});
|
auto result = op.evaluate({&x}, {2.});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
// out->printIndexedBuffer("Adjusted Constrast");
|
||||||
|
@ -157,7 +157,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
|
||||||
});
|
});
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::adjust_contrast_v2 op;
|
nd4j::ops::adjust_contrast_v2 op;
|
||||||
auto result = op.execute({&x}, {2.}, {}, {});
|
auto result = op.evaluate({&x}, {2.}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
// out->printIndexedBuffer("Adjusted Constrast");
|
||||||
|
@ -172,7 +172,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) {
|
||||||
});
|
});
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::adjust_contrast_v2 op;
|
nd4j::ops::adjust_contrast_v2 op;
|
||||||
auto result = op.execute({&x}, {2.}, {}, {});
|
auto result = op.evaluate({&x}, {2.}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
// out->printIndexedBuffer("Adjusted Constrast");
|
||||||
|
@ -302,7 +302,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::adjust_contrast op;
|
nd4j::ops::adjust_contrast op;
|
||||||
auto result = op.execute({&x}, {2.}, {}, {});
|
auto result = op.evaluate({&x}, {2.}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printBuffer("Adjusted Constrast6");
|
// out->printBuffer("Adjusted Constrast6");
|
||||||
|
@ -407,7 +407,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
|
||||||
});
|
});
|
||||||
// x.linspace(1.);
|
// x.linspace(1.);
|
||||||
nd4j::ops::adjust_contrast_v2 op;
|
nd4j::ops::adjust_contrast_v2 op;
|
||||||
auto result = op.execute({&x}, {2.}, {}, {});
|
auto result = op.evaluate({&x}, {2.}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printBuffer("Adjusted Constrast7");
|
// out->printBuffer("Adjusted Constrast7");
|
||||||
|
@ -423,7 +423,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
|
||||||
auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 });
|
auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 });
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::bitcast op;
|
nd4j::ops::bitcast op;
|
||||||
auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {});
|
auto result = op.evaluate({&x}, {(int) nd4j::DataType::DOUBLE});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printIndexedBuffer("Casted result");
|
// out->printIndexedBuffer("Casted result");
|
||||||
|
@ -437,7 +437,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
|
||||||
0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f});
|
0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f});
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::bitcast op;
|
nd4j::ops::bitcast op;
|
||||||
auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {});
|
auto result = op.evaluate({&x}, {(int) nd4j::DataType::HALF});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
@ -450,7 +450,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_3) {
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::bitcast op;
|
nd4j::ops::bitcast op;
|
||||||
try {
|
try {
|
||||||
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
auto result = op.evaluate({&x}, {(int) nd4j::DataType::INT64});
|
||||||
ASSERT_NE(Status::OK(), result->status());
|
ASSERT_NE(Status::OK(), result->status());
|
||||||
delete result;
|
delete result;
|
||||||
} catch (std::exception& e) {
|
} catch (std::exception& e) {
|
||||||
|
@ -478,7 +478,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) {
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::bitcast op;
|
nd4j::ops::bitcast op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
// e.printIndexedBuffer("Double to int64");
|
// e.printIndexedBuffer("Double to int64");
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
|
@ -497,7 +497,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_5) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL,
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL,
|
||||||
3314989625590692528LL});
|
3314989625590692528LL});
|
||||||
nd4j::ops::bitcast op;
|
nd4j::ops::bitcast op;
|
||||||
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
// res->printIndexedBuffer("BITCAST5");
|
// res->printIndexedBuffer("BITCAST5");
|
||||||
|
@ -515,7 +515,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL,
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL,
|
||||||
5476460161268730496LL});
|
5476460161268730496LL});
|
||||||
nd4j::ops::bitcast op;
|
nd4j::ops::bitcast op;
|
||||||
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
// res->printIndexedBuffer("BITCAST6");
|
// res->printIndexedBuffer("BITCAST6");
|
||||||
|
@ -532,7 +532,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {
|
||||||
4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL});
|
4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL});
|
||||||
nd4j::ops::bitcast op;
|
nd4j::ops::bitcast op;
|
||||||
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
// res->printIndexedBuffer("BITCAST7");
|
// res->printIndexedBuffer("BITCAST7");
|
||||||
|
@ -549,7 +549,7 @@ TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
|
||||||
auto gB = NDArrayFactory::create<double>('c', {1, 4});
|
auto gB = NDArrayFactory::create<double>('c', {1, 4});
|
||||||
|
|
||||||
nd4j::ops::matmul_bp op;
|
nd4j::ops::matmul_bp op;
|
||||||
auto status = op.execute({&a, &b, &gI}, {&gA, &gB}, {}, {1, 0, 0}, {});
|
auto status = op.execute({&a, &b, &gI}, std::vector<NDArray*>{&gA, &gB}, {}, {1, 0, 0}, {});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -573,7 +573,7 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_1) {
|
||||||
auto y = NDArrayFactory::string("shouldn't ever trigger");
|
auto y = NDArrayFactory::string("shouldn't ever trigger");
|
||||||
|
|
||||||
nd4j::ops::check_numerics op;
|
nd4j::ops::check_numerics op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -617,7 +617,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
|
||||||
auto b = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
|
auto b = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
|
||||||
|
|
||||||
nd4j::ops::layer_norm op;
|
nd4j::ops::layer_norm op;
|
||||||
auto result = op.execute({&x, &g, &b}, {}, {0}, {false});
|
auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
@ -629,7 +629,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
|
||||||
auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::layer_norm_bp op;
|
nd4j::ops::layer_norm_bp op;
|
||||||
auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false});
|
auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
@ -662,9 +662,9 @@ TEST_F(DeclarableOpsTests15, test_hashCode_1) {
|
||||||
y.linspace(2.);
|
y.linspace(2.);
|
||||||
|
|
||||||
nd4j::ops::hashcode op;
|
nd4j::ops::hashcode op;
|
||||||
auto resultA0 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto resultA0 = op.evaluate({&x});
|
||||||
auto resultA1 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto resultA1 = op.evaluate({&x});
|
||||||
auto resultB0 = op.execute({&y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto resultB0 = op.evaluate({&y});
|
||||||
// resultA0->at(0)->printIndexedBuffer("A0");
|
// resultA0->at(0)->printIndexedBuffer("A0");
|
||||||
// resultA1->at(0)->printIndexedBuffer("A1");
|
// resultA1->at(0)->printIndexedBuffer("A1");
|
||||||
// resultB0->at(0)->printIndexedBuffer("B0");
|
// resultB0->at(0)->printIndexedBuffer("B0");
|
||||||
|
@ -684,9 +684,9 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
|
||||||
y.linspace(2.);
|
y.linspace(2.);
|
||||||
|
|
||||||
nd4j::ops::hashcode op;
|
nd4j::ops::hashcode op;
|
||||||
auto resultA0 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto resultA0 = op.evaluate({&x});
|
||||||
auto resultA1 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto resultA1 = op.evaluate({&x});
|
||||||
auto resultB0 = op.execute({&y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto resultB0 = op.evaluate({&y});
|
||||||
|
|
||||||
// resultA0->at(0)->printIndexedBuffer("A0");
|
// resultA0->at(0)->printIndexedBuffer("A0");
|
||||||
// resultA1->at(0)->printIndexedBuffer("A1");
|
// resultA1->at(0)->printIndexedBuffer("A1");
|
||||||
|
@ -705,7 +705,7 @@ TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&array}, {}, {1, 1});
|
auto result = op.evaluate({&array}, {}, {1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -742,7 +742,7 @@ TEST_F(DeclarableOpsTests15, test_rank_2) {
|
||||||
auto e = NDArrayFactory::create<int>('c', {}, {2});
|
auto e = NDArrayFactory::create<int>('c', {}, {2});
|
||||||
|
|
||||||
nd4j::ops::rank op;
|
nd4j::ops::rank op;
|
||||||
auto result = op.execute({&array}, {}, {});
|
auto result = op.evaluate({&array}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
|
||||||
auto x8 = NDArrayFactory::create<float>('c', {12});
|
auto x8 = NDArrayFactory::create<float>('c', {12});
|
||||||
|
|
||||||
nd4j::ops::lstmBlock op;
|
nd4j::ops::lstmBlock op;
|
||||||
auto result = op.execute({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0});
|
auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -790,7 +790,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
|
||||||
auto x8 = NDArrayFactory::create<float>('f', {4 * nIn});
|
auto x8 = NDArrayFactory::create<float>('f', {4 * nIn});
|
||||||
|
|
||||||
nd4j::ops::lstmBlock op;
|
nd4j::ops::lstmBlock op;
|
||||||
auto result = op.execute({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1});
|
auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -860,7 +860,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) {
|
||||||
NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32);
|
NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32);
|
||||||
NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32);
|
NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32);
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({&rgbs}, {}, {});
|
auto result = op.evaluate({&rgbs}, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -876,7 +876,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_2) {
|
||||||
auto rgbs = NDArrayFactory::create<int>('f', { 3 }, { 1, 120, -25 });
|
auto rgbs = NDArrayFactory::create<int>('f', { 3 }, { 1, 120, -25 });
|
||||||
auto expected = NDArrayFactory::create<int>('f', { 1 }, { 67 });
|
auto expected = NDArrayFactory::create<int>('f', { 1 }, { 67 });
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -892,7 +892,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_3) {
|
||||||
NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32);
|
NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32);
|
||||||
NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, nd4j::DataType::INT32);
|
NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, nd4j::DataType::INT32);
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -910,7 +910,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_4) {
|
||||||
rgbs.permutei({1,0});
|
rgbs.permutei({1,0});
|
||||||
NDArray expected('c', { 2, 1 }, { 138, 58 }, nd4j::DataType::INT32);
|
NDArray expected('c', { 2, 1 }, { 138, 58 }, nd4j::DataType::INT32);
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -926,7 +926,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_5) {
|
||||||
NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32);
|
NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32);
|
||||||
NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, nd4j::DataType::INT32);
|
NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, nd4j::DataType::INT32);
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {0});
|
auto result = op.evaluate({ &rgbs }, {}, {0});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -943,7 +943,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_6) {
|
||||||
auto expected = NDArrayFactory::create<float>('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f});
|
auto expected = NDArrayFactory::create<float>('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f});
|
||||||
|
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -960,7 +960,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_7) {
|
||||||
auto expected = NDArrayFactory::create<float>('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f });
|
auto expected = NDArrayFactory::create<float>('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f });
|
||||||
|
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {1});
|
auto result = op.evaluate({ &rgbs }, {}, {1});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -976,7 +976,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) {
|
||||||
auto rgbs = NDArrayFactory::create<float>('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f});
|
auto rgbs = NDArrayFactory::create<float>('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f});
|
||||||
try {
|
try {
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
ASSERT_EQ(Status::THROW(), result->status());
|
ASSERT_EQ(Status::THROW(), result->status());
|
||||||
delete result;
|
delete result;
|
||||||
} catch (std::exception& e) {
|
} catch (std::exception& e) {
|
||||||
|
@ -991,7 +991,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) {
|
||||||
auto expected = NDArrayFactory::create<float>('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f });
|
auto expected = NDArrayFactory::create<float>('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f });
|
||||||
|
|
||||||
nd4j::ops::rgb_to_grs op;
|
nd4j::ops::rgb_to_grs op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1007,7 +1007,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) {
|
||||||
NDArray rgbs('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
NDArray rgbs('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
||||||
NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1026,7 +1026,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) {
|
||||||
NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, nd4j::DataType::FLOAT32);
|
NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, nd4j::DataType::FLOAT32);
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
|
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1043,7 +1043,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) {
|
||||||
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
auto result = op.execute({ &rgbs }, {}, { 0 });
|
auto result = op.evaluate({ &rgbs }, {}, { 0 });
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
@ -1059,7 +1059,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) {
|
||||||
NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
|
NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1076,7 +1076,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) {
|
||||||
NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, nd4j::DataType::FLOAT32);
|
NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
auto result = op.execute({ &rgbs }, {}, { 1 });
|
auto result = op.evaluate({ &rgbs }, {}, { 1 });
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1091,7 +1091,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) {
|
||||||
NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
|
NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
|
||||||
try {
|
try {
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
ASSERT_EQ(Status::THROW(), result->status());
|
ASSERT_EQ(Status::THROW(), result->status());
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
@ -1107,7 +1107,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) {
|
||||||
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1123,7 +1123,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) {
|
||||||
NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
||||||
NDArray expected('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
NDArray expected('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
||||||
nd4j::ops::yuv_to_rgb op;
|
nd4j::ops::yuv_to_rgb op;
|
||||||
auto result = op.execute({ &yuv }, {}, {});
|
auto result = op.evaluate({ &yuv }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1139,7 +1139,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) {
|
||||||
NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
||||||
NDArray expected('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
NDArray expected('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
||||||
nd4j::ops::yuv_to_rgb op;
|
nd4j::ops::yuv_to_rgb op;
|
||||||
auto result = op.execute({ &yuv }, {}, {});
|
auto result = op.evaluate({ &yuv }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1156,7 +1156,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) {
|
||||||
NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::yuv_to_rgb op;
|
nd4j::ops::yuv_to_rgb op;
|
||||||
auto result = op.execute({ &yuv }, {}, { 0 });
|
auto result = op.evaluate({ &yuv }, {}, { 0 });
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
@ -1172,7 +1172,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) {
|
||||||
NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
|
NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::yuv_to_rgb op;
|
nd4j::ops::yuv_to_rgb op;
|
||||||
auto result = op.execute({ &yuv }, {}, {});
|
auto result = op.evaluate({ &yuv }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1189,7 +1189,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) {
|
||||||
NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, nd4j::DataType::FLOAT32);
|
NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::yuv_to_rgb op;
|
nd4j::ops::yuv_to_rgb op;
|
||||||
auto result = op.execute({ &yuv }, {}, { 1 });
|
auto result = op.evaluate({ &yuv }, {}, { 1 });
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1204,7 +1204,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) {
|
||||||
NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
|
NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
|
||||||
try {
|
try {
|
||||||
nd4j::ops::yuv_to_rgb op;
|
nd4j::ops::yuv_to_rgb op;
|
||||||
auto result = op.execute({ &yuv }, {}, {});
|
auto result = op.evaluate({ &yuv }, {}, {});
|
||||||
ASSERT_EQ(Status::THROW(), result->status());
|
ASSERT_EQ(Status::THROW(), result->status());
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
@ -1220,7 +1220,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) {
|
||||||
NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::yuv_to_rgb op;
|
nd4j::ops::yuv_to_rgb op;
|
||||||
auto result = op.execute({ &yuv }, {}, {});
|
auto result = op.evaluate({ &yuv }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1246,7 +1246,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test1) {
|
||||||
dLdz.assign(1.0);
|
dLdz.assign(1.0);
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto results = op.execute({ &x, &y, &dLdz }, {}, {});
|
auto results = op.evaluate({ &x, &y, &dLdz }, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1275,7 +1275,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test2) {
|
||||||
dLdz.linspace(0.1, 0.1);
|
dLdz.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto results = op.execute({ &x, &y, &dLdz }, {}, {});
|
auto results = op.evaluate({ &x, &y, &dLdz }, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto* dLdx = results->at(0);
|
auto* dLdx = results->at(0);
|
||||||
|
@ -1305,7 +1305,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test3) {
|
||||||
dLdz.linspace(0.1, 0.1);
|
dLdz.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto resultsY = op.execute({ &xY, &yY, &dLdz }, {}, {});
|
auto resultsY = op.evaluate({ &xY, &yY, &dLdz }, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsY->status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsY->status());
|
||||||
|
|
||||||
|
@ -1337,7 +1337,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test4) {
|
||||||
xX.assign(2.0);
|
xX.assign(2.0);
|
||||||
yX.assign(4.0);
|
yX.assign(4.0);
|
||||||
|
|
||||||
auto resultsX = op.execute({ &xX, &yX, &dLdz }, {}, {});
|
auto resultsX = op.evaluate({ &xX, &yX, &dLdz }, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsX->status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsX->status());
|
||||||
|
|
||||||
|
@ -1369,7 +1369,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test5) {
|
||||||
dLdyExp.assign(pow(3, 4) * log(3));
|
dLdyExp.assign(pow(3, 4) * log(3));
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto results = op.execute({ &xConst, &yConst, &dLdz }, {}, {});
|
auto results = op.evaluate({ &xConst, &yConst, &dLdz }, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto* dLdx = results->at(0);
|
auto* dLdx = results->at(0);
|
||||||
|
@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test6) {
|
||||||
NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32);
|
NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto resultsXC = op.execute({ &xConst, &y, &dLdzC }, {}, {});
|
auto resultsXC = op.evaluate({ &xConst, &y, &dLdzC }, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsXC->status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsXC->status());
|
||||||
|
|
||||||
auto* dLdxXC = resultsXC->at(0);
|
auto* dLdxXC = resultsXC->at(0);
|
||||||
|
@ -1428,7 +1428,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) {
|
||||||
auto dLdyExpYs = NDArrayFactory::create<float>(79.85056f);
|
auto dLdyExpYs = NDArrayFactory::create<float>(79.85056f);
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto resultsYs = op.execute({ &x, &Y, &dLdzC }, {}, {});
|
auto resultsYs = op.evaluate({ &x, &Y, &dLdzC }, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsYs->status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsYs->status());
|
||||||
|
|
||||||
auto* dLdxY = resultsYs->at(0);
|
auto* dLdxY = resultsYs->at(0);
|
||||||
|
@ -1454,7 +1454,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test8) {
|
||||||
NDArray dLdyExp = NDArrayFactory::create<float>(pow(4.f, 2.f) * log(4.f) * 0.1f);
|
NDArray dLdyExp = NDArrayFactory::create<float>(pow(4.f, 2.f) * log(4.f) * 0.1f);
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto results = op.execute({ &X, &Y, &dLdz }, {}, {});
|
auto results = op.evaluate({ &X, &Y, &dLdz }, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1484,7 +1484,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test9) {
|
||||||
y.assign(2.0);
|
y.assign(2.0);
|
||||||
dLdz.linspace(0.1, 0.1);
|
dLdz.linspace(0.1, 0.1);
|
||||||
|
|
||||||
auto results = op.execute({ &x, &y, &dLdz }, {}, {});
|
auto results = op.evaluate({ &x, &y, &dLdz }, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto* dLdx = results->at(0);
|
auto* dLdx = results->at(0);
|
||||||
|
@ -1513,7 +1513,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test10) {
|
||||||
yB.assign(2.0);
|
yB.assign(2.0);
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {});
|
auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsB->status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsB->status());
|
||||||
|
|
||||||
|
@ -1540,7 +1540,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
|
||||||
NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, nd4j::DataType::FLOAT32);
|
NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::Pow_bp op;
|
nd4j::ops::Pow_bp op;
|
||||||
auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {});
|
auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsB->status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsB->status());
|
||||||
auto* dLdxB = resultsB->at(0);
|
auto* dLdxB = resultsB->at(0);
|
||||||
|
|
|
@ -46,7 +46,7 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) {
|
||||||
auto e = NDArrayFactory::create<float>('c', { 3 }, { 3.f, 1.f, 1.f });
|
auto e = NDArrayFactory::create<float>('c', { 3 }, { 3.f, 1.f, 1.f });
|
||||||
|
|
||||||
nd4j::ops::scatter_upd op;
|
nd4j::ops::scatter_upd op;
|
||||||
auto result = op.execute({ &x, &y, &w }, {}, {});
|
auto result = op.evaluate({ &x, &y, &w });
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -66,7 +66,7 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::scatter_upd op;
|
nd4j::ops::scatter_upd op;
|
||||||
auto result = op.execute({ &x, &indices, &updates }, {}, {});
|
auto result = op.evaluate({ &x, &indices, &updates });
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>(18);
|
auto e = NDArrayFactory::create<Nd4jLong>(18);
|
||||||
|
|
||||||
nd4j::ops::bits_hamming_distance op;
|
nd4j::ops::bits_hamming_distance op;
|
||||||
auto result = op.execute({ &x, &y }, {}, {});
|
auto result = op.evaluate({ &x, &y });
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -166,7 +166,7 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
|
||||||
auto e = NDArrayFactory::create<Nd4jLong>('c', { 1, 0, 2 });
|
auto e = NDArrayFactory::create<Nd4jLong>('c', { 1, 0, 2 });
|
||||||
|
|
||||||
nd4j::ops::cast op;
|
nd4j::ops::cast op;
|
||||||
auto result = op.execute({ &x }, {}, { 10 });
|
auto result = op.evaluate({&x}, {10});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::compat_sparse_to_dense op;
|
nd4j::ops::compat_sparse_to_dense op;
|
||||||
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {});
|
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -63,7 +63,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::compat_sparse_to_dense op;
|
nd4j::ops::compat_sparse_to_dense op;
|
||||||
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {});
|
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
||||||
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
|
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
|
||||||
|
|
||||||
nd4j::ops::compat_string_split op;
|
nd4j::ops::compat_string_split op;
|
||||||
auto result = op.execute({&x, &delimiter}, {}, {});
|
auto result = op.evaluate({&x, &delimiter});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_EQ(2, result->size());
|
ASSERT_EQ(2, result->size());
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -43,7 +43,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_1) {
|
||||||
auto exp = x.tile(reps);
|
auto exp = x.tile(reps);
|
||||||
|
|
||||||
nd4j::ops::tile op;
|
nd4j::ops::tile op;
|
||||||
auto result = op.execute({&x, &rep_vector}, {}, {});
|
auto result = op.evaluate({&x, &rep_vector});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -61,7 +61,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_2) {
|
||||||
auto exp = x.tile(reps);
|
auto exp = x.tile(reps);
|
||||||
|
|
||||||
nd4j::ops::tile op;
|
nd4j::ops::tile op;
|
||||||
auto result = op.execute({&x}, {}, {2, 2});
|
auto result = op.evaluate({&x}, {}, {2, 2});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests3, Test_Permute_1) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2, 4, 3});
|
auto exp= NDArrayFactory::create<float>('c', {2, 4, 3});
|
||||||
|
|
||||||
nd4j::ops::permute op;
|
nd4j::ops::permute op;
|
||||||
auto result = op.execute({&x, &permute}, {}, {});
|
auto result = op.evaluate({&x, &permute});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -92,7 +92,7 @@ TEST_F(DeclarableOpsTests3, Test_Permute_2) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {4, 3, 2});
|
auto exp= NDArrayFactory::create<float>('c', {4, 3, 2});
|
||||||
|
|
||||||
nd4j::ops::permute op;
|
nd4j::ops::permute op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -110,7 +110,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) {
|
||||||
// auto expI= NDArrayFactory::create<float>('c', {3}, {0, 1, 4});
|
// auto expI= NDArrayFactory::create<float>('c', {3}, {0, 1, 4});
|
||||||
|
|
||||||
nd4j::ops::unique op;
|
nd4j::ops::unique op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
ASSERT_EQ(2, result->size());
|
ASSERT_EQ(2, result->size());
|
||||||
|
@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_2) {
|
||||||
auto expC= NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 2, 1});
|
auto expC= NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 2, 1});
|
||||||
|
|
||||||
nd4j::ops::unique_with_counts op;
|
nd4j::ops::unique_with_counts op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
ASSERT_EQ(3, result->size());
|
ASSERT_EQ(3, result->size());
|
||||||
|
@ -169,7 +169,7 @@ TEST_F(DeclarableOpsTests3, Test_Rint_1) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f});
|
auto exp= NDArrayFactory::create<float>('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f});
|
||||||
|
|
||||||
nd4j::ops::rint op;
|
nd4j::ops::rint op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -188,7 +188,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) {
|
||||||
std::vector<int> dims({1});
|
std::vector<int> dims({1});
|
||||||
nd4j::ops::norm op;
|
nd4j::ops::norm op;
|
||||||
|
|
||||||
auto result0 = op.execute({&x}, {0.}, {});
|
auto result0 = op.evaluate({&x}, {0.}, {});
|
||||||
|
|
||||||
auto z0 = result0->at(0);
|
auto z0 = result0->at(0);
|
||||||
auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false);
|
auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false);
|
||||||
|
@ -197,7 +197,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) {
|
||||||
|
|
||||||
delete result0;
|
delete result0;
|
||||||
|
|
||||||
auto result1 = op.execute({&x}, {1.}, {1});
|
auto result1 = op.evaluate({&x}, {1.}, {1});
|
||||||
ASSERT_EQ(result1->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result1->status(), ND4J_STATUS_OK);
|
||||||
auto z1 = result1->at(0);
|
auto z1 = result1->at(0);
|
||||||
// z1->printIndexedBuffer("Z1");
|
// z1->printIndexedBuffer("Z1");
|
||||||
|
@ -210,7 +210,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) {
|
||||||
|
|
||||||
delete result1;
|
delete result1;
|
||||||
|
|
||||||
auto result4 = op.execute({&x}, {4.}, {1});
|
auto result4 = op.evaluate({&x}, {4.}, {1});
|
||||||
|
|
||||||
auto z4 = result4->at(0);
|
auto z4 = result4->at(0);
|
||||||
auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false);
|
auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false);
|
||||||
|
@ -230,7 +230,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
|
||||||
std::vector<int> dims({1});
|
std::vector<int> dims({1});
|
||||||
nd4j::ops::norm op;
|
nd4j::ops::norm op;
|
||||||
|
|
||||||
auto result0 = op.execute({&x}, {0}, {});
|
auto result0 = op.evaluate({&x}, {0}, {});
|
||||||
|
|
||||||
auto z0 = result0->at(0);
|
auto z0 = result0->at(0);
|
||||||
auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false);
|
auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false);
|
||||||
|
@ -239,7 +239,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
|
||||||
|
|
||||||
delete result0;
|
delete result0;
|
||||||
|
|
||||||
auto result1 = op.execute({&x, &axis}, {1}, {});
|
auto result1 = op.evaluate({&x, &axis}, {1}, {});
|
||||||
|
|
||||||
auto z1 = result1->at(0);
|
auto z1 = result1->at(0);
|
||||||
auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false);
|
auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false);
|
||||||
|
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
|
||||||
|
|
||||||
delete result1;
|
delete result1;
|
||||||
|
|
||||||
auto result4 = op.execute({&x, &axis}, {4}, {});
|
auto result4 = op.evaluate({&x, &axis}, {4}, {});
|
||||||
|
|
||||||
auto z4 = result4->at(0);
|
auto z4 = result4->at(0);
|
||||||
auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false);
|
auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false);
|
||||||
|
@ -264,7 +264,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
|
||||||
|
|
||||||
nd4j::ops::clipbyavgnorm op;
|
nd4j::ops::clipbyavgnorm op;
|
||||||
auto result = op.execute({&x}, {0.8}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {0.8}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
|
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
|
||||||
|
|
||||||
nd4j::ops::clipbyavgnorm op;
|
nd4j::ops::clipbyavgnorm op;
|
||||||
auto result = op.execute({&x}, {0.9}, {});
|
auto result = op.evaluate({&x}, {0.9}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) {
|
||||||
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
|
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
|
||||||
|
|
||||||
nd4j::ops::clipbynorm op;
|
nd4j::ops::clipbynorm op;
|
||||||
auto result = op.execute({&x}, {4.0}, {});
|
auto result = op.evaluate({&x}, {4.0}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -310,7 +310,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) {
|
||||||
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||||
|
|
||||||
nd4j::ops::clipbynorm op;
|
nd4j::ops::clipbynorm op;
|
||||||
auto result = op.execute({&x}, {6.0}, {});
|
auto result = op.evaluate({&x}, {6.0}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) {
|
||||||
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||||
|
|
||||||
nd4j::ops::clipbynorm op;
|
nd4j::ops::clipbynorm op;
|
||||||
auto result = op.execute({&x}, {1.0}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {1.0}, {1});
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
|
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||||
|
@ -360,7 +360,7 @@ TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
|
||||||
auto exp1= NDArrayFactory::create<Nd4jLong>('c', {3}, {1, 3, 5});
|
auto exp1= NDArrayFactory::create<Nd4jLong>('c', {3}, {1, 3, 5});
|
||||||
|
|
||||||
nd4j::ops::listdiff op;
|
nd4j::ops::listdiff op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -386,7 +386,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_1) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f});
|
auto exp= NDArrayFactory::create<float>('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({&start, &stop, &step}, {}, {});
|
auto result = op.evaluate({&start, &stop, &step});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -406,7 +406,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_2) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2}, {2.f, 1.f});
|
auto exp= NDArrayFactory::create<float>('c', {2}, {2.f, 1.f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({&start, &stop, &step}, {}, {});
|
auto result = op.evaluate({&start, &stop, &step});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -425,7 +425,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2}, {0.f, 1.f});
|
auto exp= NDArrayFactory::create<float>('c', {2}, {0.f, 1.f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({&start, &stop, &step}, {}, {});
|
auto result = op.evaluate({&start, &stop, &step});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_4) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f});
|
auto exp= NDArrayFactory::create<float>('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {-10., 10., 1.666}, {});
|
auto result = op.evaluate({}, {-10., 10., 1.666}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -459,7 +459,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_5) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2}, {2.f, 1.f});
|
auto exp= NDArrayFactory::create<float>('c', {2}, {2.f, 1.f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {2, 0, -1}, {});
|
auto result = op.evaluate({}, {2, 0, -1}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -475,7 +475,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_6) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2}, {0.f, 1.f});
|
auto exp= NDArrayFactory::create<float>('c', {2}, {0.f, 1.f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {0, 2, 1}, {});
|
auto result = op.evaluate({}, {0, 2, 1}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -491,7 +491,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_7) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f});
|
auto exp= NDArrayFactory::create<float>('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {10,-5,-1.666}, {});
|
auto result = op.evaluate({}, {10,-5,-1.666}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -509,7 +509,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_8) {
|
||||||
auto exp= NDArrayFactory::create<int>('c', {2}, {2, 1});
|
auto exp= NDArrayFactory::create<int>('c', {2}, {2, 1});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {}, {2, 0, -1});
|
auto result = op.evaluate({}, {}, {2, 0, -1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -525,7 +525,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_9) {
|
||||||
auto exp= NDArrayFactory::create<int>('c', {2}, {0, 1});
|
auto exp= NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {}, {0, 2, 1});
|
auto result = op.evaluate({}, {}, {0, 2, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) {
|
||||||
auto exp = MmulHelper::mmul(&x, &y);
|
auto exp = MmulHelper::mmul(&x, &y);
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(3, result->size());
|
ASSERT_EQ(3, result->size());
|
||||||
|
@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
|
||||||
auto exp = MmulHelper::mmul(&x, &y);
|
auto exp = MmulHelper::mmul(&x, &y);
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(3, result->size());
|
ASSERT_EQ(3, result->size());
|
||||||
|
@ -602,7 +602,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
|
||||||
auto exp = MmulHelper::mmul(&x, &y);
|
auto exp = MmulHelper::mmul(&x, &y);
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(3, result->size());
|
ASSERT_EQ(3, result->size());
|
||||||
|
@ -630,7 +630,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
|
||||||
auto exp = MmulHelper::mmul(&x, &y);
|
auto exp = MmulHelper::mmul(&x, &y);
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(3, result->size());
|
ASSERT_EQ(3, result->size());
|
||||||
|
@ -658,7 +658,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
|
||||||
auto exp = MmulHelper::mmul(&x, &y);
|
auto exp = MmulHelper::mmul(&x, &y);
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(3, result->size());
|
ASSERT_EQ(3, result->size());
|
||||||
|
@ -687,7 +687,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) {
|
||||||
auto exp = MmulHelper::mmul(&x, &y);
|
auto exp = MmulHelper::mmul(&x, &y);
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(3, result->size());
|
ASSERT_EQ(3, result->size());
|
||||||
|
@ -717,7 +717,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
|
||||||
// exp->printShapeInfo("exp shape");
|
// exp->printShapeInfo("exp shape");
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_EQ(3, result->size());
|
ASSERT_EQ(3, result->size());
|
||||||
|
@ -744,7 +744,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
|
||||||
|
|
||||||
nd4j::ops::batched_gemm op;
|
nd4j::ops::batched_gemm op;
|
||||||
try {
|
try {
|
||||||
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
||||||
delete result;
|
delete result;
|
||||||
ASSERT_TRUE(false);
|
ASSERT_TRUE(false);
|
||||||
} catch (std::invalid_argument &e) {
|
} catch (std::invalid_argument &e) {
|
||||||
|
@ -775,7 +775,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) {
|
||||||
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0});
|
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0});
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto result = op.execute({&x, &y}, {}, {1, 1});
|
auto result = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -794,7 +794,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) {
|
||||||
auto exp= NDArrayFactory::create<double>('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0});
|
auto exp= NDArrayFactory::create<double>('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0});
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto result = op.execute({&x, &y}, {}, {0, 0});
|
auto result = op.evaluate({&x, &y}, {}, {0, 0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -813,7 +813,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) {
|
||||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto result = op.execute({&x, &y}, {}, {1, 0});
|
auto result = op.evaluate({&x, &y}, {}, {1, 0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -832,7 +832,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) {
|
||||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto result = op.execute({&x, &y}, {}, {0, 1});
|
auto result = op.evaluate({&x, &y}, {}, {0, 1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -851,7 +851,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) {
|
||||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -870,7 +870,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) {
|
||||||
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
|
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -889,7 +889,7 @@ TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
|
||||||
auto exp= NDArrayFactory::create<double>('c', {1, 3}, {2, 3, 4});
|
auto exp= NDArrayFactory::create<double>('c', {1, 3}, {2, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::reversedivide op;
|
nd4j::ops::reversedivide op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -920,7 +920,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test1) {
|
||||||
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f});
|
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f});
|
||||||
|
|
||||||
nd4j::ops::sruCell op;
|
nd4j::ops::sruCell op;
|
||||||
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {});
|
auto results = op.evaluate({&xt, &ct_1, &w, &b});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -956,7 +956,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test2) {
|
||||||
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f});
|
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f});
|
||||||
|
|
||||||
nd4j::ops::sruCell op;
|
nd4j::ops::sruCell op;
|
||||||
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {});
|
auto results = op.evaluate({&xt, &ct_1, &w, &b});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -991,7 +991,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) {
|
||||||
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f});
|
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
|
|
||||||
nd4j::ops::sruCell op;
|
nd4j::ops::sruCell op;
|
||||||
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {});
|
auto results = op.evaluate({&xt, &ct_1, &w, &b});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1030,7 +1030,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test1) {
|
||||||
auto expHt = NDArrayFactory::create<float>('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f});
|
auto expHt = NDArrayFactory::create<float>('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f});
|
||||||
|
|
||||||
nd4j::ops::gruCell op;
|
nd4j::ops::gruCell op;
|
||||||
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {});
|
auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1066,7 +1066,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) {
|
||||||
auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f});
|
auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f});
|
||||||
|
|
||||||
nd4j::ops::gruCell op;
|
nd4j::ops::gruCell op;
|
||||||
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {});
|
auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1102,7 +1102,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test3) {
|
||||||
auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f});
|
auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f});
|
||||||
|
|
||||||
nd4j::ops::gruCell op;
|
nd4j::ops::gruCell op;
|
||||||
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {});
|
auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2});
|
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2});
|
||||||
|
|
||||||
nd4j::ops::invert_permutation op;
|
nd4j::ops::invert_permutation op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1140,7 +1140,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2});
|
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2});
|
||||||
|
|
||||||
nd4j::ops::invert_permutation op;
|
nd4j::ops::invert_permutation op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1159,7 +1159,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test3) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7});
|
auto expected= NDArrayFactory::create<double>('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7});
|
||||||
|
|
||||||
nd4j::ops::invert_permutation op;
|
nd4j::ops::invert_permutation op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1180,7 +1180,7 @@ TEST_F(DeclarableOpsTests3, diag_test1) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6});
|
auto expected= NDArrayFactory::create<double>('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1201,7 +1201,7 @@ TEST_F(DeclarableOpsTests3, diag_test2) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6});
|
auto expected= NDArrayFactory::create<double>('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1222,7 +1222,7 @@ TEST_F(DeclarableOpsTests3, diag_test_vector) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4});
|
auto expected= NDArrayFactory::create<double>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto results = op.execute({input}, {}, {});
|
auto results = op.evaluate({input});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1246,7 +1246,7 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4});
|
auto expected= NDArrayFactory::create<double>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto results = op.execute({input}, {}, {});
|
auto results = op.evaluate({input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1267,7 +1267,7 @@ TEST_F(DeclarableOpsTests3, diag_test3) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3});
|
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1288,7 +1288,7 @@ TEST_F(DeclarableOpsTests3, diag_test4) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3});
|
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1309,7 +1309,7 @@ TEST_F(DeclarableOpsTests3, diag_test5) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {1,1}, {2});
|
auto expected= NDArrayFactory::create<double>('c', {1,1}, {2});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1330,7 +1330,7 @@ TEST_F(DeclarableOpsTests3, diag_test6) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8});
|
auto expected= NDArrayFactory::create<double>('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1353,7 +1353,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0});
|
auto expected= NDArrayFactory::create<double>('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0});
|
||||||
|
|
||||||
nd4j::ops::matrix_set_diag op;
|
nd4j::ops::matrix_set_diag op;
|
||||||
auto results = op.execute({&input, &diagonal}, {}, {});
|
auto results = op.evaluate({&input, &diagonal}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1376,7 +1376,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {1,1,2}, {1.f, 0.f});
|
auto expected= NDArrayFactory::create<float>('c', {1,1,2}, {1.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::matrix_set_diag op;
|
nd4j::ops::matrix_set_diag op;
|
||||||
auto results = op.execute({&input, &diagonal}, {}, {});
|
auto results = op.evaluate({&input, &diagonal}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {2,1,4}, {1,0,0,0,1,0,0,0});
|
auto expected= NDArrayFactory::create<double>('c', {2,1,4}, {1,0,0,0,1,0,0,0});
|
||||||
|
|
||||||
nd4j::ops::matrix_set_diag op;
|
nd4j::ops::matrix_set_diag op;
|
||||||
auto results = op.execute({&input, &diagonal}, {}, {});
|
auto results = op.evaluate({&input, &diagonal}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1422,7 +1422,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {2,1,4,1}, {1,0,0,0,1,0,0,0});
|
auto expected= NDArrayFactory::create<double>('c', {2,1,4,1}, {1,0,0,0,1,0,0,0});
|
||||||
|
|
||||||
nd4j::ops::matrix_set_diag op;
|
nd4j::ops::matrix_set_diag op;
|
||||||
auto results = op.execute({&input, &diagonal}, {}, {});
|
auto results = op.evaluate({&input, &diagonal}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1443,7 +1443,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test1) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {2}, {1,4});
|
auto expected= NDArrayFactory::create<double>('c', {2}, {1,4});
|
||||||
|
|
||||||
nd4j::ops::diag_part op;
|
nd4j::ops::diag_part op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1465,7 +1465,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test2) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {2,2}, {1,6,11,16});
|
auto expected= NDArrayFactory::create<double>('c', {2,2}, {1,6,11,16});
|
||||||
|
|
||||||
nd4j::ops::diag_part op;
|
nd4j::ops::diag_part op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1486,7 +1486,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test3) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {2,2,2}, {1,10,19,28,37,46,55,64});
|
auto expected= NDArrayFactory::create<double>('c', {2,2,2}, {1,10,19,28,37,46,55,64});
|
||||||
|
|
||||||
nd4j::ops::diag_part op;
|
nd4j::ops::diag_part op;
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.evaluate({&input}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1512,7 +1512,7 @@ TEST_F(DeclarableOpsTests3, betainc_test1) {
|
||||||
auto expected = NDArrayFactory::create<float16>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
|
auto expected = NDArrayFactory::create<float16>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1538,7 +1538,7 @@ TEST_F(DeclarableOpsTests3, betainc_test2) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1564,7 +1564,7 @@ TEST_F(DeclarableOpsTests3, betainc_test3) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests3, betainc_test4) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1616,7 +1616,7 @@ TEST_F(DeclarableOpsTests3, betainc_test5) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1642,7 +1642,7 @@ TEST_F(DeclarableOpsTests3, betainc_test6) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1668,7 +1668,7 @@ TEST_F(DeclarableOpsTests3, betainc_test7) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1694,7 +1694,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1720,7 +1720,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1746,7 +1746,7 @@ TEST_F(DeclarableOpsTests3, betainc_test10) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1767,7 +1767,7 @@ TEST_F(DeclarableOpsTests3, betainc_test11) {
|
||||||
|
|
||||||
NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32);
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1789,7 +1789,7 @@ TEST_F(DeclarableOpsTests3, betainc_test12) {
|
||||||
NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.evaluate({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1813,7 +1813,7 @@ TEST_F(DeclarableOpsTests3, zeta_test1) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f});
|
||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
auto results = op.execute({&x, &q}, {}, {});
|
auto results = op.evaluate({&x, &q}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f});
|
||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
auto results = op.execute({&x, &q}, {}, {});
|
auto results = op.evaluate({&x, &q}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1861,7 +1861,7 @@ TEST_F(DeclarableOpsTests3, zeta_test3) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f});
|
||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
auto results = op.execute({&x, &q}, {}, {});
|
auto results = op.evaluate({&x, &q}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1886,7 +1886,7 @@ TEST_F(DeclarableOpsTests3, zeta_test4) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f});
|
||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
auto results = op.execute({&x, &q}, {}, {});
|
auto results = op.evaluate({&x, &q}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1910,7 +1910,7 @@ TEST_F(DeclarableOpsTests3, zeta_test5) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f});
|
||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
auto results = op.execute({&x, &q}, {}, {});
|
auto results = op.evaluate({&x, &q}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1934,7 +1934,7 @@ TEST_F(DeclarableOpsTests3, zeta_test6) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f});
|
||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
auto results = op.execute({&x, &q}, {}, {});
|
auto results = op.evaluate({&x, &q}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1958,7 +1958,7 @@ TEST_F(DeclarableOpsTests3, zeta_test7) {
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f});
|
||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
auto results = op.execute({&x, &q}, {}, {});
|
auto results = op.evaluate({&x, &q}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1982,7 +1982,7 @@ TEST_F(DeclarableOpsTests3, zeta_test8) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
auto expected= NDArrayFactory::create<double>('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
auto results = op.execute({&x, &q}, {}, {});
|
auto results = op.evaluate({&x, &q}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2054,7 +2054,7 @@ TEST_F(DeclarableOpsTests3, Test_SplitV_Validation_1) {
|
||||||
auto z1 = NDArrayFactory::create<float>('c', {3, 7});
|
auto z1 = NDArrayFactory::create<float>('c', {3, 7});
|
||||||
|
|
||||||
nd4j::ops::split_v op;
|
nd4j::ops::split_v op;
|
||||||
auto status = op.execute({&x, &indices, &axis}, {&z0, &z1}, {}, {}, {});
|
auto status = op.execute({&x, &indices, &axis}, std::vector<NDArray*>{&z0, &z1}, {}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2070,7 +2070,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) {
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08});
|
auto expected= NDArrayFactory::create<double>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08});
|
||||||
|
|
||||||
nd4j::ops::polygamma op;
|
nd4j::ops::polygamma op;
|
||||||
auto results = op.execute({&n, &x}, {}, {});
|
auto results = op.evaluate({&n, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2097,7 +2097,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test2) {
|
||||||
//ASSERT_FALSE(true);
|
//ASSERT_FALSE(true);
|
||||||
|
|
||||||
nd4j::ops::polygamma op;
|
nd4j::ops::polygamma op;
|
||||||
auto results = op.execute({&n, &x}, {}, {});
|
auto results = op.evaluate({&n, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2120,7 +2120,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) {
|
||||||
|
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
|
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
|
||||||
nd4j::ops::polygamma op;
|
nd4j::ops::polygamma op;
|
||||||
auto results = op.execute({&n, &x}, {}, {});
|
auto results = op.evaluate({&n, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2141,7 +2141,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test4) {
|
||||||
1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE);
|
1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::ops::polygamma op;
|
nd4j::ops::polygamma op;
|
||||||
auto results = op.execute({&n, &x}, {}, {});
|
auto results = op.evaluate({&n, &x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2161,7 +2161,7 @@ TEST_F(DeclarableOpsTests3, digamma_1) {
|
||||||
std::numeric_limits<double>::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE);
|
std::numeric_limits<double>::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::ops::digamma op;
|
nd4j::ops::digamma op;
|
||||||
auto results = op.execute({&x}, {}, {});
|
auto results = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2182,7 +2182,7 @@ TEST_F(DeclarableOpsTests3, svd_test1) {
|
||||||
auto expV= NDArrayFactory::create<double>('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543});
|
auto expV= NDArrayFactory::create<double>('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {1, 1, 16});
|
auto results = op.evaluate({&x}, {}, {1, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2219,7 +2219,7 @@ TEST_F(DeclarableOpsTests3, svd_test2) {
|
||||||
auto expV= NDArrayFactory::create<double>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
|
auto expV= NDArrayFactory::create<double>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {1, 1, 16});
|
auto results = op.evaluate({&x}, {}, {1, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests3, svd_test3) {
|
||||||
auto expV= NDArrayFactory::create<double>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
|
auto expV= NDArrayFactory::create<double>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {0, 1, 16});
|
auto results = op.evaluate({&x}, {}, {0, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2293,7 +2293,7 @@ TEST_F(DeclarableOpsTests3, svd_test4) {
|
||||||
auto expV= NDArrayFactory::create<double>('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065});
|
auto expV= NDArrayFactory::create<double>('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {1, 1, 16});
|
auto results = op.evaluate({&x}, {}, {1, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2330,7 +2330,7 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
|
||||||
auto expV= NDArrayFactory::create<double>('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151});
|
auto expV= NDArrayFactory::create<double>('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {0, 1, 16});
|
auto results = op.evaluate({&x}, {}, {0, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2385,7 +2385,7 @@ TEST_F(DeclarableOpsTests3, svd_test6) {
|
||||||
-0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428});
|
-0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {1, 1, 16});
|
auto results = op.evaluate({&x}, {}, {1, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2423,7 +2423,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
|
||||||
39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122});
|
39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {0, 0, 16});
|
auto results = op.evaluate({&x}, {}, {0, 0, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2622,7 +2622,7 @@ TEST_F(DeclarableOpsTests3, svd_test9) {
|
||||||
1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01});
|
1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {1, 1, 16});
|
auto results = op.evaluate({&x}, {}, {1, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2681,7 +2681,7 @@ TEST_F(DeclarableOpsTests3, svd_test10) {
|
||||||
-4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01});
|
-4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {0, 1, 16});
|
auto results = op.evaluate({&x}, {}, {0, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2726,7 +2726,7 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
|
||||||
-0.43596, 0.83108, -0.34531});
|
-0.43596, 0.83108, -0.34531});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {0, 1, 16});
|
auto results = op.evaluate({&x}, {}, {0, 1, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2761,7 +2761,7 @@ TEST_F(DeclarableOpsTests3, svd_test12) {
|
||||||
NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371});
|
NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371});
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {1, 0, 16});
|
auto results = op.evaluate({&x}, {}, {1, 0, 16});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2780,7 +2780,7 @@ TEST_F(DeclarableOpsTests3, elu_test1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9});
|
||||||
|
|
||||||
nd4j::ops::elu op;
|
nd4j::ops::elu op;
|
||||||
auto results = op.execute({&x}, {0.5}, {});
|
auto results = op.evaluate({&x}, {0.5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2799,7 +2799,7 @@ TEST_F(DeclarableOpsTests3, elu_bp_test1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2});
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2});
|
||||||
|
|
||||||
nd4j::ops::elu_bp op;
|
nd4j::ops::elu_bp op;
|
||||||
auto results = op.execute({ &x, &eps }, {0.5}, {});
|
auto results = op.evaluate({ &x, &eps }, {0.5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2816,7 +2816,7 @@ TEST_F(DeclarableOpsTests3, lrelu_test1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
|
||||||
|
|
||||||
nd4j::ops::lrelu op;
|
nd4j::ops::lrelu op;
|
||||||
auto results = op.execute({&x}, {0.2}, {});
|
auto results = op.evaluate({&x}, {0.2}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2833,7 +2833,7 @@ TEST_F(DeclarableOpsTests3, lrelu_bp_test1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2});
|
||||||
|
|
||||||
nd4j::ops::lrelu_bp op;
|
nd4j::ops::lrelu_bp op;
|
||||||
auto results = op.execute({&x, &eps}, {0.2}, {});
|
auto results = op.evaluate({&x, &eps}, {0.2}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2850,7 +2850,7 @@ TEST_F(DeclarableOpsTests3, selu_test1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
|
||||||
|
|
||||||
nd4j::ops::selu op;
|
nd4j::ops::selu op;
|
||||||
auto results = op.execute({&x}, {}, {});
|
auto results = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2868,7 +2868,7 @@ TEST_F(DeclarableOpsTests3, selu_test2) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402});
|
||||||
|
|
||||||
nd4j::ops::selu_bp op;
|
nd4j::ops::selu_bp op;
|
||||||
auto results = op.execute({&x, &eps}, {0.2}, {});
|
auto results = op.evaluate({&x, &eps}, {0.2}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2888,7 +2888,7 @@ TEST_F(DeclarableOpsTests3, EQScalarTests_1) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::eq_scalar op;
|
nd4j::ops::eq_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_TRUE(res);
|
ASSERT_TRUE(res);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -2900,7 +2900,7 @@ TEST_F(DeclarableOpsTests3, EQScalarTests_2) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::eq_scalar op;
|
nd4j::ops::eq_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_FALSE(res);
|
ASSERT_FALSE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2911,7 +2911,7 @@ TEST_F(DeclarableOpsTests3, GTScalarTests_1) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::gt_scalar op;
|
nd4j::ops::gt_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_FALSE(res);
|
ASSERT_FALSE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2922,7 +2922,7 @@ TEST_F(DeclarableOpsTests3, GTScalarTests_2) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::gt_scalar op;
|
nd4j::ops::gt_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_TRUE(res);
|
ASSERT_TRUE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2933,7 +2933,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_1) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::gte_scalar op;
|
nd4j::ops::gte_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_TRUE(res);
|
ASSERT_TRUE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2944,7 +2944,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_2) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::gte_scalar op;
|
nd4j::ops::gte_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_TRUE(res);
|
ASSERT_TRUE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2955,7 +2955,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_3) {
|
||||||
auto scalar = NDArrayFactory::create(2.0f);
|
auto scalar = NDArrayFactory::create(2.0f);
|
||||||
|
|
||||||
nd4j::ops::gte_scalar op;
|
nd4j::ops::gte_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_FALSE(res);
|
ASSERT_FALSE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2966,7 +2966,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_1) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::lte_scalar op;
|
nd4j::ops::lte_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_TRUE(res);
|
ASSERT_TRUE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2977,7 +2977,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_2) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::lte_scalar op;
|
nd4j::ops::lte_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_FALSE(res);
|
ASSERT_FALSE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2988,7 +2988,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_3) {
|
||||||
auto scalar = NDArrayFactory::create(2.0f);
|
auto scalar = NDArrayFactory::create(2.0f);
|
||||||
|
|
||||||
nd4j::ops::lte_scalar op;
|
nd4j::ops::lte_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_TRUE(res);
|
ASSERT_TRUE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests3, NEQScalarTests_1) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::neq_scalar op;
|
nd4j::ops::neq_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_FALSE(res);
|
ASSERT_FALSE(res);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -3011,7 +3011,7 @@ TEST_F(DeclarableOpsTests3, NEQScalarTests_2) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::neq_scalar op;
|
nd4j::ops::neq_scalar op;
|
||||||
auto res = op.evaluate({&x, &scalar});
|
auto res = op.verify({&x, &scalar});
|
||||||
ASSERT_TRUE(res);
|
ASSERT_TRUE(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3022,7 +3022,7 @@ TEST_F(DeclarableOpsTests3, NOOPTests_1) {
|
||||||
auto scalar = NDArrayFactory::create(1.0f);
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
nd4j::ops::noop op;
|
nd4j::ops::noop op;
|
||||||
auto res = op.execute({&x, &scalar}, {}, {});
|
auto res = op.evaluate({&x, &scalar}, {}, {});
|
||||||
ASSERT_TRUE(res->status() == nd4j::Status::OK());
|
ASSERT_TRUE(res->status() == nd4j::Status::OK());
|
||||||
delete res;
|
delete res;
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) {
|
||||||
matrix.linspace(1);
|
matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -71,7 +71,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
|
||||||
matrix.linspace(1);
|
matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -92,7 +92,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
|
||||||
//matrix.linspace(1);
|
//matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -114,7 +114,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
||||||
//matrix.linspace(1);
|
//matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -133,7 +133,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
|
||||||
auto s = NDArrayFactory::create_<int>('c', {1}, {1});
|
auto s = NDArrayFactory::create_<int>('c', {1}, {1});
|
||||||
nd4j::ops::ones_as opOnes;
|
nd4j::ops::ones_as opOnes;
|
||||||
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
|
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
|
||||||
auto onesRes = opOnes.execute({&matrix}, {}, {});
|
auto onesRes = opOnes.evaluate({&matrix});
|
||||||
//matrix.linspace(1);
|
//matrix.linspace(1);
|
||||||
ASSERT_EQ(onesRes->status(), Status::OK());
|
ASSERT_EQ(onesRes->status(), Status::OK());
|
||||||
|
|
||||||
|
@ -181,7 +181,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) {
|
||||||
//matrix.linspace(1);
|
//matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -201,7 +201,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
|
||||||
//matrix.linspace(1);
|
//matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -222,7 +222,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
|
||||||
//matrix.linspace(1);
|
//matrix.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0});
|
auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -244,7 +244,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
||||||
grad.linspace(1);
|
grad.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice_bp op;
|
nd4j::ops::strided_slice_bp op;
|
||||||
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
|
auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -266,7 +266,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
||||||
//grad.linspace(1);
|
//grad.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice_bp op;
|
nd4j::ops::strided_slice_bp op;
|
||||||
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
|
auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -288,7 +288,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
|
||||||
grad.linspace(1);
|
grad.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::strided_slice_bp op;
|
nd4j::ops::strided_slice_bp op;
|
||||||
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1});
|
auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -302,7 +302,7 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
|
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
|
||||||
|
|
||||||
nd4j::ops::test_scalar op;
|
nd4j::ops::test_scalar op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -321,7 +321,7 @@ TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
||||||
exp.linspace(1);
|
exp.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::order op;
|
nd4j::ops::order op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -336,7 +336,7 @@ TEST_F(DeclarableOpsTests6, cumSum_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1.f, 3.f, 6.f, 10.f});
|
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1.f, 3.f, 6.f, 10.f});
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 0});
|
auto result = op.evaluate({&x}, {}, {0, 0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -352,7 +352,7 @@ TEST_F(DeclarableOpsTests6, cumSum_2) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f});
|
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f});
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 0, 1});
|
auto result = op.evaluate({&x}, {}, {0, 0, 1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -369,7 +369,7 @@ TEST_F(DeclarableOpsTests6, cumSum_3) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f});
|
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f});
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 0, 0});
|
auto result = op.evaluate({&x}, {}, {0, 0, 0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -385,7 +385,7 @@ TEST_F(DeclarableOpsTests6, cumSum_4) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.});
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.});
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 1, 0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -401,7 +401,7 @@ TEST_F(DeclarableOpsTests6, cumSum_5) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,});
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,});
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {0, 1, 1}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -416,7 +416,7 @@ TEST_F(DeclarableOpsTests6, cumSum_6) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f});
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {1, 1, 0}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -431,7 +431,7 @@ TEST_F(DeclarableOpsTests6, cumSum_7) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {1, 1, 1}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -447,7 +447,7 @@ TEST_F(DeclarableOpsTests6, cumSum_8) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x, &axis}, {}, {1, 1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -475,7 +475,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
||||||
exclusive = 0; reverse = 0;
|
exclusive = 0; reverse = 0;
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
ASSERT_TRUE(expFF.equalsTo(z));
|
ASSERT_TRUE(expFF.equalsTo(z));
|
||||||
|
@ -484,7 +484,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 1; reverse = 0;
|
exclusive = 1; reverse = 0;
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
|
result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
z = result->at(0);
|
z = result->at(0);
|
||||||
ASSERT_TRUE(expTF.equalsTo(z));
|
ASSERT_TRUE(expTF.equalsTo(z));
|
||||||
|
@ -493,7 +493,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 0; reverse = 1;
|
exclusive = 0; reverse = 1;
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
|
result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
z = result->at(0);
|
z = result->at(0);
|
||||||
ASSERT_TRUE(expFT.equalsTo(z));
|
ASSERT_TRUE(expFT.equalsTo(z));
|
||||||
|
@ -502,7 +502,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 1; reverse = 1;
|
exclusive = 1; reverse = 1;
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
|
result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
z = result->at(0);
|
z = result->at(0);
|
||||||
ASSERT_TRUE(expTT.equalsTo(z));
|
ASSERT_TRUE(expTT.equalsTo(z));
|
||||||
|
@ -516,7 +516,7 @@ TEST_F(DeclarableOpsTests6, cumSum_10) {
|
||||||
auto y = NDArrayFactory::create<int>(-3);
|
auto y = NDArrayFactory::create<int>(-3);
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x, &y}, {}, {1, 1}, {});
|
auto result = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -531,7 +531,7 @@ TEST_F(DeclarableOpsTests6, cumSum_11) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {0, 1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -550,7 +550,7 @@ TEST_F(DeclarableOpsTests6, cumSum_12) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 0, 1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -569,7 +569,7 @@ TEST_F(DeclarableOpsTests6, cumSum_13) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {1, 1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -588,7 +588,7 @@ TEST_F(DeclarableOpsTests6, cumSum_14) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {1, 1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -607,7 +607,7 @@ TEST_F(DeclarableOpsTests6, cumSum_15) {
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 1, 2}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {}, {0, 1, 2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -623,7 +623,7 @@ TEST_F(DeclarableOpsTests6, cumSum_16) {
|
||||||
NDArray x('f', {3, 4}, nd4j::DataType::FLOAT32);
|
NDArray x('f', {3, 4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 0, 1});
|
auto result = op.evaluate({&x}, {}, {0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -659,7 +659,7 @@ TEST_F(DeclarableOpsTests6, cumSum_17) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 0, 1});
|
auto result = op.evaluate({&x}, {}, {0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -692,7 +692,7 @@ TEST_F(DeclarableOpsTests6, cumSum_18) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {1, 0, 1});
|
auto result = op.evaluate({&x}, {}, {1, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -725,7 +725,7 @@ TEST_F(DeclarableOpsTests6, cumSum_19) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {0, 1, 1});
|
auto result = op.evaluate({&x}, {}, {0, 1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -759,7 +759,7 @@ TEST_F(DeclarableOpsTests6, cumSum_20) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::cumsum op;
|
nd4j::ops::cumsum op;
|
||||||
auto result = op.execute({&x}, {}, {1, 1, 1});
|
auto result = op.evaluate({&x}, {}, {1, 1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -778,7 +778,7 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) {
|
||||||
auto exp = NDArrayFactory::create<int>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
|
auto exp = NDArrayFactory::create<int>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
|
||||||
nd4j::ops::mergemaxindex op;
|
nd4j::ops::mergemaxindex op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &y, &z}, {}, {}, {});
|
auto ress = op.evaluate({&x, &y, &z}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
// ress->at(0)->printIndexedBuffer("MergeMaxIndex Result is ");
|
// ress->at(0)->printIndexedBuffer("MergeMaxIndex Result is ");
|
||||||
|
@ -797,7 +797,7 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
|
||||||
nd4j::ops::mergemaxindex op;
|
nd4j::ops::mergemaxindex op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &y, &z}, {}, {nd4j::DataType::INT64}, {});
|
auto ress = op.evaluate({&x, &y, &z}, {}, {nd4j::DataType::INT64});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
// ress->at(0)->printIndexedBuffer("MergeMaxIndex2 Result is ");
|
// ress->at(0)->printIndexedBuffer("MergeMaxIndex2 Result is ");
|
||||||
|
@ -814,7 +814,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) {
|
||||||
auto shape = NDArrayFactory::create<Nd4jLong>({2, 2});
|
auto shape = NDArrayFactory::create<Nd4jLong>({2, 2});
|
||||||
nd4j::ops::dropout op;
|
nd4j::ops::dropout op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &shape}, {0.2f}, {113}, {}, false, nd4j::DataType::DOUBLE);
|
auto ress = op.evaluate({&x, &shape}, {0.2f}, {113});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
//ress->at(0)->printIndexedBuffer("Result is ");
|
//ress->at(0)->printIndexedBuffer("Result is ");
|
||||||
|
@ -830,7 +830,7 @@ TEST_F(DeclarableOpsTests6, TestMod_1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0});
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0});
|
||||||
nd4j::ops::mod op;
|
nd4j::ops::mod op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &y}, {}, {}, {});
|
auto ress = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
// ress->at(0)->printIndexedBuffer("MOD Result is ");
|
// ress->at(0)->printIndexedBuffer("MOD Result is ");
|
||||||
|
@ -848,7 +848,7 @@ TEST_F(DeclarableOpsTests6, TestMod_BP_1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2});
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2});
|
||||||
nd4j::ops::mod_bp op;
|
nd4j::ops::mod_bp op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &y, &eps}, {}, {}, {});
|
auto ress = op.evaluate({&x, &y, &eps});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
// ress->at(0)->printIndexedBuffer("MOD_BP Result is ");
|
// ress->at(0)->printIndexedBuffer("MOD_BP Result is ");
|
||||||
|
@ -867,7 +867,7 @@ TEST_F(DeclarableOpsTests6, TestRank_1) {
|
||||||
auto exp = NDArrayFactory::create<int>(3);
|
auto exp = NDArrayFactory::create<int>(3);
|
||||||
nd4j::ops::rank op;
|
nd4j::ops::rank op;
|
||||||
|
|
||||||
auto ress = op.execute({&x}, {}, {}, {});
|
auto ress = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
|
|
||||||
|
@ -881,7 +881,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_2) {
|
||||||
|
|
||||||
nd4j::ops::dropout op;
|
nd4j::ops::dropout op;
|
||||||
|
|
||||||
auto ress = op.execute({&x}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE);
|
auto ress = op.evaluate({&x}, {0.4f}, {113});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
|
|
||||||
|
@ -896,7 +896,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_3) {
|
||||||
|
|
||||||
nd4j::ops::dropout op;
|
nd4j::ops::dropout op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &shape}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE);
|
auto ress = op.evaluate({&x, &shape}, {0.4f}, {113});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
|
|
||||||
|
@ -913,7 +913,7 @@ TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) {
|
||||||
|
|
||||||
nd4j::ops::max_pool_with_argmax op;
|
nd4j::ops::max_pool_with_argmax op;
|
||||||
|
|
||||||
auto ress = op.execute({&x}, {}, {1,1,1,1,1,1,1,1,1});
|
auto ress = op.evaluate({&x}, {}, {1,1,1,1,1,1,1,1,1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
ASSERT_TRUE(expI.isSameShape(ress->at(0)));
|
ASSERT_TRUE(expI.isSameShape(ress->at(0)));
|
||||||
|
@ -942,7 +942,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_1) {
|
||||||
|
|
||||||
nd4j::ops::sufficient_statistics op;
|
nd4j::ops::sufficient_statistics op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto ress = op.evaluate({&x, &axis});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
ASSERT_EQ(ress->at(0)->e<double>(0), count);
|
ASSERT_EQ(ress->at(0)->e<double>(0), count);
|
||||||
|
@ -974,7 +974,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_2) {
|
||||||
|
|
||||||
nd4j::ops::sufficient_statistics op;
|
nd4j::ops::sufficient_statistics op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto ress = op.evaluate({&x, &axis});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
ASSERT_EQ(ress->at(0)->e<double>(0), count);
|
ASSERT_EQ(ress->at(0)->e<double>(0), count);
|
||||||
|
@ -996,7 +996,7 @@ TEST_F(DeclarableOpsTests6, BinCount_1) {
|
||||||
|
|
||||||
nd4j::ops::bincount op;
|
nd4j::ops::bincount op;
|
||||||
|
|
||||||
auto res = op.execute({&x}, {}, {});
|
auto res = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1021,7 +1021,7 @@ TEST_F(DeclarableOpsTests6, BinCount_2) {
|
||||||
|
|
||||||
nd4j::ops::bincount op;
|
nd4j::ops::bincount op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &weights}, {}, {});
|
auto res = op.evaluate({&x, &weights});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1046,7 +1046,7 @@ TEST_F(DeclarableOpsTests6, BinCount_3) {
|
||||||
|
|
||||||
nd4j::ops::bincount op;
|
nd4j::ops::bincount op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &weights}, {}, {0, 2});
|
auto res = op.evaluate({&x, &weights}, {}, {0, 2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1071,7 +1071,7 @@ TEST_F(DeclarableOpsTests6, BinCount_4) {
|
||||||
|
|
||||||
nd4j::ops::bincount op;
|
nd4j::ops::bincount op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &weights}, {}, {4, 4});
|
auto res = op.evaluate({&x, &weights}, {}, {4, 4});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1097,7 +1097,7 @@ TEST_F(DeclarableOpsTests6, BinCount_5) {
|
||||||
|
|
||||||
nd4j::ops::bincount op;
|
nd4j::ops::bincount op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &weights, &minV, &maxV}, {}, {});
|
auto res = op.evaluate({&x, &weights, &minV, &maxV});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
// res->at(0)->printBuffer("BC out");
|
// res->at(0)->printBuffer("BC out");
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1116,7 +1116,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) {
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &y}, {}, {});
|
auto res = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1135,7 +1135,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto res = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
|
||||||
|
@ -1153,7 +1153,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {});
|
auto res = op.evaluate({&x, &y}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1172,7 +1172,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto res = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
//res->at(0)->printBuffer("Shape SGO 4");
|
//res->at(0)->printBuffer("Shape SGO 4");
|
||||||
|
@ -1191,7 +1191,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto res = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1209,7 +1209,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
|
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto res = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1274,7 +1274,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
|
||||||
// auto expNorm(8.660254);
|
// auto expNorm(8.660254);
|
||||||
|
|
||||||
nd4j::ops::clip_by_global_norm op;
|
nd4j::ops::clip_by_global_norm op;
|
||||||
auto result = op.execute({&x}, {0.8}, {});
|
auto result = op.evaluate({&x}, {0.8}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1316,7 +1316,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) {
|
||||||
);
|
);
|
||||||
|
|
||||||
nd4j::ops::clip_by_global_norm op;
|
nd4j::ops::clip_by_global_norm op;
|
||||||
auto result = op.execute({&x, &a}, {1.8}, {});
|
auto result = op.evaluate({&x, &a}, {1.8}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1346,7 +1346,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) {
|
||||||
);
|
);
|
||||||
|
|
||||||
nd4j::ops::clip_by_global_norm op;
|
nd4j::ops::clip_by_global_norm op;
|
||||||
auto result = op.execute({&x, &a}, {0.8}, {});
|
auto result = op.evaluate({&x, &a}, {0.8}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1372,7 +1372,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) {
|
||||||
auto exp = NDArrayFactory::create<double>({36.0, -48.0});
|
auto exp = NDArrayFactory::create<double>({36.0, -48.0});
|
||||||
|
|
||||||
nd4j::ops::matrix_determinant op;
|
nd4j::ops::matrix_determinant op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1393,7 +1393,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
|
||||||
auto exp = NDArrayFactory::create<double>({-2.0, -2.0});
|
auto exp = NDArrayFactory::create<double>({-2.0, -2.0});
|
||||||
|
|
||||||
nd4j::ops::matrix_determinant op;
|
nd4j::ops::matrix_determinant op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1414,7 +1414,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
|
||||||
NDArray exp('c', {1}, {-54.0});
|
NDArray exp('c', {1}, {-54.0});
|
||||||
|
|
||||||
nd4j::ops::matrix_determinant op;
|
nd4j::ops::matrix_determinant op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1435,7 +1435,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1}, {189.0});
|
auto exp = NDArrayFactory::create<double>('c', {1}, {189.0});
|
||||||
|
|
||||||
nd4j::ops::matrix_determinant op;
|
nd4j::ops::matrix_determinant op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1459,7 +1459,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
|
||||||
x.p(12, 12.0);
|
x.p(12, 12.0);
|
||||||
|
|
||||||
nd4j::ops::matrix_determinant op;
|
nd4j::ops::matrix_determinant op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1483,7 +1483,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) {
|
||||||
x.p(12, 12.0);
|
x.p(12, 12.0);
|
||||||
|
|
||||||
nd4j::ops::matrix_determinant op;
|
nd4j::ops::matrix_determinant op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1505,7 +1505,7 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
|
||||||
auto exp = NDArrayFactory::create<double>({3.58351893845611, 3.871201010907891});
|
auto exp = NDArrayFactory::create<double>({3.58351893845611, 3.871201010907891});
|
||||||
|
|
||||||
nd4j::ops::log_matrix_determinant op;
|
nd4j::ops::log_matrix_determinant op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1524,7 +1524,7 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
|
||||||
auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008});
|
auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008});
|
||||||
|
|
||||||
nd4j::ops::logdet op;
|
nd4j::ops::logdet op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1542,7 +1542,7 @@ TEST_F(DeclarableOpsTests6, LogDet_2) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1}, { 3.5835189});
|
auto exp = NDArrayFactory::create<double>('c', {1}, { 3.5835189});
|
||||||
|
|
||||||
nd4j::ops::logdet op;
|
nd4j::ops::logdet op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1561,7 +1561,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) {
|
||||||
auto exp = NDArrayFactory::create<double>( 3.5835189);
|
auto exp = NDArrayFactory::create<double>( 3.5835189);
|
||||||
|
|
||||||
nd4j::ops::logdet op;
|
nd4j::ops::logdet op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1605,7 +1605,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1624,7 +1624,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f});
|
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1643,7 +1643,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f });
|
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f });
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1662,7 +1662,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f });
|
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f });
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1740,7 +1740,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1774,7 +1774,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1808,7 +1808,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1842,7 +1842,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests6, ReluLayer_1) {
|
||||||
26.2, 31.65, 60.7});
|
26.2, 31.65, 60.7});
|
||||||
|
|
||||||
nd4j::ops::relu_layer op;
|
nd4j::ops::relu_layer op;
|
||||||
auto result = op.execute({&x, &w, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &w, &b});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -1923,7 +1923,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test1) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
|
||||||
|
|
||||||
nd4j::ops::static_rnn op;
|
nd4j::ops::static_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1966,7 +1966,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test2) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648});
|
||||||
|
|
||||||
nd4j::ops::static_rnn op;
|
nd4j::ops::static_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2009,7 +2009,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test3) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2});
|
||||||
|
|
||||||
nd4j::ops::static_rnn op;
|
nd4j::ops::static_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2051,7 +2051,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test4) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882});
|
||||||
|
|
||||||
nd4j::ops::static_rnn op;
|
nd4j::ops::static_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2093,7 +2093,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test5) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653});
|
||||||
|
|
||||||
nd4j::ops::static_rnn op;
|
nd4j::ops::static_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2144,7 +2144,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) {
|
||||||
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25});
|
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25});
|
||||||
|
|
||||||
nd4j::ops::static_bidirectional_rnn op;
|
nd4j::ops::static_bidirectional_rnn op;
|
||||||
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2197,7 +2197,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) {
|
||||||
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.});
|
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.});
|
||||||
|
|
||||||
nd4j::ops::static_bidirectional_rnn op;
|
nd4j::ops::static_bidirectional_rnn op;
|
||||||
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2250,7 +2250,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) {
|
||||||
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713});
|
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713});
|
||||||
|
|
||||||
nd4j::ops::static_bidirectional_rnn op;
|
nd4j::ops::static_bidirectional_rnn op;
|
||||||
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
|
auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2296,7 +2296,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
|
||||||
|
|
||||||
nd4j::ops::dynamic_rnn op;
|
nd4j::ops::dynamic_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2341,7 +2341,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
|
||||||
|
|
||||||
nd4j::ops::dynamic_rnn op;
|
nd4j::ops::dynamic_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2383,7 +2383,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
|
||||||
|
|
||||||
nd4j::ops::dynamic_rnn op;
|
nd4j::ops::dynamic_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2424,7 +2424,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608});
|
||||||
|
|
||||||
nd4j::ops::dynamic_rnn op;
|
nd4j::ops::dynamic_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2465,7 +2465,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) {
|
||||||
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833});
|
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833});
|
||||||
|
|
||||||
nd4j::ops::dynamic_rnn op;
|
nd4j::ops::dynamic_rnn op;
|
||||||
auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {});
|
auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2521,7 +2521,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) {
|
||||||
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25});
|
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25});
|
||||||
|
|
||||||
nd4j::ops::dynamic_bidirectional_rnn op;
|
nd4j::ops::dynamic_bidirectional_rnn op;
|
||||||
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2581,7 +2581,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) {
|
||||||
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25});
|
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25});
|
||||||
|
|
||||||
nd4j::ops::dynamic_bidirectional_rnn op;
|
nd4j::ops::dynamic_bidirectional_rnn op;
|
||||||
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2637,7 +2637,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) {
|
||||||
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.});
|
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.});
|
||||||
|
|
||||||
nd4j::ops::dynamic_bidirectional_rnn op;
|
nd4j::ops::dynamic_bidirectional_rnn op;
|
||||||
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
|
auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2696,7 +2696,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) {
|
||||||
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357});
|
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357});
|
||||||
|
|
||||||
nd4j::ops::dynamic_bidirectional_rnn op;
|
nd4j::ops::dynamic_bidirectional_rnn op;
|
||||||
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {});
|
auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2749,7 +2749,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
|
||||||
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234});
|
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234});
|
||||||
|
|
||||||
nd4j::ops::dynamic_bidirectional_rnn op;
|
nd4j::ops::dynamic_bidirectional_rnn op;
|
||||||
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
|
auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2776,7 +2776,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_1) {
|
||||||
auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f});
|
auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
@ -2789,7 +2789,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
|
||||||
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
|
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
@ -2802,7 +2802,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
|
||||||
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
|
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
|
||||||
|
|
||||||
nd4j::ops::diag op;
|
nd4j::ops::diag op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result->at(0));
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -51,7 +51,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) {
|
||||||
|
|
||||||
nd4j::ops::reduce_stdev_bp op;
|
nd4j::ops::reduce_stdev_bp op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &gradO2}, {0,0}, {1});
|
auto result = op.evaluate({&x, &gradO2}, {0,0}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
// output->printIndexedBuffer();
|
// output->printIndexedBuffer();
|
||||||
|
@ -59,7 +59,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) {
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
delete result;
|
delete result;
|
||||||
|
|
||||||
result = op.execute({&x, &gradO1}, {1,0}, {1});
|
result = op.evaluate({&x, &gradO1}, {1,0}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
output = result->at(0);
|
output = result->at(0);
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
@ -80,7 +80,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) {
|
||||||
|
|
||||||
nd4j::ops::reduce_stdev_bp op;
|
nd4j::ops::reduce_stdev_bp op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &gradO2, &axis}, {}, {}, {false, false});
|
auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
// output->printIndexedBuffer();
|
// output->printIndexedBuffer();
|
||||||
|
@ -88,7 +88,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) {
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
delete result;
|
delete result;
|
||||||
|
|
||||||
result = op.execute({&x, &gradO1}, {1,0}, {1});
|
result = op.evaluate({&x, &gradO1}, {1,0}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
output = result->at(0);
|
output = result->at(0);
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {1});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ TEST_F(DeclarableOpsTests9, concat_test2) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {1});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -296,7 +296,7 @@ TEST_F(DeclarableOpsTests9, concat_test3) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {0});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -316,7 +316,7 @@ TEST_F(DeclarableOpsTests9, concat_test4) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {1});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -336,7 +336,7 @@ TEST_F(DeclarableOpsTests9, concat_test5) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {0});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -356,7 +356,7 @@ TEST_F(DeclarableOpsTests9, concat_test6) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {0});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -376,7 +376,7 @@ TEST_F(DeclarableOpsTests9, concat_test7) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {0});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -394,7 +394,7 @@ TEST_F(DeclarableOpsTests9, concat_test8) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0}, {}, {0});
|
auto result = op.evaluate({&x0}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -412,7 +412,7 @@ TEST_F(DeclarableOpsTests9, concat_test9) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0}, {}, {0});
|
auto result = op.evaluate({&x0}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -437,7 +437,7 @@ TEST_F(DeclarableOpsTests9, concat_test10) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {1});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -462,7 +462,7 @@ TEST_F(DeclarableOpsTests9, concat_test11) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {1});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests9, concat_test12) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {1});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -512,7 +512,7 @@ TEST_F(DeclarableOpsTests9, concat_test13) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {1});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -532,7 +532,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
|
||||||
x1 = 2.;
|
x1 = 2.;
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&x0, &x1}, {}, {0}, {});
|
auto result = op.evaluate({&x0, &x1}, {}, {0}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -555,7 +555,7 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3}, {1, 0, 3});
|
auto exp = NDArrayFactory::create<double>('c', {3}, {1, 0, 3});
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&x, &y}, {}, {0});
|
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
|
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&x, &y}, {}, {0});
|
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -594,7 +594,7 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
|
||||||
x1 = 2.;
|
x1 = 2.;
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&x0, &x1}, {}, {0}, {});
|
auto result = op.evaluate({&x0, &x1}, {}, {0}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -675,7 +675,7 @@ TEST_F(DeclarableOpsTests9, concat_test20) {
|
||||||
x3.assign(4.0);
|
x3.assign(4.0);
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {});
|
auto result = op.evaluate({&x0, &x1, &x2, &x3}, {}, {0}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -763,7 +763,7 @@ TEST_F(DeclarableOpsTests9, concat_test25) {
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
auto result = op.execute({&x0, &x1, &axis}, {}, {}, {true});
|
auto result = op.evaluate({&x0, &x1, &axis}, {}, {}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
@ -784,7 +784,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
nd4j::ops::tile_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {2, 3});
|
auto results = op.evaluate({&input, &gradO}, {}, {2, 3});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -804,7 +804,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test2) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
nd4j::ops::tile_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {1, 3});
|
auto results = op.evaluate({&input, &gradO}, {}, {1, 3});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
|
@ -823,7 +823,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test3) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
nd4j::ops::tile_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {1, 1});
|
auto results = op.evaluate({&input, &gradO}, {}, {1, 1});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -843,7 +843,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test4) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
nd4j::ops::tile_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {2});
|
auto results = op.evaluate({&input, &gradO}, {}, {2});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -863,7 +863,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test5) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
nd4j::ops::tile_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {1});
|
auto results = op.evaluate({&input, &gradO}, {}, {1});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -883,7 +883,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test6) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
nd4j::ops::tile_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {1, 3, 2});
|
auto results = op.evaluate({&input, &gradO}, {}, {1, 3, 2});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -904,7 +904,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test7) {
|
||||||
gradO.linspace(0.01, 0.01);
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
nd4j::ops::tile_bp op;
|
||||||
auto results = op.execute({&input, &reps, &gradO}, {}, {});
|
auto results = op.evaluate({&input, &reps, &gradO}, {}, {});
|
||||||
auto gradI = results->at(0);
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -922,7 +922,7 @@ TEST_F(DeclarableOpsTests9, tile_test1) {
|
||||||
auto expOut = NDArrayFactory::create<double>('c', {2, 6,}, {1.,2.,3.,4.,5.,6., 1.,2.,3.,4.,5.,6.});
|
auto expOut = NDArrayFactory::create<double>('c', {2, 6,}, {1.,2.,3.,4.,5.,6., 1.,2.,3.,4.,5.,6.});
|
||||||
|
|
||||||
nd4j::ops::tile op;
|
nd4j::ops::tile op;
|
||||||
auto results = op.execute({&input, &reps}, {}, {});
|
auto results = op.evaluate({&input, &reps}, {}, {});
|
||||||
auto out = results->at(0);
|
auto out = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -944,7 +944,7 @@ TEST_F(DeclarableOpsTests9, matmul_test1) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -966,7 +966,7 @@ TEST_F(DeclarableOpsTests9, matmul_test2) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -987,7 +987,7 @@ TEST_F(DeclarableOpsTests9, matmul_test3) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1009,7 +1009,7 @@ TEST_F(DeclarableOpsTests9, matmul_test4) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1031,7 +1031,7 @@ TEST_F(DeclarableOpsTests9, matmul_test5) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1});
|
auto results = op.evaluate({&x, &y}, {}, {1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1052,7 +1052,7 @@ TEST_F(DeclarableOpsTests9, matmul_test6) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1075,7 +1075,7 @@ TEST_F(DeclarableOpsTests9, matmul_test7) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {0, 1});
|
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1100,7 +1100,7 @@ TEST_F(DeclarableOpsTests9, matmul_test8) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {0, 1});
|
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1125,7 +1125,7 @@ TEST_F(DeclarableOpsTests9, matmul_test9) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1142,7 +1142,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_BP_1) {
|
||||||
NDArray shape('c', {2}, {2, 2});
|
NDArray shape('c', {2}, {2, 2});
|
||||||
nd4j::ops::dropout_bp op;
|
nd4j::ops::dropout_bp op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &errs, &shape}, {0.2f}, {113});
|
auto ress = op.evaluate({&x, &errs, &shape}, {0.2f}, {113});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
//ress->at(0)->printIndexedBuffer("Result is ");
|
//ress->at(0)->printIndexedBuffer("Result is ");
|
||||||
|
@ -1159,7 +1159,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) {
|
||||||
//NDArray<float> shape({2.f, 2.f});
|
//NDArray<float> shape({2.f, 2.f});
|
||||||
nd4j::ops::dropout op;
|
nd4j::ops::dropout op;
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
auto ress = op.execute({&x}, {0.2f}, {113});
|
auto ress = op.evaluate({&x}, {0.2f}, {113});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
NDArray* res = ress->at(0); //->printIndexedBuffer("Result is ");
|
NDArray* res = ress->at(0); //->printIndexedBuffer("Result is ");
|
||||||
|
@ -1167,7 +1167,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) {
|
||||||
//res->printIndexedBuffer("Result for Dropout_1");
|
//res->printIndexedBuffer("Result for Dropout_1");
|
||||||
auto countZero = res->reduceNumber(reduce::CountZero);
|
auto countZero = res->reduceNumber(reduce::CountZero);
|
||||||
ASSERT_NEAR(countZero.e<Nd4jLong>(0), 80, 5);
|
ASSERT_NEAR(countZero.e<Nd4jLong>(0), 80, 5);
|
||||||
auto ress2 = op.execute({&x}, {0.2f}, {113});
|
auto ress2 = op.evaluate({&x}, {0.2f}, {113});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress2->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress2->status());
|
||||||
NDArray* res2 = ress2->at(0);
|
NDArray* res2 = ress2->at(0);
|
||||||
|
@ -1214,7 +1214,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
||||||
*/
|
*/
|
||||||
nd4j::ops::dropout op;
|
nd4j::ops::dropout op;
|
||||||
|
|
||||||
auto ress = op.execute({&x1}, {0.5f}, {119});
|
auto ress = op.evaluate({&x1}, {0.5f}, {119});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
//ress->at(0)->printIndexedBuffer("01Dropout result is ");
|
//ress->at(0)->printIndexedBuffer("01Dropout result is ");
|
||||||
|
@ -1225,11 +1225,11 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
||||||
//NDArray<float> exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f});
|
//NDArray<float> exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f});
|
||||||
//02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000]
|
//02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000]
|
||||||
|
|
||||||
auto ressX = op2.execute({&x1, &x1}, {0.5f}, {119}); // , false, nd4j::DataType::FLOAT32); // skipped due given by default
|
auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, {119}); // , false, nd4j::DataType::FLOAT32); // skipped due given by default
|
||||||
//x0.printIndexedBuffer("X0");
|
//x0.printIndexedBuffer("X0");
|
||||||
//x1.printIndexedBuffer("X1");
|
//x1.printIndexedBuffer("X1");
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ressX->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ressX->status());
|
||||||
auto ressY = op2.execute({&x1, &x0}, {0.5f}, {119});
|
auto ressY = op2.evaluate({&x1, &x0}, {0.5f}, {119});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ressY->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ressY->status());
|
||||||
//ressY->at(0)->printIndexedBuffer("BP");
|
//ressY->at(0)->printIndexedBuffer("BP");
|
||||||
//ress->at(0)->printIndexedBuffer("FF");
|
//ress->at(0)->printIndexedBuffer("FF");
|
||||||
|
@ -1264,17 +1264,17 @@ TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) {
|
||||||
|
|
||||||
nd4j::ops::dropout op;
|
nd4j::ops::dropout op;
|
||||||
|
|
||||||
auto ress = op.execute({&x}, {0.5f}, {119});
|
auto ress = op.evaluate({&x}, {0.5f}, {119});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
// ress->at(0)->printIndexedBuffer("01Dropout result is ");
|
// ress->at(0)->printIndexedBuffer("01Dropout result is ");
|
||||||
|
|
||||||
nd4j::ops::dropout_bp op2;
|
nd4j::ops::dropout_bp op2;
|
||||||
|
|
||||||
auto ressX = op2.execute({&x, &x}, {0.5f}, {119});
|
auto ressX = op2.evaluate({&x, &x}, {0.5f}, {119});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ressX->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ressX->status());
|
||||||
auto ressY = op2.execute({&x, &x}, {0.5f}, {119});
|
auto ressY = op2.evaluate({&x, &x}, {0.5f}, {119});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ressY->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ressY->status());
|
||||||
|
|
||||||
//ress->at(0)->printIndexedBuffer("FF Dropout result is ");
|
//ress->at(0)->printIndexedBuffer("FF Dropout result is ");
|
||||||
|
@ -1307,12 +1307,12 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) {
|
||||||
|
|
||||||
nd4j::ops::alpha_dropout_bp op;
|
nd4j::ops::alpha_dropout_bp op;
|
||||||
|
|
||||||
auto ress = op.execute({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119});
|
auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
NDArray* res = ress->at(0);
|
NDArray* res = ress->at(0);
|
||||||
|
|
||||||
auto ress2 = op.execute({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119});
|
auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress2->status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress2->status());
|
||||||
NDArray* res2 = ress2->at(0);
|
NDArray* res2 = ress2->at(0);
|
||||||
|
@ -1336,7 +1336,7 @@ TEST_F(DeclarableOpsTests9, matmul_test10) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1356,7 +1356,7 @@ TEST_F(DeclarableOpsTests9, matmul_test11) {
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
@ -1377,7 +1377,7 @@ TEST_F(DeclarableOpsTests9, matmul_test12) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
@ -1398,7 +1398,7 @@ TEST_F(DeclarableOpsTests9, matmul_test13) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {0, 0, 1});
|
auto results = op.evaluate({&x, &y}, {}, {0, 0, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1419,7 +1419,7 @@ TEST_F(DeclarableOpsTests9, matmul_test14) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 0, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1440,7 +1440,7 @@ TEST_F(DeclarableOpsTests9, matmul_test15) {
|
||||||
y.linspace(0.5, 0.5);
|
y.linspace(0.5, 0.5);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 0, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1464,7 +1464,7 @@ TEST_F(DeclarableOpsTests9, matmul_test16) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 1, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 1, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1485,7 +1485,7 @@ TEST_F(DeclarableOpsTests9, matmul_test17) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 0});
|
auto results = op.evaluate({&x, &y}, {}, {1, 0});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1506,7 +1506,7 @@ TEST_F(DeclarableOpsTests9, matmul_test18) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {0, 1});
|
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1527,7 +1527,7 @@ TEST_F(DeclarableOpsTests9, matmul_test19) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1549,7 +1549,7 @@ TEST_F(DeclarableOpsTests9, matmul_test20) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1,1,1});
|
auto results = op.evaluate({&x, &y}, {}, {1,1,1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1571,7 +1571,7 @@ TEST_F(DeclarableOpsTests9, matmul_test21) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {});
|
auto results = op.evaluate({&x, &y}, {}, {});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1593,7 +1593,7 @@ TEST_F(DeclarableOpsTests9, matmul_test22) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1});
|
auto results = op.evaluate({&x, &y}, {}, {1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1615,7 +1615,7 @@ TEST_F(DeclarableOpsTests9, matmul_test23) {
|
||||||
y.linspace(0.1, 0.1);
|
y.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1634,7 +1634,7 @@ TEST_F(DeclarableOpsTests9, matmul_test24) {
|
||||||
auto exp = NDArrayFactory::create<double>(6.);
|
auto exp = NDArrayFactory::create<double>(6.);
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto results = op.execute({&x, &y}, {}, {1, 1});
|
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
auto z = results->at(0);
|
auto z = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -1650,7 +1650,7 @@ TEST_F(DeclarableOpsTests9, test_range_int_1) {
|
||||||
auto x2 = NDArrayFactory::create<int>(1);
|
auto x2 = NDArrayFactory::create<int>(1);
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1664,7 +1664,7 @@ TEST_F(DeclarableOpsTests9, test_range_empty_1) {
|
||||||
auto x2 = NDArrayFactory::create<int>(1);
|
auto x2 = NDArrayFactory::create<int>(1);
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({&x0, &x1, &x2}, {}, {});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1703,7 +1703,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_1) {
|
||||||
x.linspace(1.0);
|
x.linspace(1.0);
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_EQ(5, result->size());
|
ASSERT_EQ(5, result->size());
|
||||||
|
|
||||||
|
@ -1721,7 +1721,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) {
|
||||||
auto z5 = NDArrayFactory::create<double>(5);
|
auto z5 = NDArrayFactory::create<double>(5);
|
||||||
std::vector<NDArray*> z({&z1, &z2, &z3, &z4, &z5});
|
std::vector<NDArray*> z({&z1, &z2, &z3, &z4, &z5});
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_EQ(5, result->size());
|
ASSERT_EQ(5, result->size());
|
||||||
for (size_t i = 0; i < result->size(); i++) {
|
for (size_t i = 0; i < result->size(); i++) {
|
||||||
|
@ -1758,7 +1758,7 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::clipbynorm op;
|
nd4j::ops::clipbynorm op;
|
||||||
auto result = op.execute({&y}, {clip}, {axis}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&y}, {clip}, {axis});
|
||||||
auto outFF = result->at(0);
|
auto outFF = result->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expect.isSameShape(outFF));
|
ASSERT_TRUE(expect.isSameShape(outFF));
|
||||||
|
@ -1852,7 +1852,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
exclusive = 0; reverse = 0;
|
exclusive = 0; reverse = 0;
|
||||||
|
|
||||||
nd4j::ops::cumprod op;
|
nd4j::ops::cumprod op;
|
||||||
auto result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
ASSERT_TRUE(expFF.equalsTo(z));
|
ASSERT_TRUE(expFF.equalsTo(z));
|
||||||
|
@ -1861,7 +1861,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 1; reverse = 0;
|
exclusive = 1; reverse = 0;
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
|
result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
z = result->at(0);
|
z = result->at(0);
|
||||||
ASSERT_TRUE(expTF.equalsTo(z));
|
ASSERT_TRUE(expTF.equalsTo(z));
|
||||||
|
@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 0; reverse = 1;
|
exclusive = 0; reverse = 1;
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
|
result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
z = result->at(0);
|
z = result->at(0);
|
||||||
ASSERT_TRUE(expFT.equalsTo(z));
|
ASSERT_TRUE(expFT.equalsTo(z));
|
||||||
|
@ -1879,7 +1879,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 1; reverse = 1;
|
exclusive = 1; reverse = 1;
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
|
result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
z = result->at(0);
|
z = result->at(0);
|
||||||
ASSERT_TRUE(expTT.equalsTo(z));
|
ASSERT_TRUE(expTT.equalsTo(z));
|
||||||
|
@ -1910,7 +1910,7 @@ TEST_F(DeclarableOpsTests9, cumprod_2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::cumprod op;
|
nd4j::ops::cumprod op;
|
||||||
auto result = op.execute({&x}, {}, {0, 0, 1});
|
auto result = op.evaluate({&x}, {}, {0, 0, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -2111,7 +2111,7 @@ TEST_F(DeclarableOpsTests9, prelu_test1) {
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &alpha}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2129,7 +2129,7 @@ TEST_F(DeclarableOpsTests9, prelu_test2) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2147,7 +2147,7 @@ TEST_F(DeclarableOpsTests9, prelu_test3) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2165,7 +2165,7 @@ TEST_F(DeclarableOpsTests9, prelu_test4) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2183,7 +2183,7 @@ TEST_F(DeclarableOpsTests9, prelu_test5) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2201,7 +2201,7 @@ TEST_F(DeclarableOpsTests9, prelu_test6) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {1,0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {1,0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2220,7 +2220,7 @@ TEST_F(DeclarableOpsTests9, prelu_test7) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {1,0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {1,0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2238,7 +2238,7 @@ TEST_F(DeclarableOpsTests9, prelu_test8) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {1,0,1,0,1,0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {1,0,1,0,1,0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests9, prelu_test9) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2274,7 +2274,7 @@ TEST_F(DeclarableOpsTests9, prelu_test10) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2299,7 +2299,7 @@ TEST_F(DeclarableOpsTests9, prelu_test11) {
|
||||||
62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
|
62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {1,3}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {1,3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2323,7 +2323,7 @@ TEST_F(DeclarableOpsTests9, prelu_test12) {
|
||||||
53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
|
53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {-1, 2}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {-1, 2});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2347,7 +2347,7 @@ TEST_F(DeclarableOpsTests9, prelu_test13) {
|
||||||
53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
|
53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {-1, 2}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {-1, 2});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2372,7 +2372,7 @@ TEST_F(DeclarableOpsTests9, prelu_test14) {
|
||||||
55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
|
55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
|
||||||
|
|
||||||
nd4j::ops::prelu op;
|
nd4j::ops::prelu op;
|
||||||
auto result = op.execute({&x, &alpha}, {}, {-2}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x, &alpha}, {}, {-2});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2391,7 +2391,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
|
||||||
|
|
||||||
nd4j::ops::thresholdedrelu op;
|
nd4j::ops::thresholdedrelu op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {theta}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {theta});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2411,7 +2411,7 @@ TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
|
||||||
|
|
||||||
nd4j::ops::compare_and_bitpack op;
|
nd4j::ops::compare_and_bitpack op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &threshold}, {}, {}, {});
|
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
// output->printIndexedBuffer("Packed to uint8");
|
// output->printIndexedBuffer("Packed to uint8");
|
||||||
|
@ -2429,7 +2429,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
|
||||||
|
|
||||||
nd4j::ops::thresholdedrelu op;
|
nd4j::ops::thresholdedrelu op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {theta}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.evaluate({&x}, {theta});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
@ -2544,7 +2544,7 @@ TEST_F(DeclarableOpsTests9, multiply_test1) {
|
||||||
y.linspace(0.1f, 0.1f);
|
y.linspace(0.1f, 0.1f);
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -2564,7 +2564,7 @@ TEST_F(DeclarableOpsTests9, multiply_test2) {
|
||||||
// y.linspace(0.1f, 0.1f);
|
// y.linspace(0.1f, 0.1f);
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&y, &x}, {}, {});
|
auto result = op.evaluate({&y, &x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -2584,7 +2584,7 @@ TEST_F(DeclarableOpsTests9, multiply_test3) {
|
||||||
y.linspace(0.1f, 0.1f);
|
y.linspace(0.1f, 0.1f);
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -2603,7 +2603,7 @@ TEST_F(DeclarableOpsTests9, multiply_test4) {
|
||||||
x.linspace(1.f);
|
x.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -2621,7 +2621,7 @@ TEST_F(DeclarableOpsTests9, multiply_test5) {
|
||||||
auto exp = NDArrayFactory::create<double>(0.1f);
|
auto exp = NDArrayFactory::create<double>(0.1f);
|
||||||
|
|
||||||
nd4j::ops::multiply op;
|
nd4j::ops::multiply op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -2643,8 +2643,8 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test1) {
|
||||||
|
|
||||||
nd4j::ops::multiply opFF;
|
nd4j::ops::multiply opFF;
|
||||||
nd4j::ops::multiply_bp opBP;
|
nd4j::ops::multiply_bp opBP;
|
||||||
auto resFF = opFF.execute({&x, &y}, {}, {});
|
auto resFF = opFF.evaluate({&x, &y}, {}, {});
|
||||||
auto resBP = opBP.execute({&x, &y, &dLdz}, {}, {});
|
auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {});
|
||||||
// resFF->at(0)->printIndexedBuffer("Multiply 1x1");
|
// resFF->at(0)->printIndexedBuffer("Multiply 1x1");
|
||||||
// resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x");
|
// resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x");
|
||||||
// resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/
|
// resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/
|
||||||
|
@ -2800,7 +2800,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_2) {
|
||||||
// resFF->at(0)->printIndexedBuffer("FF floormod");
|
// resFF->at(0)->printIndexedBuffer("FF floormod");
|
||||||
// delete resFF;
|
// delete resFF;
|
||||||
nd4j::ops::floormod_bp opBP;
|
nd4j::ops::floormod_bp opBP;
|
||||||
auto resBP = opBP.execute({&x, &y, &dLdz}, {}, {});
|
auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {});
|
||||||
ASSERT_TRUE(resBP->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(resBP->status() == ND4J_STATUS_OK);
|
||||||
|
|
||||||
// resBP->at(0)->printIndexedBuffer("BP floormod /dx");
|
// resBP->at(0)->printIndexedBuffer("BP floormod /dx");
|
||||||
|
@ -2832,10 +2832,10 @@ TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) {
|
||||||
dLdzZ.assign(3);
|
dLdzZ.assign(3);
|
||||||
|
|
||||||
nd4j::ops::dynamic_partition op1;
|
nd4j::ops::dynamic_partition op1;
|
||||||
auto res1 = op1.execute({&x, &y}, {}, {3});
|
auto res1 = op1.evaluate({&x, &y}, {}, {3});
|
||||||
|
|
||||||
nd4j::ops::dynamic_partition_bp op2;
|
nd4j::ops::dynamic_partition_bp op2;
|
||||||
auto res2 = op2.execute({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3});
|
auto res2 = op2.evaluate({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3});
|
||||||
ASSERT_TRUE(res2->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res2->status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res2->size() == 2);
|
ASSERT_TRUE(res2->size() == 2);
|
||||||
// printf("How many: %ul\n", res2->size());
|
// printf("How many: %ul\n", res2->size());
|
||||||
|
@ -2879,7 +2879,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
|
||||||
eps.assign(1.f);
|
eps.assign(1.f);
|
||||||
nd4j::ops::floormod_bp op;
|
nd4j::ops::floormod_bp op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &y, &eps}, {}, {});
|
auto result = op.evaluate({&x, &y, &eps}, {}, {});
|
||||||
|
|
||||||
ASSERT_TRUE(result->size() == 2);
|
ASSERT_TRUE(result->size() == 2);
|
||||||
auto gradX = result->at(0);
|
auto gradX = result->at(0);
|
||||||
|
@ -2924,7 +2924,7 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {});
|
const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {});
|
||||||
nd4j::ops::gruCell op;
|
nd4j::ops::gruCell op;
|
||||||
auto results = op.execute(argsHolderFF);
|
auto results = op.evaluate(argsHolderFF);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2964,7 +2964,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
|
||||||
|
|
||||||
nd4j::ops::cholesky op;
|
nd4j::ops::cholesky op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
// res->printIndexedBuffer("Output for Cholesky1");
|
// res->printIndexedBuffer("Output for Cholesky1");
|
||||||
|
@ -2980,7 +2980,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
|
||||||
|
|
||||||
nd4j::ops::cholesky op;
|
nd4j::ops::cholesky op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
// res->printIndexedBuffer("Output for Cholesky 2");
|
// res->printIndexedBuffer("Output for Cholesky 2");
|
||||||
|
@ -2996,7 +2996,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) {
|
||||||
|
|
||||||
nd4j::ops::cholesky op;
|
nd4j::ops::cholesky op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
// res->printIndexedBuffer("Output for Cholesky 3");
|
// res->printIndexedBuffer("Output for Cholesky 3");
|
||||||
|
|
|
@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
|
||||||
|
|
||||||
nd4j::ops::choose op;
|
nd4j::ops::choose op;
|
||||||
//greater than test
|
//greater than test
|
||||||
auto result = op.execute({&x}, {0.0},{3});
|
auto result = op.evaluate({&x}, {0.0},{3});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(1);
|
auto z = result->at(1);
|
||||||
|
|
|
@ -66,7 +66,7 @@ TEST_F(EmptyTests, Test_Concat_1) {
|
||||||
ASSERT_TRUE(empty->isEmpty());
|
ASSERT_TRUE(empty->isEmpty());
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({empty, vector}, {}, {0});
|
auto result = op.evaluate({empty, vector}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -91,7 +91,7 @@ TEST_F(EmptyTests, Test_Concat_2) {
|
||||||
ASSERT_TRUE(empty->isEmpty());
|
ASSERT_TRUE(empty->isEmpty());
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({empty, scalar1, scalar2}, {}, {0});
|
auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -116,7 +116,7 @@ TEST_F(EmptyTests, Test_Concat_3) {
|
||||||
ASSERT_TRUE(empty.isEmpty());
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&empty, &scalar1, &scalar2}, {}, {0});
|
auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -135,7 +135,7 @@ TEST_F(EmptyTests, Test_Concat_4) {
|
||||||
ASSERT_TRUE(empty.isEmpty());
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&scalar1, &empty, &scalar2}, {}, {0});
|
auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -151,7 +151,7 @@ TEST_F(EmptyTests, Test_Reshape_1) {
|
||||||
auto empty = NDArrayFactory::empty_<int>();
|
auto empty = NDArrayFactory::empty_<int>();
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&vector, empty}, {}, {});
|
auto result = op.evaluate({&vector, empty}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -167,7 +167,7 @@ TEST_F(EmptyTests, Test_Reshape_2) {
|
||||||
auto empty = NDArrayFactory::empty_<Nd4jLong>();
|
auto empty = NDArrayFactory::empty_<Nd4jLong>();
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&vector, empty}, {}, {}, {}, true);
|
auto result = op.evaluate({&vector, empty}, {}, {}, {}, {}, true);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
@ -184,7 +184,7 @@ TEST_F(EmptyTests, Test_Reshape_3) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {10, 0});
|
auto e = NDArrayFactory::create<float>('c', {10, 0});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -213,7 +213,7 @@ TEST_F(EmptyTests, test_empty_scatter_1) {
|
||||||
x.linspace(1.0f);
|
x.linspace(1.0f);
|
||||||
|
|
||||||
nd4j::ops::scatter_upd op;
|
nd4j::ops::scatter_upd op;
|
||||||
auto result = op.execute({&x, &indices, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -311,12 +311,12 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
|
||||||
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result0 = op.execute({&x0, &shape0}, {}, {});
|
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result0->status());
|
ASSERT_EQ(Status::OK(), result0->status());
|
||||||
auto z0 = result0->at(0);
|
auto z0 = result0->at(0);
|
||||||
ASSERT_EQ(e0, *z0);
|
ASSERT_EQ(e0, *z0);
|
||||||
|
|
||||||
auto result1 = op.execute({&x1, &shape1}, {}, {});
|
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result1->status());
|
ASSERT_EQ(Status::OK(), result1->status());
|
||||||
auto z1 = result1->at(0);
|
auto z1 = result1->at(0);
|
||||||
ASSERT_EQ(e1, *z1);
|
ASSERT_EQ(e1, *z1);
|
||||||
|
@ -332,7 +332,7 @@ TEST_F(EmptyTests, test_empty_matmul_1) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {0, 0});
|
auto e = NDArrayFactory::create<float>('c', {0, 0});
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -347,7 +347,7 @@ TEST_F(EmptyTests, test_empty_matmul_2) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 0, 0});
|
auto e = NDArrayFactory::create<float>('c', {1, 0, 0});
|
||||||
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -46,7 +46,7 @@ TEST_F(IndexingTests, StridedSlice_1) {
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
|
auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -65,7 +65,7 @@ TEST_F(IndexingTests, StridedSlice_2) {
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
|
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -85,7 +85,7 @@ TEST_F(IndexingTests, StridedSlice_3) {
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
|
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -108,7 +108,7 @@ TEST_F(IndexingTests, SimpleSlice_1) {
|
||||||
|
|
||||||
nd4j::ops::slice op;
|
nd4j::ops::slice op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1,0,0, 1,1,3});
|
auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -134,7 +134,7 @@ TEST_F(IndexingTests, SimpleSlice_2) {
|
||||||
|
|
||||||
nd4j::ops::slice op;
|
nd4j::ops::slice op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1,0,0, 1,2,3});
|
auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -159,7 +159,7 @@ TEST_F(IndexingTests, SimpleSlice_3) {
|
||||||
|
|
||||||
nd4j::ops::slice op;
|
nd4j::ops::slice op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1,0,0, 2,1,3});
|
auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -179,7 +179,7 @@ TEST_F(IndexingTests, SimpleSlice_4) {
|
||||||
|
|
||||||
nd4j::ops::slice op;
|
nd4j::ops::slice op;
|
||||||
|
|
||||||
auto result = op.execute({&input, &start, &stop}, {}, {});
|
auto result = op.evaluate({&input, &start, &stop});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -202,7 +202,7 @@ TEST_F(IndexingTests, MaskedSlice_0) {
|
||||||
exp.assign(2.0f);
|
exp.assign(2.0f);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -228,7 +228,7 @@ TEST_F(IndexingTests, MaskedSlice_00) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ TEST_F(IndexingTests, MaskedSlice_1) {
|
||||||
exp.assign(2.0f);
|
exp.assign(2.0f);
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -273,7 +273,7 @@ TEST_F(IndexingTests, MaskedSlice_2) {
|
||||||
|
|
||||||
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -293,7 +293,7 @@ TEST_F(IndexingTests, MaskedSlice_3) {
|
||||||
|
|
||||||
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -313,7 +313,7 @@ TEST_F(IndexingTests, MaskedSlice_4) {
|
||||||
|
|
||||||
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -336,7 +336,7 @@ TEST_F(IndexingTests, Live_Slice_1) {
|
||||||
|
|
||||||
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
|
auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -359,7 +359,7 @@ TEST_F(IndexingTests, Test_StridedSlice_1) {
|
||||||
auto exp = NDArrayFactory::create<float>({5.0f, 2});
|
auto exp = NDArrayFactory::create<float>({5.0f, 2});
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -379,7 +379,7 @@ TEST_F(IndexingTests, Test_StridedSlice_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1}, {5.0});
|
auto exp = NDArrayFactory::create<float>('c', {1}, {5.0});
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -402,7 +402,7 @@ TEST_F(IndexingTests, Test_StridedSlice_3) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1}, {6.0});
|
auto exp = NDArrayFactory::create<float>('c', {1}, {6.0});
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -423,7 +423,7 @@ TEST_F(IndexingTests, Test_StridedSlice_4) {
|
||||||
auto exp = NDArrayFactory::create<float>({5.0f, 2});
|
auto exp = NDArrayFactory::create<float>({5.0f, 2});
|
||||||
|
|
||||||
nd4j::ops::strided_slice op;
|
nd4j::ops::strided_slice op;
|
||||||
auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
|
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
|
@ -62,7 +62,7 @@ TEST_F(LegacyOpsTests, TransformTests_2) {
|
||||||
exp.assign(-1.0);
|
exp.assign(-1.0);
|
||||||
|
|
||||||
nd4j::ops::LegacyTransformSameOp op(transform::Neg); // Neg
|
nd4j::ops::LegacyTransformSameOp op(transform::Neg); // Neg
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ TEST_F(LegacyOpsTests, PWT_Tests_2) {
|
||||||
exp.assign(6.0);
|
exp.assign(6.0);
|
||||||
|
|
||||||
nd4j::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
|
nd4j::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ TEST_F(LegacyOpsTests, Scalar_Test_2) {
|
||||||
auto y = NDArrayFactory::create<float>(5.0f);
|
auto y = NDArrayFactory::create<float>(5.0f);
|
||||||
|
|
||||||
nd4j::ops::LegacyScalarOp op(scalar::Add, y);
|
nd4j::ops::LegacyScalarOp op(scalar::Add, y);
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -167,7 +167,7 @@ TEST_F(LegacyOpsTests, ReduceTests_1) {
|
||||||
int opNum = reduce::Sum;
|
int opNum = reduce::Sum;
|
||||||
nd4j::ops::LegacyReduceSameOp op(opNum);
|
nd4j::ops::LegacyReduceSameOp op(opNum);
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
|
@ -186,7 +186,7 @@ TEST_F(LegacyOpsTests, ReduceTests_2) {
|
||||||
|
|
||||||
nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
|
nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||||
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
|
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
|
||||||
auto result = op.execute({&x, &axis}, {}, {});
|
auto result = op.evaluate({&x, &axis}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
|
@ -208,7 +208,7 @@ TEST_F(LegacyOpsTests, ReduceTests_3) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
|
nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||||
auto result = op.execute({&x, &indices}, {}, {});
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
auto exp = x.reduceAlongDimension(reduce::Sum,{1});
|
auto exp = x.reduceAlongDimension(reduce::Sum,{1});
|
||||||
|
|
||||||
|
@ -228,7 +228,7 @@ TEST_F(LegacyOpsTests, ReduceTests_4) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
|
nd4j::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||||
auto result = op.execute({&x, &indices}, {}, {}, {true});
|
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
|
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
|
||||||
// indices.printShapeInfo("Indices shape");
|
// indices.printShapeInfo("Indices shape");
|
||||||
|
@ -247,7 +247,7 @@ TEST_F(LegacyOpsTests, ReduceTests_5) {
|
||||||
int opNum = reduce::Mean;
|
int opNum = reduce::Mean;
|
||||||
nd4j::ops::LegacyReduceFloatOp op(opNum);
|
nd4j::ops::LegacyReduceFloatOp op(opNum);
|
||||||
|
|
||||||
ResultSet* result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
|
@ -266,7 +266,7 @@ TEST_F(LegacyOpsTests, ReduceTests_6) {
|
||||||
auto axis = NDArrayFactory::create<int>('c', {1}, {1});
|
auto axis = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
|
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||||
|
|
||||||
auto result = op.execute({&x, &axis}, {}, {});
|
auto result = op.evaluate({&x, &axis}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
|
@ -288,7 +288,7 @@ TEST_F(LegacyOpsTests, ReduceTests_7) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
|
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||||
auto result = op.execute({&x, &indices}, {}, {});
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
auto exp = x.reduceAlongDimension(reduce::Mean,{1});
|
auto exp = x.reduceAlongDimension(reduce::Mean,{1});
|
||||||
|
|
||||||
|
@ -308,7 +308,7 @@ TEST_F(LegacyOpsTests, ReduceTests_8) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
|
nd4j::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||||
auto result = op.execute({&x, &indices}, {}, {}, {true});
|
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
|
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
|
||||||
|
|
||||||
|
@ -329,7 +329,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_1) {
|
||||||
|
|
||||||
nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax);
|
nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax);
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
|
@ -349,7 +349,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({4,4,4,4,4});
|
auto exp = NDArrayFactory::create<Nd4jLong>({4,4,4,4,4});
|
||||||
nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax);
|
nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax);
|
||||||
|
|
||||||
auto result = op.execute({&x, &indices}, {}, {});
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
|
|
|
@ -133,7 +133,7 @@ TEST_F(MultiDataTypeTests, Basic_Test_7) {
|
||||||
auto e = NDArrayFactory::create<float>('c', {2, 3}, {0.f, 2.f, 4.f, 6.f, 8.f, 10.f});
|
auto e = NDArrayFactory::create<float>('c', {2, 3}, {0.f, 2.f, 4.f, 6.f, 8.f, 10.f});
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
auto result = op.execute({&x, &y},{}, {});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -65,7 +65,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
|
||||||
auto inferenceVector = NDArrayFactory::empty<float>();
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
nd4j::ops::skipgram op;
|
nd4j::ops::skipgram op;
|
||||||
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto row0 = syn0({0,1, 0,0}, true);
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
|
@ -106,7 +106,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
|
||||||
auto inferenceVector = NDArrayFactory::empty<float>();
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
nd4j::ops::skipgram op;
|
nd4j::ops::skipgram op;
|
||||||
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto row0 = syn0({0,1, 0,0}, true);
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
|
@ -157,8 +157,8 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
|
||||||
auto inferenceVector = NDArrayFactory::empty<float>();
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
nd4j::ops::skipgram op;
|
nd4j::ops::skipgram op;
|
||||||
auto result0 = op.execute({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true);
|
auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
auto result1 = op.execute({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true);
|
auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result0->status());
|
ASSERT_EQ(Status::OK(), result0->status());
|
||||||
|
|
||||||
auto row00 = syn00({0,1, 0,0}, true);
|
auto row00 = syn00({0,1, 0,0}, true);
|
||||||
|
@ -191,7 +191,7 @@ TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
|
||||||
auto inferenceVector = NDArrayFactory::empty<float>();
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
nd4j::ops::skipgram op;
|
nd4j::ops::skipgram op;
|
||||||
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -226,7 +226,7 @@ TEST_F(NlpTests, basic_sg_ns_test_1) {
|
||||||
auto inferenceVector = NDArrayFactory::empty<float>();
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
nd4j::ops::skipgram op;
|
nd4j::ops::skipgram op;
|
||||||
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto row0 = syn0({1,2, 0,0}, true);
|
auto row0 = syn0({1,2, 0,0}, true);
|
||||||
|
@ -268,7 +268,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
|
||||||
auto inferenceVector = NDArrayFactory::empty<float>();
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
nd4j::ops::cbow op;
|
nd4j::ops::cbow op;
|
||||||
auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, true);
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||||
|
@ -322,7 +322,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
|
||||||
auto inferenceVector = NDArrayFactory::empty<float>();
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
nd4j::ops::cbow op;
|
nd4j::ops::cbow op;
|
||||||
auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, true);
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||||
|
@ -371,7 +371,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
|
||||||
expTable.assign(0.5);
|
expTable.assign(0.5);
|
||||||
|
|
||||||
nd4j::ops::skipgram op;
|
nd4j::ops::skipgram op;
|
||||||
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto row0 = syn0({0,1, 0,0}, true);
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
|
@ -415,7 +415,7 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
|
||||||
negTable.linspace(0.0);
|
negTable.linspace(0.0);
|
||||||
|
|
||||||
nd4j::ops::skipgram op;
|
nd4j::ops::skipgram op;
|
||||||
auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto row0 = syn0({0,0, 0,0}, true);
|
auto row0 = syn0({0,0, 0,0}, true);
|
||||||
|
@ -452,7 +452,7 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
||||||
auto inferenceVector = NDArrayFactory::empty<float>();
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
nd4j::ops::cbow op;
|
nd4j::ops::cbow op;
|
||||||
auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, true);
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
|
@ -41,7 +41,7 @@ TEST_F(ParityOpsTests, TestZeroAs1) {
|
||||||
|
|
||||||
nd4j::ops::zeros_as op;
|
nd4j::ops::zeros_as op;
|
||||||
|
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ TEST_F(ParityOpsTests, TestMaximum1) {
|
||||||
|
|
||||||
nd4j::ops::maximum op;
|
nd4j::ops::maximum op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ TEST_F(ParityOpsTests, TestMinimum1) {
|
||||||
|
|
||||||
nd4j::ops::minimum op;
|
nd4j::ops::minimum op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ TEST_F(ParityOpsTests, TestTear1) {
|
||||||
|
|
||||||
nd4j::ops::tear op;
|
nd4j::ops::tear op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1});
|
auto result = op.evaluate({&input}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(10, result->size());
|
ASSERT_EQ(10, result->size());
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ TEST_F(ParityOpsTests, TestUnstack1) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {0});
|
auto result = op.evaluate({&input}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(10, result->size());
|
ASSERT_EQ(10, result->size());
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ TEST_F(ParityOpsTests, TestUnstack2) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {2});
|
auto result = op.evaluate({&input}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(6, result->size());
|
ASSERT_EQ(6, result->size());
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ TEST_F(ParityOpsTests, TestUnstack3) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {2});
|
auto result = op.evaluate({&input}, {}, {2});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -177,7 +177,7 @@ TEST_F(ParityOpsTests, TestUnstack4) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1});
|
auto result = op.evaluate({&input}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -195,7 +195,7 @@ TEST_F(ParityOpsTests, TestUnstack5) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {0});
|
auto result = op.evaluate({&input}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -213,7 +213,7 @@ TEST_F(ParityOpsTests, TestUnstack6) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {0});
|
auto result = op.evaluate({&input}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -231,7 +231,7 @@ TEST_F(ParityOpsTests, TestUnstack7) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1});
|
auto result = op.evaluate({&input}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -249,7 +249,7 @@ TEST_F(ParityOpsTests, TestUnstack8) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {0});
|
auto result = op.evaluate({&input}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -267,7 +267,7 @@ TEST_F(ParityOpsTests, TestUnstack9) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1});
|
auto result = op.evaluate({&input}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -286,7 +286,7 @@ TEST_F(ParityOpsTests, TestUnstack10) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {0});
|
auto result = op.evaluate({&input}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(result->at(0)));
|
ASSERT_TRUE(exp.isSameShape(result->at(0)));
|
||||||
|
@ -304,7 +304,7 @@ TEST_F(ParityOpsTests, TestUnstack11) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {2});
|
auto result = op.evaluate({&input}, {}, {2});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(result->at(0)));
|
ASSERT_TRUE(exp.isSameShape(result->at(0)));
|
||||||
|
@ -320,7 +320,7 @@ TEST_F(ParityOpsTests, TestUnstack12) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1});
|
auto result = op.evaluate({&input}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
ASSERT_TRUE(result->size() == 0);
|
ASSERT_TRUE(result->size() == 0);
|
||||||
|
@ -334,7 +334,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest1) {
|
||||||
auto reshaped = input.reshape('c', {5, 1, 5});
|
auto reshaped = input.reshape('c', {5, 1, 5});
|
||||||
|
|
||||||
nd4j::ops::expand_dims op;
|
nd4j::ops::expand_dims op;
|
||||||
auto result = op.execute({&input}, {}, {1});
|
auto result = op.evaluate({&input}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -353,7 +353,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest2) {
|
||||||
auto reshaped = input.reshape('c', {1, 3, 4});
|
auto reshaped = input.reshape('c', {1, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::expand_dims op;
|
nd4j::ops::expand_dims op;
|
||||||
auto result = op.execute({&input}, {}, {0});
|
auto result = op.evaluate({&input}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -372,7 +372,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest3) {
|
||||||
auto reshaped = input.reshape('c', {3, 1, 4});
|
auto reshaped = input.reshape('c', {3, 1, 4});
|
||||||
|
|
||||||
nd4j::ops::expand_dims op;
|
nd4j::ops::expand_dims op;
|
||||||
auto result = op.execute({&input}, {}, {-2});
|
auto result = op.evaluate({&input}, {}, {-2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -390,7 +390,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest4) {
|
||||||
auto reshaped = input.reshape('c', {1, 3, 4});
|
auto reshaped = input.reshape('c', {1, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::expand_dims op;
|
nd4j::ops::expand_dims op;
|
||||||
auto result = op.execute({&input}, {}, {-3});
|
auto result = op.evaluate({&input}, {}, {-3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -408,7 +408,7 @@ TEST_F(ParityOpsTests, Test_Shape_1) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4}, {3, 4, 5, 6});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4}, {3, 4, 5, 6});
|
||||||
|
|
||||||
nd4j::ops::shape_of op;
|
nd4j::ops::shape_of op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -426,7 +426,7 @@ TEST_F(ParityOpsTests, Test_Equals_1) {
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 0, 1, 0, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 0, 1, 0, 1});
|
||||||
|
|
||||||
nd4j::ops::equals op;
|
nd4j::ops::equals op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -444,7 +444,7 @@ TEST_F(ParityOpsTests, Test_NotEquals_1) {
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 1, 0, 1, 0});
|
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 1, 0, 1, 0});
|
||||||
|
|
||||||
nd4j::ops::not_equals op;
|
nd4j::ops::not_equals op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -461,7 +461,7 @@ TEST_F(ParityOpsTests, Test_Less_1) {
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 1, 0, 0, 0});
|
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 1, 0, 0, 0});
|
||||||
|
|
||||||
nd4j::ops::less op;
|
nd4j::ops::less op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -478,7 +478,7 @@ TEST_F(ParityOpsTests, Test_LessEquals_1) {
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 1, 1, 0, 0});
|
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {1, 1, 1, 0, 0});
|
||||||
|
|
||||||
nd4j::ops::less_equal op;
|
nd4j::ops::less_equal op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -495,7 +495,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_1) {
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 1, 1, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 1, 1, 1});
|
||||||
|
|
||||||
nd4j::ops::greater_equal op;
|
nd4j::ops::greater_equal op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -512,7 +512,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_2) {
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 1, 1, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 1, 1, 1});
|
||||||
|
|
||||||
nd4j::ops::greater_equal op;
|
nd4j::ops::greater_equal op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false);
|
auto result = op.evaluate({&x, &y}, {}, {}, {}, {}, false);
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -529,7 +529,7 @@ TEST_F(ParityOpsTests, Test_Greater_1) {
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 0, 1, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {1, 5}, {0, 0, 0, 1, 1});
|
||||||
|
|
||||||
nd4j::ops::greater op;
|
nd4j::ops::greater op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -547,7 +547,7 @@ TEST_F(ParityOpsTests, Test_Where_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9});
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto result = op.execute({&mask, &x, &y}, {}, {});
|
auto result = op.evaluate({&mask, &x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -567,7 +567,7 @@ TEST_F(ParityOpsTests, Test_Where_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1});
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto result = op.execute({&mask, &x, &y}, {}, {});
|
auto result = op.evaluate({&mask, &x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -584,7 +584,7 @@ TEST_F(ParityOpsTests, Test_Where_3) {
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2});
|
||||||
|
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
auto result = op.execute({&mask}, {}, {});
|
auto result = op.evaluate({&mask}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -604,7 +604,7 @@ TEST_F(ParityOpsTests, Test_Select_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1});
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1});
|
||||||
|
|
||||||
nd4j::ops::select op;
|
nd4j::ops::select op;
|
||||||
auto result = op.execute({&mask, &x, &y}, {}, {});
|
auto result = op.evaluate({&mask, &x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -622,7 +622,7 @@ TEST_F(ParityOpsTests, Test_Select_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1, 8, 3, 6});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1, 8, 3, 6});
|
||||||
|
|
||||||
nd4j::ops::select op;
|
nd4j::ops::select op;
|
||||||
auto result = op.execute({&mask, &x, &y}, {}, {});
|
auto result = op.evaluate({&mask, &x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -641,7 +641,7 @@ TEST_F(ParityOpsTests, Test_Select_3) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 1}, {2});
|
auto exp = NDArrayFactory::create<float>('c', {1, 1}, {2});
|
||||||
|
|
||||||
nd4j::ops::select op;
|
nd4j::ops::select op;
|
||||||
auto result = op.execute({&mask, &x, &y}, {}, {});
|
auto result = op.evaluate({&mask, &x, &y}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -660,7 +660,7 @@ TEST_F(ParityOpsTests, Test_Reshape_TF_1) {
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &shape}, {}, {});
|
auto result = op.evaluate({&x, &shape}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -677,7 +677,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) {
|
||||||
auto bias = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
auto bias = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
||||||
nd4j::ops::biasadd op;
|
nd4j::ops::biasadd op;
|
||||||
|
|
||||||
auto result = op.execute({&x, &bias}, {}, {});
|
auto result = op.evaluate({&x, &bias}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -697,7 +697,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {2, 3, 3, 4});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {2, 3, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -715,7 +715,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {2, 3, 4, 5});
|
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {2, 3, 4, 5});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&vec, &idc, &updates}, {}, {});
|
auto result = op.evaluate({&vec, &idc, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -732,7 +732,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -749,7 +749,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -766,7 +766,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -784,7 +784,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -801,7 +801,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f});
|
auto exp = NDArrayFactory::create<float>('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -850,7 +850,7 @@ TEST_F(ParityOpsTests, scatterMax_test1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::scatter_max op;
|
nd4j::ops::scatter_max op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -867,7 +867,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {10, 2, 30, 4});
|
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {10, 2, 30, 4});
|
||||||
|
|
||||||
nd4j::ops::scatter_max op;
|
nd4j::ops::scatter_max op;
|
||||||
auto result = op.execute({&vec, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -884,7 +884,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8});
|
||||||
|
|
||||||
nd4j::ops::scatter_max op;
|
nd4j::ops::scatter_max op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -901,7 +901,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8});
|
||||||
|
|
||||||
nd4j::ops::scatter_max op;
|
nd4j::ops::scatter_max op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {true}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {true}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -918,7 +918,7 @@ TEST_F(ParityOpsTests, scatterMax_test5) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10});
|
||||||
|
|
||||||
nd4j::ops::scatter_max op;
|
nd4j::ops::scatter_max op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -935,7 +935,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2});
|
||||||
|
|
||||||
nd4j::ops::scatter_max op;
|
nd4j::ops::scatter_max op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -953,7 +953,7 @@ TEST_F(ParityOpsTests, scatterMin_test1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-1, 1, 3, 4});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-1, 1, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::scatter_min op;
|
nd4j::ops::scatter_min op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -970,7 +970,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 3, 1});
|
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 3, 1});
|
||||||
|
|
||||||
nd4j::ops::scatter_min op;
|
nd4j::ops::scatter_min op;
|
||||||
auto result = op.execute({&vec, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -987,7 +987,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8});
|
||||||
|
|
||||||
nd4j::ops::scatter_min op;
|
nd4j::ops::scatter_min op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1004,7 +1004,7 @@ TEST_F(ParityOpsTests, scatterMin_test4) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8});
|
||||||
|
|
||||||
nd4j::ops::scatter_min op;
|
nd4j::ops::scatter_min op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1036,7 +1036,7 @@ TEST_F(ParityOpsTests, scatterND_test1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {false, true});
|
auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1058,7 +1058,7 @@ TEST_F(ParityOpsTests, scatterND_test2) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {});
|
auto result = op.evaluate({&indices, &updates, &shape}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1083,7 +1083,7 @@ TEST_F(ParityOpsTests, scatterND_test3) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {false, true});
|
auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1103,7 +1103,7 @@ TEST_F(ParityOpsTests, scatterND_test4) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f});
|
auto exp = NDArrayFactory::create<float>('c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f});
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {});
|
auto result = op.evaluate({&indices, &updates, &shape}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1123,7 +1123,7 @@ TEST_F(ParityOpsTests, scatterND_test5) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
auto exp = NDArrayFactory::create<float>('c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true});
|
auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1150,7 +1150,7 @@ TEST_F(ParityOpsTests, scatterND_test6) {
|
||||||
updates.linspace(1);
|
updates.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {});
|
auto result = op.evaluate({&indices, &updates, &shape}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1177,7 +1177,7 @@ TEST_F(ParityOpsTests, scatterND_test7) {
|
||||||
updates.linspace(1);
|
updates.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true, true});
|
auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1198,7 +1198,7 @@ TEST_F(ParityOpsTests, scatterND_test8) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {6,4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
auto exp = NDArrayFactory::create<float>('c', {6,4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {true});
|
auto result = op.evaluate({&indices, &updates, &shape}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1233,7 +1233,7 @@ TEST_F(ParityOpsTests, scatterND_add_test1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f});
|
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f});
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_add op;
|
nd4j::ops::scatter_nd_add op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1256,7 +1256,7 @@ TEST_F(ParityOpsTests, scatterND_add_test2) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_add op;
|
nd4j::ops::scatter_nd_add op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1280,7 +1280,7 @@ TEST_F(ParityOpsTests, scatterND_add_test3) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_add op;
|
nd4j::ops::scatter_nd_add op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1307,7 +1307,7 @@ TEST_F(ParityOpsTests, scatterND_add_test4) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_add op;
|
nd4j::ops::scatter_nd_add op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1343,7 +1343,7 @@ TEST_F(ParityOpsTests, scatterND_add_test5) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_add op;
|
nd4j::ops::scatter_nd_add op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1376,7 +1376,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f});
|
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f});
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_sub op;
|
nd4j::ops::scatter_nd_sub op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1399,7 +1399,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test2) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_sub op;
|
nd4j::ops::scatter_nd_sub op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1424,7 +1424,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test3) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_sub op;
|
nd4j::ops::scatter_nd_sub op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1451,7 +1451,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test4) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_sub op;
|
nd4j::ops::scatter_nd_sub op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1487,7 +1487,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test5) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_sub op;
|
nd4j::ops::scatter_nd_sub op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1508,7 +1508,7 @@ TEST_F(ParityOpsTests, scatterND_update_test1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f});
|
auto exp = NDArrayFactory::create<float>('c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f});
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_update op;
|
nd4j::ops::scatter_nd_update op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1531,7 +1531,7 @@ TEST_F(ParityOpsTests, scatterND_update_test2) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_update op;
|
nd4j::ops::scatter_nd_update op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1555,7 +1555,7 @@ TEST_F(ParityOpsTests, scatterND_update_test3) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_update op;
|
nd4j::ops::scatter_nd_update op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1583,7 +1583,7 @@ TEST_F(ParityOpsTests, scatterND_update_test4) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_update op;
|
nd4j::ops::scatter_nd_update op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1619,7 +1619,7 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd_update op;
|
nd4j::ops::scatter_nd_update op;
|
||||||
auto result = op.execute({&input, &indices, &updates}, {}, {});
|
auto result = op.evaluate({&input, &indices, &updates}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1652,7 +1652,7 @@ TEST_F(ParityOpsTests, scatter_update_1) {
|
||||||
NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32);
|
NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
nd4j::ops::scatter_update op;
|
nd4j::ops::scatter_update op;
|
||||||
auto results = op.execute({&x, &updates}, {}, {6, 1,1, 2,1,0});
|
auto results = op.evaluate({&x, &updates}, {}, {6, 1,1, 2,1,0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
// x.printBuffer();
|
// x.printBuffer();
|
||||||
|
@ -1672,7 +1672,7 @@ TEST_F(ParityOpsTests, scatter_update_2) {
|
||||||
NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32);
|
NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
nd4j::ops::scatter_update op;
|
nd4j::ops::scatter_update op;
|
||||||
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,1,0});
|
auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,1,0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1691,7 +1691,7 @@ TEST_F(ParityOpsTests, scatter_update_3) {
|
||||||
NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32);
|
NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
nd4j::ops::scatter_update op;
|
nd4j::ops::scatter_update op;
|
||||||
auto results = op.execute({&x, &updates}, {}, {6, 2,1,2, 2,1,0});
|
auto results = op.evaluate({&x, &updates}, {}, {6, 2,1,2, 2,1,0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1710,7 +1710,7 @@ TEST_F(ParityOpsTests, scatter_update_4) {
|
||||||
NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32);
|
NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
nd4j::ops::scatter_update op;
|
nd4j::ops::scatter_update op;
|
||||||
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,3,0});
|
auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,3,0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
|
|
@ -257,7 +257,7 @@ TEST_F(RNGTests, Test_Gaussian_21) {
|
||||||
ASSERT_FALSE(x0.equalsTo(nexp1));
|
ASSERT_FALSE(x0.equalsTo(nexp1));
|
||||||
ASSERT_FALSE(x0.equalsTo(nexp2));
|
ASSERT_FALSE(x0.equalsTo(nexp2));
|
||||||
nd4j::ops::moments op;
|
nd4j::ops::moments op;
|
||||||
auto result = op.execute({&x0}, {}, {});
|
auto result = op.evaluate({&x0}, {}, {});
|
||||||
//x0.printIndexedBuffer("X0 Normal");
|
//x0.printIndexedBuffer("X0 Normal");
|
||||||
//x1.printIndexedBuffer("X1 Normal");
|
//x1.printIndexedBuffer("X1 Normal");
|
||||||
ASSERT_TRUE(result->status() == Status::OK());
|
ASSERT_TRUE(result->status() == Status::OK());
|
||||||
|
@ -289,7 +289,7 @@ TEST_F(RNGTests, Test_Gaussian_22) {
|
||||||
ASSERT_FALSE(x0.equalsTo(nexp1));
|
ASSERT_FALSE(x0.equalsTo(nexp1));
|
||||||
ASSERT_FALSE(x0.equalsTo(nexp2));
|
ASSERT_FALSE(x0.equalsTo(nexp2));
|
||||||
nd4j::ops::moments op;
|
nd4j::ops::moments op;
|
||||||
auto result = op.execute({&x0}, {}, {});
|
auto result = op.evaluate({&x0}, {}, {});
|
||||||
//x0.printIndexedBuffer("X0 Normal");
|
//x0.printIndexedBuffer("X0 Normal");
|
||||||
//x1.printIndexedBuffer("X1 Normal");
|
//x1.printIndexedBuffer("X1 Normal");
|
||||||
ASSERT_TRUE(result->status() == Status::OK());
|
ASSERT_TRUE(result->status() == Status::OK());
|
||||||
|
@ -412,14 +412,14 @@ TEST_F(RNGTests, Test_Truncated_21) {
|
||||||
ASSERT_NEAR(mean.e<float>(0), 1.f, 0.002);
|
ASSERT_NEAR(mean.e<float>(0), 1.f, 0.002);
|
||||||
ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5);
|
ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5);
|
||||||
nd4j::ops::moments op;
|
nd4j::ops::moments op;
|
||||||
auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
|
||||||
// result->at(0)->printBuffer("MEAN");
|
// result->at(0)->printBuffer("MEAN");
|
||||||
// result->at(1)->printBuffer("VARIANCE");
|
// result->at(1)->printBuffer("VARIANCE");
|
||||||
delete result;
|
delete result;
|
||||||
nd4j::ops::reduce_min minOp;
|
nd4j::ops::reduce_min minOp;
|
||||||
nd4j::ops::reduce_max maxOp;
|
nd4j::ops::reduce_max maxOp;
|
||||||
auto minRes = minOp.execute({&x1}, {}, {}, {});
|
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||||
auto maxRes = maxOp.execute({&x0}, {}, {}, {});
|
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||||
// minRes->at(0)->printBuffer("MIN for Truncated");
|
// minRes->at(0)->printBuffer("MIN for Truncated");
|
||||||
// maxRes->at(0)->printBuffer("MAX for Truncated");
|
// maxRes->at(0)->printBuffer("MAX for Truncated");
|
||||||
|
|
||||||
|
@ -459,14 +459,14 @@ TEST_F(RNGTests, Test_Truncated_22) {
|
||||||
ASSERT_NEAR(mean.e<float>(0), 2.f, 0.01);
|
ASSERT_NEAR(mean.e<float>(0), 2.f, 0.01);
|
||||||
ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52);
|
ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52);
|
||||||
nd4j::ops::moments op;
|
nd4j::ops::moments op;
|
||||||
auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
|
||||||
// result->at(0)->printBuffer("MEAN");
|
// result->at(0)->printBuffer("MEAN");
|
||||||
// result->at(1)->printBuffer("VARIANCE");
|
// result->at(1)->printBuffer("VARIANCE");
|
||||||
delete result;
|
delete result;
|
||||||
nd4j::ops::reduce_min minOp;
|
nd4j::ops::reduce_min minOp;
|
||||||
nd4j::ops::reduce_max maxOp;
|
nd4j::ops::reduce_max maxOp;
|
||||||
auto minRes = minOp.execute({&x1}, {}, {}, {});
|
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||||
auto maxRes = maxOp.execute({&x0}, {}, {}, {});
|
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||||
// minRes->at(0)->printBuffer("MIN for Truncated2");
|
// minRes->at(0)->printBuffer("MIN for Truncated2");
|
||||||
// maxRes->at(0)->printBuffer("MAX for Truncated2");
|
// maxRes->at(0)->printBuffer("MAX for Truncated2");
|
||||||
|
|
||||||
|
@ -506,14 +506,14 @@ TEST_F(RNGTests, Test_Truncated_23) {
|
||||||
ASSERT_NEAR(mean.e<float>(0), 0.f, 0.01);
|
ASSERT_NEAR(mean.e<float>(0), 0.f, 0.01);
|
||||||
ASSERT_NEAR(deviation.e<float>(0), 1.f, 0.5);
|
ASSERT_NEAR(deviation.e<float>(0), 1.f, 0.5);
|
||||||
nd4j::ops::moments op;
|
nd4j::ops::moments op;
|
||||||
auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
auto result = op.evaluate({&x0});
|
||||||
// result->at(0)->printBuffer("MEAN");
|
// result->at(0)->printBuffer("MEAN");
|
||||||
// result->at(1)->printBuffer("VARIANCE");
|
// result->at(1)->printBuffer("VARIANCE");
|
||||||
delete result;
|
delete result;
|
||||||
nd4j::ops::reduce_min minOp;
|
nd4j::ops::reduce_min minOp;
|
||||||
nd4j::ops::reduce_max maxOp;
|
nd4j::ops::reduce_max maxOp;
|
||||||
auto minRes = minOp.execute({&x1}, {}, {}, {});
|
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||||
auto maxRes = maxOp.execute({&x0}, {}, {}, {});
|
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||||
// minRes->at(0)->printBuffer("MIN for Truncated3");
|
// minRes->at(0)->printBuffer("MIN for Truncated3");
|
||||||
// maxRes->at(0)->printBuffer("MAX for Truncated3");
|
// maxRes->at(0)->printBuffer("MAX for Truncated3");
|
||||||
|
|
||||||
|
@ -686,7 +686,7 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::random_normal op;
|
nd4j::ops::random_normal op;
|
||||||
auto result = op.execute({&x}, {0.0, 1.0f}, {});
|
auto result = op.evaluate({&x}, {0.0, 1.0f}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -707,7 +707,7 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::random_bernoulli op;
|
nd4j::ops::random_bernoulli op;
|
||||||
auto result = op.execute({&x}, {0.5f}, {});
|
auto result = op.evaluate({&x}, {0.5f}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -728,7 +728,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::random_exponential op;
|
nd4j::ops::random_exponential op;
|
||||||
auto result = op.execute({&x}, {0.25f}, {0});
|
auto result = op.evaluate({&x}, {0.25f}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -752,7 +752,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::random_exponential op;
|
nd4j::ops::random_exponential op;
|
||||||
auto result = op.execute({&x, &y}, {0.25f}, {0});
|
auto result = op.evaluate({&x, &y}, {0.25f}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -776,7 +776,7 @@ TEST_F(RNGTests, Test_PoissonDistribution_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::random_poisson op;
|
nd4j::ops::random_poisson op;
|
||||||
auto result = op.execute({&x, &la}, {}, {});
|
auto result = op.evaluate({&x, &la}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -796,7 +796,7 @@ TEST_F(RNGTests, Test_GammaDistribution_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::random_gamma op;
|
nd4j::ops::random_gamma op;
|
||||||
auto result = op.execute({&x, &al}, {}, {});
|
auto result = op.evaluate({&x, &al}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -817,7 +817,7 @@ TEST_F(RNGTests, Test_GammaDistribution_2) {
|
||||||
be.assign(1.0);
|
be.assign(1.0);
|
||||||
|
|
||||||
nd4j::ops::random_gamma op;
|
nd4j::ops::random_gamma op;
|
||||||
auto result = op.execute({&x, &al, &be}, {}, {});
|
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -838,7 +838,7 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
|
||||||
be.assign(2.0);
|
be.assign(2.0);
|
||||||
|
|
||||||
nd4j::ops::random_gamma op;
|
nd4j::ops::random_gamma op;
|
||||||
auto result = op.execute({&x, &al, &be}, {}, {});
|
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -857,7 +857,7 @@ TEST_F(RNGTests, Test_UniformDistribution_04) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::randomuniform op;
|
nd4j::ops::randomuniform op;
|
||||||
auto result = op.execute({&x, &al, &be}, {}, {DataType::INT32});
|
auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -878,7 +878,7 @@ namespace nd4j {
|
||||||
auto min = NDArrayFactory::create(0.0);
|
auto min = NDArrayFactory::create(0.0);
|
||||||
auto max = NDArrayFactory::create(1.0);
|
auto max = NDArrayFactory::create(1.0);
|
||||||
nd4j::ops::randomuniform op;
|
nd4j::ops::randomuniform op;
|
||||||
op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, false);
|
op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, {}, false);
|
||||||
|
|
||||||
list.emplace_back(arrayR);
|
list.emplace_back(arrayR);
|
||||||
}
|
}
|
||||||
|
@ -1013,14 +1013,14 @@ TEST_F(RNGTests, test_multinomial_1) {
|
||||||
|
|
||||||
nd4j::ops::random_multinomial op;
|
nd4j::ops::random_multinomial op;
|
||||||
RandomGenerator rng(1234, 1234);
|
RandomGenerator rng(1234, 1234);
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, false) );
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, {}, false) );
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||||
NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64);
|
NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64);
|
||||||
|
|
||||||
auto result = op.execute({ &probsZ, &samples }, { }, { 1, INT64 });
|
auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
|
||||||
auto outputZ = result->at(0);
|
auto outputZ = result->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
@ -1038,7 +1038,7 @@ TEST_F(RNGTests, test_multinomial_2) {
|
||||||
|
|
||||||
nd4j::ops::random_multinomial op;
|
nd4j::ops::random_multinomial op;
|
||||||
RandomGenerator rng(1234, 1234);
|
RandomGenerator rng(1234, 1234);
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
@ -1047,7 +1047,7 @@ TEST_F(RNGTests, test_multinomial_2) {
|
||||||
NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64);
|
NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64);
|
||||||
|
|
||||||
rng.setStates(1234, 1234);
|
rng.setStates(1234, 1234);
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, false));
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, {}, false));
|
||||||
ASSERT_TRUE(expected2.isSameShape(output2));
|
ASSERT_TRUE(expected2.isSameShape(output2));
|
||||||
ASSERT_TRUE(expected2.equalsTo(output2));
|
ASSERT_TRUE(expected2.equalsTo(output2));
|
||||||
}
|
}
|
||||||
|
@ -1061,10 +1061,10 @@ TEST_F(RNGTests, test_multinomial_3) {
|
||||||
RandomGenerator rng(1234, 1234);
|
RandomGenerator rng(1234, 1234);
|
||||||
|
|
||||||
nd4j::ops::random_multinomial op;
|
nd4j::ops::random_multinomial op;
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, false));
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false));
|
||||||
|
|
||||||
rng.setStates(1234, 1234);
|
rng.setStates(1234, 1234);
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
}
|
}
|
||||||
|
@ -1078,10 +1078,10 @@ TEST_F(RNGTests, test_multinomial_4) {
|
||||||
|
|
||||||
RandomGenerator rng(1234, 1234);
|
RandomGenerator rng(1234, 1234);
|
||||||
nd4j::ops::random_multinomial op;
|
nd4j::ops::random_multinomial op;
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, false));
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, {}, false));
|
||||||
|
|
||||||
rng.setStates(1234, 1234);
|
rng.setStates(1234, 1234);
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, false));
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, {}, false));
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
}
|
}
|
||||||
|
@ -1101,7 +1101,7 @@ TEST_F(RNGTests, test_multinomial_5) {
|
||||||
NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64);
|
NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64);
|
||||||
RandomGenerator rng(1234, 1234);
|
RandomGenerator rng(1234, 1234);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, false));
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false));
|
||||||
|
|
||||||
auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||||
auto mean = output.meanNumber();
|
auto mean = output.meanNumber();
|
||||||
|
@ -1115,7 +1115,7 @@ TEST_F(RNGTests, test_multinomial_5) {
|
||||||
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultR = op.execute({ &probs, &samples }, { }, { 1 });
|
auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 });
|
||||||
auto outputR = resultR->at(0);
|
auto outputR = resultR->at(0);
|
||||||
ASSERT_EQ(Status::OK(), resultR->status());
|
ASSERT_EQ(Status::OK(), resultR->status());
|
||||||
|
|
||||||
|
@ -1148,7 +1148,7 @@ TEST_F(RNGTests, test_multinomial_6) {
|
||||||
// without seed
|
// without seed
|
||||||
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
|
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
auto resultR = op.execute({ &probsR, &samples }, { }, { 0 });
|
auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 });
|
||||||
auto outputR = resultR->at(0);
|
auto outputR = resultR->at(0);
|
||||||
ASSERT_EQ(Status::OK(), resultR->status());
|
ASSERT_EQ(Status::OK(), resultR->status());
|
||||||
|
|
||||||
|
@ -1180,7 +1180,7 @@ TEST_F(RNGTests, test_multinomial_6) {
|
||||||
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
|
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
|
||||||
NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64);
|
NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
|
||||||
|
|
||||||
NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
|
NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ TEST_F(ScalarTests, Test_Concat_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&t, &u, &v}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ TEST_F(ScalarTests, Test_Concat_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
auto exp = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&t, &u, &v}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ TEST_F(ScalarTests, Test_Concat_3) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
auto exp = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&t, &u, &v}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ TEST_F(ScalarTests, Test_ExpandDims_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1}, {2.0f});
|
auto exp = NDArrayFactory::create<float>('c', {1}, {2.0f});
|
||||||
|
|
||||||
nd4j::ops::expand_dims op;
|
nd4j::ops::expand_dims op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -171,7 +171,7 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
|
||||||
auto exp = NDArrayFactory::create<float>(2.0f);
|
auto exp = NDArrayFactory::create<float>(2.0f);
|
||||||
|
|
||||||
nd4j::ops::squeeze op;
|
nd4j::ops::squeeze op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -188,7 +188,7 @@ TEST_F(ScalarTests, Test_Reshape_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
|
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {-99, 1, 1, 1});
|
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -205,7 +205,7 @@ TEST_F(ScalarTests, Test_Permute_1) {
|
||||||
auto exp = NDArrayFactory::create<float>(3.0f);
|
auto exp = NDArrayFactory::create<float>(3.0f);
|
||||||
|
|
||||||
nd4j::ops::permute op;
|
nd4j::ops::permute op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -224,7 +224,7 @@ TEST_F(ScalarTests, Test_Stack_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto result = op.execute({&t, &u, &v}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -243,7 +243,7 @@ TEST_F(ScalarTests, Test_Stack_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {4, 1, 1}, {1, 2, 3, 4});
|
auto exp = NDArrayFactory::create<float>('c', {4, 1, 1}, {1, 2, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::stack op;
|
nd4j::ops::stack op;
|
||||||
auto result = op.execute({&t, &u, &v, &w}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -265,7 +265,7 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4});
|
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&t, &u, &v, &w}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -285,7 +285,7 @@ TEST_F(ScalarTests, Test_Concat_Scalar_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4});
|
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&t, &u, &v, &w}, {}, {1});
|
auto result = op.evaluate({&t, &u, &v, &w}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -307,7 +307,7 @@ TEST_F(ShapeTests, Tests_Transpose_119_2) {
|
||||||
auto exp = x.transpose();
|
auto exp = x.transpose();
|
||||||
|
|
||||||
nd4j::ops::transpose op;
|
nd4j::ops::transpose op;
|
||||||
auto result = op.execute({&x},{}, {});
|
auto result = op.evaluate({&x});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -68,7 +68,7 @@ TEST_F(SingleDimTests, Test_Concat_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6});
|
auto exp = NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6});
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
auto result = op.execute({&x, &y}, {}, {0});
|
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::expand_dims op;
|
nd4j::ops::expand_dims op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::expand_dims op;
|
nd4j::ops::expand_dims op;
|
||||||
auto result = op.execute({&x}, {}, {1});
|
auto result = op.evaluate({&x}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -140,7 +140,7 @@ TEST_F(SingleDimTests, Test_Squeeze_1) {
|
||||||
auto exp = NDArrayFactory::create<float>(3.0f);
|
auto exp = NDArrayFactory::create<float>(3.0f);
|
||||||
|
|
||||||
nd4j::ops::squeeze op;
|
nd4j::ops::squeeze op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::squeeze op;
|
nd4j::ops::squeeze op;
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -173,7 +173,7 @@ TEST_F(SingleDimTests, Test_Reshape_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {-99, 3});
|
auto result = op.evaluate({&x}, {}, {-99, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -189,7 +189,7 @@ TEST_F(SingleDimTests, Test_Reshape_2) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
nd4j::ops::reshape op;
|
||||||
auto result = op.execute({&x}, {}, {-99, 1, 3});
|
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -206,7 +206,7 @@ TEST_F(SingleDimTests, Test_Permute_1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
|
||||||
nd4j::ops::permute op;
|
nd4j::ops::permute op;
|
||||||
auto result = op.execute({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -23,8 +23,10 @@ public final class DType {
|
||||||
public static final byte QINT16 = 16;
|
public static final byte QINT16 = 16;
|
||||||
public static final byte BFLOAT16 = 17;
|
public static final byte BFLOAT16 = 17;
|
||||||
public static final byte UTF8 = 50;
|
public static final byte UTF8 = 50;
|
||||||
|
public static final byte UTF16 = 51;
|
||||||
|
public static final byte UTF32 = 52;
|
||||||
|
|
||||||
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", };
|
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", "UTF16", "UTF32", };
|
||||||
|
|
||||||
public static String name(int e) { return names[e]; }
|
public static String name(int e) { return names[e]; }
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,4 +88,3 @@ public final class FlatVariable extends Table {
|
||||||
public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); }
|
public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3096,6 +3096,9 @@ public native void setGraphContextInputArray(OpaqueContext ptr, int index, Point
|
||||||
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
||||||
public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
||||||
|
public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, int numberOfArguments);
|
||||||
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||||
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
|
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
|
||||||
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
|
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
|
||||||
|
@ -6435,6 +6438,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
|
public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
|
||||||
public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments);
|
public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments);
|
||||||
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
|
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") IntPointer arguments, int numberOfArguments);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") IntBuffer arguments, int numberOfArguments);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") int[] arguments, int numberOfArguments);
|
||||||
|
|
||||||
public native void setTArguments(@StdVector DoublePointer tArgs);
|
public native void setTArguments(@StdVector DoublePointer tArgs);
|
||||||
public native void setTArguments(@StdVector DoubleBuffer tArgs);
|
public native void setTArguments(@StdVector DoubleBuffer tArgs);
|
||||||
|
@ -6444,6 +6450,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs);
|
public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs);
|
||||||
public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs);
|
public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs);
|
||||||
public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs);
|
public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntPointer dArgs);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
|
||||||
|
|
||||||
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
|
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
|
||||||
|
|
||||||
|
@ -6547,6 +6556,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
public native @StdVector DoublePointer getTArguments();
|
public native @StdVector DoublePointer getTArguments();
|
||||||
public native @StdVector IntPointer getIArguments();
|
public native @StdVector IntPointer getIArguments();
|
||||||
public native @Cast("bool*") @StdVector BooleanPointer getBArguments();
|
public native @Cast("bool*") @StdVector BooleanPointer getBArguments();
|
||||||
|
public native @Cast("nd4j::DataType*") @StdVector IntPointer getDArguments();
|
||||||
public native @StdVector IntPointer getAxis();
|
public native @StdVector IntPointer getAxis();
|
||||||
|
|
||||||
public native @Cast("samediff::Engine") int engine();
|
public native @Cast("samediff::Engine") int engine();
|
||||||
|
@ -6554,6 +6564,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
public native @Cast("size_t") long numT();
|
public native @Cast("size_t") long numT();
|
||||||
public native @Cast("size_t") long numI();
|
public native @Cast("size_t") long numI();
|
||||||
public native @Cast("size_t") long numB();
|
public native @Cast("size_t") long numB();
|
||||||
|
public native @Cast("size_t") long numD();
|
||||||
|
|
||||||
public native IntIntPair input(int idx);
|
public native IntIntPair input(int idx);
|
||||||
|
|
||||||
|
@ -9418,39 +9429,43 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
*/
|
*/
|
||||||
public native @Cast("Nd4jStatus") int execute(Context block);
|
public native @Cast("Nd4jStatus") int execute(Context block);
|
||||||
|
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||||
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||||
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||||
|
|
||||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
||||||
|
@ -9649,8 +9664,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
public BooleanOp(Pointer p) { super(p); }
|
public BooleanOp(Pointer p) { super(p); }
|
||||||
|
|
||||||
public native @Cast("bool") boolean evaluate(@ByRef NDArrayVector args);
|
|
||||||
public native @Cast("bool") boolean evaluate(@ByRef Context block);
|
public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args);
|
||||||
|
public native @Cast("bool") boolean verify(@ByRef Context block);
|
||||||
|
|
||||||
public native @Cast("Nd4jStatus") int execute(Context block);
|
public native @Cast("Nd4jStatus") int execute(Context block);
|
||||||
|
|
||||||
|
|
|
@ -3099,6 +3099,9 @@ public native void setGraphContextInputArray(OpaqueContext ptr, int index, Point
|
||||||
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
||||||
public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
||||||
|
public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, int numberOfArguments);
|
||||||
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||||
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
|
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
|
||||||
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
|
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
|
||||||
|
@ -6438,6 +6441,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
|
public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
|
||||||
public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments);
|
public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments);
|
||||||
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
|
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") IntPointer arguments, int numberOfArguments);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") IntBuffer arguments, int numberOfArguments);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") int[] arguments, int numberOfArguments);
|
||||||
|
|
||||||
public native void setTArguments(@StdVector DoublePointer tArgs);
|
public native void setTArguments(@StdVector DoublePointer tArgs);
|
||||||
public native void setTArguments(@StdVector DoubleBuffer tArgs);
|
public native void setTArguments(@StdVector DoubleBuffer tArgs);
|
||||||
|
@ -6447,6 +6453,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs);
|
public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs);
|
||||||
public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs);
|
public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs);
|
||||||
public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs);
|
public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntPointer dArgs);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
|
||||||
|
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
|
||||||
|
|
||||||
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
|
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
|
||||||
|
|
||||||
|
@ -6550,6 +6559,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
public native @StdVector DoublePointer getTArguments();
|
public native @StdVector DoublePointer getTArguments();
|
||||||
public native @StdVector IntPointer getIArguments();
|
public native @StdVector IntPointer getIArguments();
|
||||||
public native @Cast("bool*") @StdVector BooleanPointer getBArguments();
|
public native @Cast("bool*") @StdVector BooleanPointer getBArguments();
|
||||||
|
public native @Cast("nd4j::DataType*") @StdVector IntPointer getDArguments();
|
||||||
public native @StdVector IntPointer getAxis();
|
public native @StdVector IntPointer getAxis();
|
||||||
|
|
||||||
public native @Cast("samediff::Engine") int engine();
|
public native @Cast("samediff::Engine") int engine();
|
||||||
|
@ -6557,6 +6567,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
public native @Cast("size_t") long numT();
|
public native @Cast("size_t") long numT();
|
||||||
public native @Cast("size_t") long numI();
|
public native @Cast("size_t") long numI();
|
||||||
public native @Cast("size_t") long numB();
|
public native @Cast("size_t") long numB();
|
||||||
|
public native @Cast("size_t") long numD();
|
||||||
|
|
||||||
public native IntIntPair input(int idx);
|
public native IntIntPair input(int idx);
|
||||||
|
|
||||||
|
@ -11130,7 +11141,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
|
|
||||||
// #define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
// #define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
||||||
|
|
||||||
|
// #define D_ARG(INDEX) block.getDArguments()->at(INDEX)
|
||||||
// #define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
|
// #define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
|
||||||
|
// #define I_ARG(INDEX) INT_ARG(INDEX)
|
||||||
// #define T_ARG(INDEX) block.getTArguments()->at(INDEX)
|
// #define T_ARG(INDEX) block.getTArguments()->at(INDEX)
|
||||||
// #define B_ARG(INDEX) block.getBArguments()->at(INDEX)
|
// #define B_ARG(INDEX) block.getBArguments()->at(INDEX)
|
||||||
|
|
||||||
|
@ -11629,39 +11642,43 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
*/
|
*/
|
||||||
public native @Cast("Nd4jStatus") int execute(Context block);
|
public native @Cast("Nd4jStatus") int execute(Context block);
|
||||||
|
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||||
public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||||
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||||
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector<nd4j::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/);
|
||||||
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||||
|
|
||||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
||||||
|
@ -11860,8 +11877,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
public BooleanOp(Pointer p) { super(p); }
|
public BooleanOp(Pointer p) { super(p); }
|
||||||
|
|
||||||
public native @Cast("bool") boolean evaluate(@ByRef NDArrayVector args);
|
|
||||||
public native @Cast("bool") boolean evaluate(@ByRef Context block);
|
public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args);
|
||||||
|
public native @Cast("bool") boolean verify(@ByRef Context block);
|
||||||
|
|
||||||
public native @Cast("Nd4jStatus") int execute(Context block);
|
public native @Cast("Nd4jStatus") int execute(Context block);
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//Generated by flatc compiler (version 1.9.0)
|
//Generated by flatc compiler (version 1.10.0)
|
||||||
//If you make any local changes, they will be lost
|
//If you make any local changes, they will be lost
|
||||||
//source: graph.fbs
|
//source: graph.fbs
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ public final class GraphInferenceServerGrpc {
|
||||||
|
|
||||||
private GraphInferenceServerGrpc() {}
|
private GraphInferenceServerGrpc() {}
|
||||||
|
|
||||||
public static final String SERVICE_NAME = "nd4j.graph.GraphInferenceServer";
|
public static final String SERVICE_NAME = "org.nd4j.graph.GraphInferenceServer";
|
||||||
|
|
||||||
// Static method descriptors that strictly reflect the proto.
|
// Static method descriptors that strictly reflect the proto.
|
||||||
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
|
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
|
||||||
|
@ -81,7 +81,7 @@ public final class GraphInferenceServerGrpc {
|
||||||
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatGraph, org.nd4j.graph.FlatResponse>newBuilder()
|
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatGraph, org.nd4j.graph.FlatResponse>newBuilder()
|
||||||
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
|
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
|
||||||
.setFullMethodName(generateFullMethodName(
|
.setFullMethodName(generateFullMethodName(
|
||||||
"nd4j.graph.GraphInferenceServer", "RegisterGraph"))
|
"org.nd4j.graph.GraphInferenceServer", "RegisterGraph"))
|
||||||
.setSampledToLocalTracing(true)
|
.setSampledToLocalTracing(true)
|
||||||
.setRequestMarshaller(FlatbuffersUtils.marshaller(
|
.setRequestMarshaller(FlatbuffersUtils.marshaller(
|
||||||
org.nd4j.graph.FlatGraph.class, getExtractorOfFlatGraph()))
|
org.nd4j.graph.FlatGraph.class, getExtractorOfFlatGraph()))
|
||||||
|
@ -128,7 +128,7 @@ public final class GraphInferenceServerGrpc {
|
||||||
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatDropRequest, org.nd4j.graph.FlatResponse>newBuilder()
|
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatDropRequest, org.nd4j.graph.FlatResponse>newBuilder()
|
||||||
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
|
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
|
||||||
.setFullMethodName(generateFullMethodName(
|
.setFullMethodName(generateFullMethodName(
|
||||||
"nd4j.graph.GraphInferenceServer", "ForgetGraph"))
|
"org.nd4j.graph.GraphInferenceServer", "ForgetGraph"))
|
||||||
.setSampledToLocalTracing(true)
|
.setSampledToLocalTracing(true)
|
||||||
.setRequestMarshaller(FlatbuffersUtils.marshaller(
|
.setRequestMarshaller(FlatbuffersUtils.marshaller(
|
||||||
org.nd4j.graph.FlatDropRequest.class, getExtractorOfFlatDropRequest()))
|
org.nd4j.graph.FlatDropRequest.class, getExtractorOfFlatDropRequest()))
|
||||||
|
@ -142,6 +142,39 @@ public final class GraphInferenceServerGrpc {
|
||||||
return getForgetGraphMethod;
|
return getForgetGraphMethod;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
|
||||||
|
@Deprecated // Use {@link #getReplaceGraphMethod()} instead.
|
||||||
|
public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
|
||||||
|
org.nd4j.graph.FlatResponse> METHOD_REPLACE_GRAPH = getReplaceGraphMethod();
|
||||||
|
|
||||||
|
private static volatile io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
|
||||||
|
org.nd4j.graph.FlatResponse> getReplaceGraphMethod;
|
||||||
|
|
||||||
|
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
|
||||||
|
public static io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph,
|
||||||
|
org.nd4j.graph.FlatResponse> getReplaceGraphMethod() {
|
||||||
|
io.grpc.MethodDescriptor<org.nd4j.graph.FlatGraph, org.nd4j.graph.FlatResponse> getReplaceGraphMethod;
|
||||||
|
if ((getReplaceGraphMethod = GraphInferenceServerGrpc.getReplaceGraphMethod) == null) {
|
||||||
|
synchronized (GraphInferenceServerGrpc.class) {
|
||||||
|
if ((getReplaceGraphMethod = GraphInferenceServerGrpc.getReplaceGraphMethod) == null) {
|
||||||
|
GraphInferenceServerGrpc.getReplaceGraphMethod = getReplaceGraphMethod =
|
||||||
|
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatGraph, org.nd4j.graph.FlatResponse>newBuilder()
|
||||||
|
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
|
||||||
|
.setFullMethodName(generateFullMethodName(
|
||||||
|
"org.nd4j.graph.GraphInferenceServer", "ReplaceGraph"))
|
||||||
|
.setSampledToLocalTracing(true)
|
||||||
|
.setRequestMarshaller(FlatbuffersUtils.marshaller(
|
||||||
|
org.nd4j.graph.FlatGraph.class, getExtractorOfFlatGraph()))
|
||||||
|
.setResponseMarshaller(FlatbuffersUtils.marshaller(
|
||||||
|
org.nd4j.graph.FlatResponse.class, getExtractorOfFlatResponse()))
|
||||||
|
.setSchemaDescriptor(null)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return getReplaceGraphMethod;
|
||||||
|
}
|
||||||
|
|
||||||
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
|
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
|
||||||
@Deprecated // Use {@link #getInferenceRequestMethod()} instead.
|
@Deprecated // Use {@link #getInferenceRequestMethod()} instead.
|
||||||
public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatInferenceRequest,
|
public static final io.grpc.MethodDescriptor<org.nd4j.graph.FlatInferenceRequest,
|
||||||
|
@ -189,7 +222,7 @@ public final class GraphInferenceServerGrpc {
|
||||||
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatInferenceRequest, org.nd4j.graph.FlatResult>newBuilder()
|
io.grpc.MethodDescriptor.<org.nd4j.graph.FlatInferenceRequest, org.nd4j.graph.FlatResult>newBuilder()
|
||||||
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
|
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
|
||||||
.setFullMethodName(generateFullMethodName(
|
.setFullMethodName(generateFullMethodName(
|
||||||
"nd4j.graph.GraphInferenceServer", "InferenceRequest"))
|
"org.nd4j.graph.GraphInferenceServer", "InferenceRequest"))
|
||||||
.setSampledToLocalTracing(true)
|
.setSampledToLocalTracing(true)
|
||||||
.setRequestMarshaller(FlatbuffersUtils.marshaller(
|
.setRequestMarshaller(FlatbuffersUtils.marshaller(
|
||||||
org.nd4j.graph.FlatInferenceRequest.class, getExtractorOfFlatInferenceRequest()))
|
org.nd4j.graph.FlatInferenceRequest.class, getExtractorOfFlatInferenceRequest()))
|
||||||
|
@ -244,6 +277,13 @@ public final class GraphInferenceServerGrpc {
|
||||||
asyncUnimplementedUnaryCall(getForgetGraphMethod(), responseObserver);
|
asyncUnimplementedUnaryCall(getForgetGraphMethod(), responseObserver);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*/
|
||||||
|
public void replaceGraph(org.nd4j.graph.FlatGraph request,
|
||||||
|
io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse> responseObserver) {
|
||||||
|
asyncUnimplementedUnaryCall(getReplaceGraphMethod(), responseObserver);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*/
|
*/
|
||||||
public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request,
|
public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request,
|
||||||
|
@ -267,6 +307,13 @@ public final class GraphInferenceServerGrpc {
|
||||||
org.nd4j.graph.FlatDropRequest,
|
org.nd4j.graph.FlatDropRequest,
|
||||||
org.nd4j.graph.FlatResponse>(
|
org.nd4j.graph.FlatResponse>(
|
||||||
this, METHODID_FORGET_GRAPH)))
|
this, METHODID_FORGET_GRAPH)))
|
||||||
|
.addMethod(
|
||||||
|
getReplaceGraphMethod(),
|
||||||
|
asyncUnaryCall(
|
||||||
|
new MethodHandlers<
|
||||||
|
org.nd4j.graph.FlatGraph,
|
||||||
|
org.nd4j.graph.FlatResponse>(
|
||||||
|
this, METHODID_REPLACE_GRAPH)))
|
||||||
.addMethod(
|
.addMethod(
|
||||||
getInferenceRequestMethod(),
|
getInferenceRequestMethod(),
|
||||||
asyncUnaryCall(
|
asyncUnaryCall(
|
||||||
|
@ -312,6 +359,14 @@ public final class GraphInferenceServerGrpc {
|
||||||
getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request, responseObserver);
|
getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request, responseObserver);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*/
|
||||||
|
public void replaceGraph(org.nd4j.graph.FlatGraph request,
|
||||||
|
io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse> responseObserver) {
|
||||||
|
asyncUnaryCall(
|
||||||
|
getChannel().newCall(getReplaceGraphMethod(), getCallOptions()), request, responseObserver);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*/
|
*/
|
||||||
public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request,
|
public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request,
|
||||||
|
@ -353,6 +408,13 @@ public final class GraphInferenceServerGrpc {
|
||||||
getChannel(), getForgetGraphMethod(), getCallOptions(), request);
|
getChannel(), getForgetGraphMethod(), getCallOptions(), request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*/
|
||||||
|
public org.nd4j.graph.FlatResponse replaceGraph(org.nd4j.graph.FlatGraph request) {
|
||||||
|
return blockingUnaryCall(
|
||||||
|
getChannel(), getReplaceGraphMethod(), getCallOptions(), request);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*/
|
*/
|
||||||
public org.nd4j.graph.FlatResult inferenceRequest(org.nd4j.graph.FlatInferenceRequest request) {
|
public org.nd4j.graph.FlatResult inferenceRequest(org.nd4j.graph.FlatInferenceRequest request) {
|
||||||
|
@ -395,6 +457,14 @@ public final class GraphInferenceServerGrpc {
|
||||||
getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request);
|
getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*/
|
||||||
|
public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResponse> replaceGraph(
|
||||||
|
org.nd4j.graph.FlatGraph request) {
|
||||||
|
return futureUnaryCall(
|
||||||
|
getChannel().newCall(getReplaceGraphMethod(), getCallOptions()), request);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*/
|
*/
|
||||||
public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResult> inferenceRequest(
|
public com.google.common.util.concurrent.ListenableFuture<org.nd4j.graph.FlatResult> inferenceRequest(
|
||||||
|
@ -406,7 +476,8 @@ public final class GraphInferenceServerGrpc {
|
||||||
|
|
||||||
private static final int METHODID_REGISTER_GRAPH = 0;
|
private static final int METHODID_REGISTER_GRAPH = 0;
|
||||||
private static final int METHODID_FORGET_GRAPH = 1;
|
private static final int METHODID_FORGET_GRAPH = 1;
|
||||||
private static final int METHODID_INFERENCE_REQUEST = 2;
|
private static final int METHODID_REPLACE_GRAPH = 2;
|
||||||
|
private static final int METHODID_INFERENCE_REQUEST = 3;
|
||||||
|
|
||||||
private static final class MethodHandlers<Req, Resp> implements
|
private static final class MethodHandlers<Req, Resp> implements
|
||||||
io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,
|
io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,
|
||||||
|
@ -433,6 +504,10 @@ public final class GraphInferenceServerGrpc {
|
||||||
serviceImpl.forgetGraph((org.nd4j.graph.FlatDropRequest) request,
|
serviceImpl.forgetGraph((org.nd4j.graph.FlatDropRequest) request,
|
||||||
(io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse>) responseObserver);
|
(io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse>) responseObserver);
|
||||||
break;
|
break;
|
||||||
|
case METHODID_REPLACE_GRAPH:
|
||||||
|
serviceImpl.replaceGraph((org.nd4j.graph.FlatGraph) request,
|
||||||
|
(io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResponse>) responseObserver);
|
||||||
|
break;
|
||||||
case METHODID_INFERENCE_REQUEST:
|
case METHODID_INFERENCE_REQUEST:
|
||||||
serviceImpl.inferenceRequest((org.nd4j.graph.FlatInferenceRequest) request,
|
serviceImpl.inferenceRequest((org.nd4j.graph.FlatInferenceRequest) request,
|
||||||
(io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResult>) responseObserver);
|
(io.grpc.stub.StreamObserver<org.nd4j.graph.FlatResult>) responseObserver);
|
||||||
|
@ -465,6 +540,7 @@ public final class GraphInferenceServerGrpc {
|
||||||
.setSchemaDescriptor(null)
|
.setSchemaDescriptor(null)
|
||||||
.addMethod(getRegisterGraphMethod())
|
.addMethod(getRegisterGraphMethod())
|
||||||
.addMethod(getForgetGraphMethod())
|
.addMethod(getForgetGraphMethod())
|
||||||
|
.addMethod(getReplaceGraphMethod())
|
||||||
.addMethod(getInferenceRequestMethod())
|
.addMethod(getInferenceRequestMethod())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue